TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
torchdiff/ldm.py ADDED
@@ -0,0 +1,2156 @@
1
+ """
2
+ **Latent Diffusion Models (LDM)**
3
+
4
+ This module provides a framework for training and sampling Latent Diffusion Models, as
5
+ described in Rombach et al. (2022, "High-Resolution Image Synthesis with Latent Diffusion
6
+ Models"). It supports diffusion in the latent space using a variational autoencoder
7
+ (compressor model), includes utilities for training the autoencoder, noise predictor, and
8
+ conditional model, and provides metrics for evaluating generated images. The framework is
9
+ compatible with DDPM, DDIM, and SDE diffusion models, supporting both unconditional and
10
+ conditional generation with text prompts.
11
+
12
+ **Components**
13
+
14
+ - **AutoencoderLDM**: Variational autoencoder for compressing images to latent space and
15
+ decoding back to image space.
16
+ - **TrainAE**: Trainer for AutoencoderLDM, optimizing reconstruction and regularization
17
+ losses with evaluation metrics.
18
+ - **TrainLDM**: Training loop with mixed precision, warmup, and scheduling for the noise
19
+ predictor and conditional model (e.g., TextEncoder with projection layers) in latent
20
+ space, with image-domain evaluation metrics using a reverse diffusion model.
21
+ - **SampleLDM**: Image generation from trained models, decoding from latent to image space.
22
+
23
+
24
+ **Notes**
25
+
26
+
27
+ - The `varinace_scheduler` parameter expects an external hyperparameter module (e.g.,
28
+ VarianceSchedulerDDPM, VarianceSchedulerSDE) as an nn.Module for noise schedule management.
29
+ - AutoencoderLDM serves as the `compressor_model` in TrainLDM and SampleLDM, providing
30
+ `encode` and `decode` methods for latent space conversion. It supports KL-divergence or
31
+ vector quantization (VQ) regularization, using internal components (DownBlock, UpBlock,
32
+ Conv3, DownSampling, UpSampling, Attention, VectorQuantizer).
33
+ - TrainAE trains AutoencoderLDM, optimizing reconstruction (MSE), regularization (KL or
34
+ VQ), and optional perceptual (LPIPS) losses, with metrics (MSE, PSNR, SSIM, FID, LPIPS)
35
+ computed via the Metrics class, KL warmup, early stopping, and learning rate scheduling.
36
+ - TrainLDM trains the noise predictor and conditional model, optimizing MSE between
37
+ predicted and ground truth noise, with optional validation metrics (MSE, PSNR, SSIM, FID,
38
+ LPIPS) on generated images decoded from latents sampled using a reverse diffusion model
39
+ (e.g., ReverseDDPM).
40
+ - SampleLDM supports multiple diffusion models ("ddpm", "ddim", "sde") via the `model`
41
+ parameter, requiring compatible `reverse_diffusion` modules (e.g., ReverseDDPM,
42
+ ReverseDDIM, ReverseSDE).
43
+
44
+
45
+ **References**
46
+
47
+ - Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models."
48
+ Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
49
+
50
+
51
+ - Esser, Patrick, Robin Rombach, and Bjorn Ommer. "Taming transformers for high-resolution image synthesis."
52
+ Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
53
+
54
+ ---------------------------------------------------------------------------------
55
+ """
56
+
57
+
58
+ import torch
59
+ import torch.nn as nn
60
+ import torch.nn.functional as F
61
+ from typing import Optional, Tuple, Any, Callable, List, Union, Self
62
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
63
+ import torch.distributed as dist
64
+ from torch.nn.parallel import DistributedDataParallel as DDP
65
+ from torch.distributed import init_process_group, destroy_process_group
66
+ from torch.optim.lr_scheduler import LambdaLR
67
+ from transformers import BertTokenizer
68
+ import warnings
69
+ from tqdm import tqdm
70
+ from torchvision.utils import save_image
71
+ import os
72
+
73
+
74
+
75
+ ###==================================================================================================================###
76
+
77
+ class TrainLDM(nn.Module):
78
+ """Trainer for the noise predictor in Latent Diffusion Models.
79
+
80
+ Optimizes the noise predictor and conditional model (e.g., TextEncoder)
81
+ to predict noise in the latent space of AutoencoderLDM, using a diffusion model (e.g., DDPM, DDIM, SDE).
82
+ Supports mixed precision, conditional generation with text prompts, and evaluation metrics
83
+ (MSE, PSNR, SSIM, FID, LPIPS) for generated images during validation, using a specified reverse
84
+ diffusion model.
85
+
86
+ Parameters
87
+ ----------
88
+ diffusion_model : str
89
+ Diffusion model type ("ddpm", "ddim", "sde").
90
+ forward_diffusion : ForwardDDPM, ForwardDDIM, or ForwardSDE
91
+ Forward diffusion model defining the noise schedule.
92
+ reverse_diffusion : ReverseDDPM, ReverseDDIM, or ReverseSDE
93
+ Reverse diffusion model for sampling during validation (default: None).
94
+ noise_predictor : torch.nn.Module
95
+ Model to predict noise in the latent space (e.g., NoisePredictor).
96
+ compressor_model : torch.nn.Module
97
+ Variational autoencoder for encoding/decoding latents.
98
+ optimizer : torch.optim.Optimizer
99
+ Optimizer for the noise predictor and conditional model (e.g., Adam).
100
+ objective : Callable
101
+ Loss function for noise prediction (e.g., MSELoss).
102
+ data_loader : torch.utils.data.DataLoader
103
+ DataLoader for training data.
104
+ val_loader : torch.utils.data.DataLoader, optional
105
+ DataLoader for validation data (default: None).
106
+ conditional_model : TextEncoder, optional
107
+ Text encoder with projection layers for conditional generation (default: None).
108
+
109
+ metrics_ : object, optional
110
+ Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
111
+ max_epochs : int, optional
112
+ Maximum number of training epochs (default: 1000).
113
+ device : str, optional
114
+ Device for computation (e.g., 'cuda', 'cpu') (default: None).
115
+ store_path : str, optional
116
+ Path to save model checkpoints (default: None, uses 'ldm_model.pth').
117
+ patience : int, optional
118
+ Number of epochs to wait for early stopping if validation loss doesn’t improve
119
+ (default: 100).
120
+ warmup_epochs : int, optional
121
+ Number of epochs for learning rate warmup (default: 100).
122
+ bert_tokenizer : BertTokenizer, optional
123
+ Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
124
+ max_token_length : int, optional
125
+ Maximum sequence length for tokenized text (default: 77).
126
+ val_frequency : int, optional
127
+ Frequency (in epochs) for validation and metric computation (default: 10).
128
+ image_output_range : tuple, optional
129
+ Range for clamping generated images (default: (-1, 1)).
130
+ normalize_output : bool, optional
131
+ Whether to normalize generated images to [0, 1] for metrics (default: True).
132
+ use_ddp : bool, optional
133
+ Whether to use Distributed Data Parallel training (default: False).
134
+ grad_accumulation_steps : int, optional
135
+ Number of gradient accumulation steps before optimizer update (default: 1).
136
+ log_frequency : int, optional
137
+ Number of epochs before printing loss.
138
+ use_compilation : bool, optional
139
+ whether the model is internally compiled using torch.compile (default: false)
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ diffusion_model: str,
145
+ forward_diffusion: torch.nn.Module,
146
+ reverse_diffusion: torch.nn.Module,
147
+ noise_predictor: torch.nn.Module,
148
+ compressor_model: torch.nn.Module,
149
+ optimizer: torch.optim.Optimizer,
150
+ objective: Callable,
151
+ data_loader: torch.utils.data.DataLoader,
152
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
153
+ conditional_model: Optional[torch.nn.Module] = None,
154
+ metrics_: Optional[Any] = None,
155
+ max_epochs: int = 1000,
156
+ device: Optional[Union[str, torch.device]] = None,
157
+ store_path: Optional[str] = None,
158
+ patience: int = 100,
159
+ warmup_epochs: int = 100,
160
+ bert_tokenizer: Optional[BertTokenizer] = None,
161
+ max_token_length: int = 77,
162
+ val_frequency: int = 10,
163
+ image_output_range: Tuple[float, float] = (-1.0, 1.0),
164
+ normalize_output: bool = True,
165
+ use_ddp: bool = False,
166
+ grad_accumulation_steps: int = 1,
167
+ log_frequency: int = 1,
168
+ use_compilation: bool = False
169
+ ) -> None:
170
+ super().__init__()
171
+ if diffusion_model not in ["ddpm", "ddim", "sde"]:
172
+ raise ValueError(f"Unknown model: {diffusion_model}. Supported: ddpm, ddim, sde")
173
+ self.diffusion_model = diffusion_model
174
+
175
+ # initialize DDP settings first
176
+ self.use_ddp = use_ddp
177
+ self.grad_accumulation_steps = grad_accumulation_steps
178
+ if device is None:
179
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180
+ elif isinstance(device, str):
181
+ self.device = torch.device(device)
182
+ else:
183
+ self.device = device
184
+
185
+ # setup distributed training if enabled
186
+ if self.use_ddp:
187
+ self._setup_ddp()
188
+ else:
189
+ self._setup_single_gpu()
190
+
191
+ # move models to appropriate device
192
+ self.forward_diffusion = forward_diffusion.to(self.device)
193
+ self.reverse_diffusion = reverse_diffusion.to(self.device)
194
+ self.noise_predictor = noise_predictor.to(self.device)
195
+ self.compressor_model = compressor_model.to(self.device)
196
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
197
+
198
+ # Training components
199
+ self.metrics_ = metrics_
200
+ self.optimizer = optimizer
201
+ self.objective = objective
202
+ self.store_path = store_path or "ldm_model"
203
+ self.data_loader = data_loader
204
+ self.val_loader = val_loader
205
+ self.max_epochs = max_epochs
206
+ self.max_token_length = max_token_length
207
+ self.patience = patience
208
+ self.val_frequency = val_frequency
209
+ self.image_output_range = image_output_range
210
+ self.normalize_output = normalize_output
211
+ self.log_frequency = log_frequency
212
+ self.use_compilation = use_compilation
213
+
214
+ # learning rate scheduling
215
+ self.scheduler = ReduceLROnPlateau(
216
+ self.optimizer,
217
+ patience=self.patience,
218
+ factor=0.5
219
+ )
220
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
221
+
222
+ # initialize tokenizer
223
+ if bert_tokenizer is None:
224
+ try:
225
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
226
+ except Exception as e:
227
+ raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
228
+ else:
229
+ self.tokenizer = bert_tokenizer
230
+
231
+ def _setup_ddp(self) -> None:
232
+ """Setup Distributed Data Parallel training configuration.
233
+
234
+ Initializes process group, determines rank information, and sets up
235
+ CUDA device for the current process.
236
+ """
237
+ # check if DDP environment variables are set
238
+ if "RANK" not in os.environ:
239
+ raise ValueError("DDP enabled but RANK environment variable not set")
240
+ if "LOCAL_RANK" not in os.environ:
241
+ raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
242
+ if "WORLD_SIZE" not in os.environ:
243
+ raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
244
+
245
+ # ensure CUDA is available for DDP
246
+ if not torch.cuda.is_available():
247
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
248
+
249
+ # initialize process group only if not already initialized
250
+ if not torch.distributed.is_initialized():
251
+ init_process_group(backend="nccl")
252
+
253
+ # get rank information
254
+ self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
255
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
256
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
257
+
258
+ # set device and make it current
259
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
260
+ torch.cuda.set_device(self.device)
261
+
262
+ # master process handles logging, checkpointing, etc.
263
+ self.master_process = self.ddp_rank == 0
264
+
265
+ if self.master_process:
266
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
267
+
268
+ def _setup_single_gpu(self) -> None:
269
+ """Setup single GPU or CPU training configuration."""
270
+ self.ddp_rank = 0
271
+ self.ddp_local_rank = 0
272
+ self.ddp_world_size = 1
273
+ self.master_process = True
274
+
275
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
276
+ """Loads a training checkpoint to resume training.
277
+
278
+ Restores the state of the noise predictor, conditional model (if applicable),
279
+ and optimizer from a saved checkpoint. Handles DDP model state dict loading.
280
+
281
+ Parameters
282
+ ----------
283
+ checkpoint_path : str
284
+ Path to the checkpoint file.
285
+
286
+ Returns
287
+ -------
288
+ epoch : int
289
+ The epoch at which the checkpoint was saved.
290
+ loss : float
291
+ The loss at the checkpoint.
292
+ """
293
+ try:
294
+ # load checkpoint with proper device mapping
295
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
296
+ except FileNotFoundError:
297
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
298
+
299
+ # load noise predictor state
300
+ if 'model_state_dict_noise_predictor' not in checkpoint:
301
+ raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
302
+
303
+ # handle DDP wrapped model state dict
304
+ state_dict = checkpoint['model_state_dict_noise_predictor']
305
+ if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
306
+ # if loading non-DDP checkpoint into DDP model, add 'module.' prefix
307
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
308
+ elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
309
+ # If loading DDP checkpoint into non-DDP model, remove 'module.' prefix
310
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
311
+
312
+ self.noise_predictor.load_state_dict(state_dict)
313
+
314
+ # load conditional model state if applicable
315
+ if self.conditional_model is not None:
316
+ if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
317
+ cond_state_dict = checkpoint['model_state_dict_conditional']
318
+ # handle DDP wrapping for conditional model
319
+ if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
320
+ cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
321
+ elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
322
+ cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
323
+ self.conditional_model.load_state_dict(cond_state_dict)
324
+ else:
325
+ warnings.warn(
326
+ "Checkpoint contains no 'model_state_dict_conditional' or it is None, "
327
+ "skipping conditional model loading"
328
+ )
329
+
330
+ # load variance_scheduler state
331
+ if 'variance_scheduler_model' not in checkpoint:
332
+ raise KeyError("Checkpoint missing 'variance_scheduler_model' key")
333
+ try:
334
+ if isinstance(self.forward_diffusion.variance_scheduler, nn.Module):
335
+ self.forward_diffusion.variance_scheduler.load_state_dict(checkpoint['variance_scheduler_model'])
336
+ if isinstance(self.reverse_diffusion.variance_scheduler, nn.Module):
337
+ self.reverse_diffusion.variance_scheduler.load_state_dict(checkpoint['variance_scheduler_model'])
338
+ else:
339
+ self.forward_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
340
+ self.reverse_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
341
+ except Exception as e:
342
+ warnings.warn(f"Variance_scheduler loading failed: {e}. Continuing with current variance_scheduler.")
343
+
344
+ # load optimizer state
345
+ if 'optimizer_state_dict' not in checkpoint:
346
+ raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
347
+ try:
348
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
349
+ except ValueError as e:
350
+ warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
351
+
352
+ epoch = checkpoint.get('epoch', -1)
353
+ loss = checkpoint.get('loss', float('inf'))
354
+
355
+ if self.master_process:
356
+ print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
357
+
358
+ return epoch, loss
359
+
360
+
361
+ @staticmethod
362
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
363
+ """Creates a learning rate scheduler for warmup.
364
+
365
+ Generates a scheduler that linearly increases the learning rate from 0 to the
366
+ optimizer's initial value over the specified warmup epochs, then maintains it.
367
+
368
+ Parameters
369
+ ----------
370
+ optimizer : torch.optim.Optimizer
371
+ Optimizer to apply the scheduler to.
372
+ warmup_epochs : int
373
+ Number of epochs for the warmup phase.
374
+
375
+ Returns
376
+ -------
377
+ torch.optim.lr_scheduler.LambdaLR
378
+ Learning rate scheduler for warmup.
379
+ """
380
+
381
+ def lr_lambda(epoch):
382
+ if epoch < warmup_epochs:
383
+ return epoch / warmup_epochs
384
+ return 1.0
385
+
386
+ return LambdaLR(optimizer, lr_lambda)
387
+
388
+ def _wrap_models_for_ddp(self) -> None:
389
+ """Wrap models with DistributedDataParallel for multi-GPU training."""
390
+ if self.use_ddp:
391
+ # wrap noise predictor with DDP
392
+ self.noise_predictor = DDP(
393
+ self.noise_predictor,
394
+ device_ids=[self.ddp_local_rank],
395
+ find_unused_parameters=True
396
+ )
397
+
398
+ # wrap conditional model with DDP if it exists
399
+ if self.conditional_model is not None:
400
+ self.conditional_model = DDP(
401
+ self.conditional_model,
402
+ device_ids=[self.ddp_local_rank],
403
+ find_unused_parameters=True
404
+ )
405
+
406
+ def forward(self) -> Tuple[List, float]:
407
+ """Trains the noise predictor and conditional model with mixed precision and evaluation metrics.
408
+
409
+ Optimizes the noise predictor and conditional model (e.g., TextEncoder with projection layers)
410
+ using the forward diffusion model’s noise schedule, with text conditioning. Performs validation
411
+ with image-domain metrics (MSE, PSNR, SSIM, FID, LPIPS) using the reverse diffusion model,
412
+ saves checkpoints for the best validation loss, and supports early stopping.
413
+
414
+ Returns
415
+ -------
416
+ train_losses : List of float
417
+ List of mean training losses per epoch.
418
+ best_val_loss : float
419
+ Best validation loss achieved (or best training loss if no validation).
420
+ """
421
+ # set models to training mode
422
+ self.noise_predictor.train()
423
+ if self.conditional_model is not None:
424
+ self.conditional_model.train()
425
+ self.compressor_model.eval() # pre-trained compressor model
426
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
427
+ self.reverse_diffusion.train()
428
+ self.forward_diffusion.train()
429
+ else:
430
+ self.reverse_diffusion.eval()
431
+ self.forward_diffusion.eval()
432
+
433
+ # compile models for optimization (if supported)
434
+ if self.use_compilation:
435
+ try:
436
+ self.noise_predictor = torch.compile(self.noise_predictor)
437
+ if self.conditional_model is not None:
438
+ self.conditional_model = torch.compile(self.conditional_model)
439
+ self.compressor_model = torch.compile(self.compressor_model)
440
+ except Exception as e:
441
+ if self.master_process:
442
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
443
+
444
+
445
+ # wrap models for DDP after compilation
446
+ self._wrap_models_for_ddp()
447
+
448
+ # initialize training components
449
+ scaler = torch.GradScaler()
450
+ train_losses = []
451
+ best_val_loss = float("inf")
452
+ wait = 0
453
+
454
+ # main training loop
455
+ for epoch in range(self.max_epochs):
456
+ # set epoch for distributed sampler if using DDP
457
+ if self.use_ddp and hasattr(self.data_loader.sampler, 'set_epoch'):
458
+ self.data_loader.sampler.set_epoch(epoch)
459
+
460
+ train_losses_epoch = []
461
+
462
+ # training step loop with gradient accumulation
463
+ for step, (x, y) in enumerate(tqdm(self.data_loader, disable=not self.master_process)):
464
+ x = x.to(self.device)
465
+
466
+ with torch.no_grad():
467
+ x, _ = self.compressor_model.encode(x)
468
+
469
+ # process conditional inputs if conditional model exists
470
+ if self.conditional_model is not None:
471
+ y_encoded = self._process_conditional_input(y)
472
+ else:
473
+ y_encoded = None
474
+
475
+ # forward pass with mixed precision
476
+ with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
477
+ # generate noise and timesteps
478
+ noise = torch.randn_like(x).to(self.device)
479
+ t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
480
+
481
+ # apply forward diffusion
482
+ noisy_x = self.forward_diffusion(x, noise, t)
483
+
484
+ # predict noise
485
+ predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
486
+
487
+ # compute loss and scale for gradient accumulation
488
+ loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
489
+
490
+ # backward pass
491
+ scaler.scale(loss).backward()
492
+
493
+ # gradient accumulation and optimizer step
494
+ if (step + 1) % self.grad_accumulation_steps == 0:
495
+ # clip gradients
496
+ scaler.unscale_(self.optimizer)
497
+ torch.nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
498
+ if self.conditional_model is not None:
499
+ torch.nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
500
+
501
+ # optimizer step
502
+ scaler.step(self.optimizer)
503
+ scaler.update()
504
+ self.optimizer.zero_grad()
505
+
506
+ # update learning rate (warmup scheduler)
507
+ self.warmup_lr_scheduler.step()
508
+
509
+ # record loss (unscaled)
510
+ train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
511
+
512
+ # compute mean training loss
513
+ mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
514
+
515
+ # all-reduce loss across processes for DDP
516
+ if self.use_ddp:
517
+ loss_tensor = torch.tensor(mean_train_loss, device=self.device)
518
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
519
+ mean_train_loss = loss_tensor.item()
520
+
521
+ train_losses.append(mean_train_loss)
522
+
523
+ # print training progress (only master process)
524
+ if self.master_process and (epoch + 1) % self.log_frequency == 0:
525
+ current_lr = self.optimizer.param_groups[0]['lr']
526
+ print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
527
+
528
+ # validation step
529
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
530
+ val_metrics = self.validate()
531
+ val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
532
+
533
+ if self.master_process:
534
+ print(f" | Val Loss: {val_loss:.4f}", end="")
535
+ if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
536
+ print(f" | FID: {fid:.4f}", end="")
537
+ if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
538
+ print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
539
+ if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
540
+ print(f" | LPIPS: {lpips_score:.4f}", end="")
541
+ print()
542
+
543
+ current_best = val_loss
544
+ self.scheduler.step(val_loss)
545
+ else:
546
+ if self.master_process:
547
+ print()
548
+ current_best = mean_train_loss
549
+ self.scheduler.step(mean_train_loss)
550
+
551
+ # save checkpoint and early stopping (only master process)
552
+ if self.master_process:
553
+ if current_best < best_val_loss and (epoch + 1) % self.val_frequency == 0:
554
+ best_val_loss = current_best
555
+ wait = 0
556
+ self._save_checkpoint(epoch + 1, best_val_loss)
557
+ else:
558
+ wait += 1
559
+ if wait >= self.patience:
560
+ print("Early stopping triggered")
561
+ self._save_checkpoint(epoch + 1, best_val_loss, "_early_stop")
562
+ break
563
+
564
+ # clean up DDP
565
+ if self.use_ddp:
566
+ destroy_process_group()
567
+
568
+ return train_losses, best_val_loss
569
+
570
+ def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor:
571
+ """Process conditional input for text-to-image generation.
572
+
573
+ Parameters
574
+ ----------
575
+ y : torch.Tensor or list
576
+ Conditional input (text prompts).
577
+
578
+ Returns
579
+ -------
580
+ torch.Tensor
581
+ Encoded conditional input.
582
+ """
583
+ # convert to string list
584
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
585
+ y_list = [str(item) for item in y_list]
586
+
587
+ # tokenize
588
+ y_encoded = self.tokenizer(
589
+ y_list,
590
+ padding="max_length",
591
+ truncation=True,
592
+ max_length=self.max_token_length,
593
+ return_tensors="pt"
594
+ ).to(self.device)
595
+
596
+ # get embeddings
597
+ input_ids = y_encoded["input_ids"]
598
+ attention_mask = y_encoded["attention_mask"]
599
+ y_encoded = self.conditional_model(input_ids, attention_mask)
600
+
601
+ return y_encoded
602
+
603
+ def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "") -> None:
604
+ """Save model checkpoint (only called by master process).
605
+
606
+ Parameters
607
+ ----------
608
+ epoch : int
609
+ Current epoch number.
610
+ loss : float
611
+ Current loss value.
612
+ suffix : str, optional
613
+ Suffix to add to checkpoint filename.
614
+ """
615
+ try:
616
+ # get state dicts, handling DDP wrapping
617
+ noise_predictor_state = (
618
+ self.noise_predictor.module.state_dict() if self.use_ddp
619
+ else self.noise_predictor.state_dict()
620
+ )
621
+ conditional_state = None
622
+ if self.conditional_model is not None:
623
+ conditional_state = (
624
+ self.conditional_model.module.state_dict() if self.use_ddp
625
+ else self.conditional_model.state_dict()
626
+ )
627
+
628
+ checkpoint = {
629
+ 'epoch': epoch,
630
+ 'model_state_dict_noise_predictor': noise_predictor_state,
631
+ 'model_state_dict_conditional': conditional_state,
632
+ 'optimizer_state_dict': self.optimizer.state_dict(),
633
+ 'loss': loss,
634
+ 'variance_scheduler_model': (
635
+ self.forward_diffusion.variance_scheduler.state_dict() if isinstance(self.forward_diffusion.variance_scheduler, nn.Module)
636
+ else self.forward_diffusion.variance_scheduler
637
+ ),
638
+ 'max_epochs': self.max_epochs,
639
+ }
640
+
641
+ filename = f"ldm_epoch_{epoch}{suffix}.pth"
642
+ filepath = os.path.join(self.store_path, filename)
643
+ os.makedirs(self.store_path, exist_ok=True)
644
+ torch.save(checkpoint, filepath)
645
+
646
+ print(f"Model saved at epoch {epoch}")
647
+
648
+ except Exception as e:
649
+ print(f"Failed to save model: {e}")
650
+
651
+
652
+ def validate(self) -> Tuple[float, float, float, float, float, float]:
653
+ """Validates the noise predictor and computes evaluation metrics.
654
+
655
+ Computes validation loss (MSE between predicted and ground truth noise) and generates
656
+ samples using the reverse diffusion model. Evaluates image quality metrics if available.
657
+
658
+ Returns
659
+ -------
660
+ tuple
661
+ (val_loss, fid, mse, psnr, ssim, lpips_score) where metrics may be None if not computed.
662
+ """
663
+ self.noise_predictor.eval()
664
+ if self.conditional_model is not None:
665
+ self.conditional_model.eval()
666
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
667
+ self.forward_diffusion.eval()
668
+ self.reverse_diffusion.eval()
669
+
670
+ val_losses = []
671
+ fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
672
+
673
+ num_steps = self.forward_diffusion.variance_scheduler.tau_num_steps if self.diffusion_model == "ddim" else self.forward_diffusion.variance_scheduler.num_steps
674
+
675
+ with torch.no_grad():
676
+ for x, y in self.val_loader:
677
+ x = x.to(self.device)
678
+ x_orig = x.clone()
679
+ x, _ = self.compressor_model.encode(x)
680
+
681
+ # process conditional input
682
+ if self.conditional_model is not None:
683
+ y_encoded = self._process_conditional_input(y)
684
+ else:
685
+ y_encoded = None
686
+
687
+ # compute validation loss
688
+ noise = torch.randn_like(x).to(self.device)
689
+ t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
690
+
691
+ noisy_x = self.forward_diffusion(x, noise, t)
692
+ predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
693
+ loss = self.objective(predicted_noise, noise)
694
+ val_losses.append(loss.item())
695
+ # generate samples for metrics evaluation
696
+ if self.metrics_ is not None and self.reverse_diffusion is not None:
697
+ xt = torch.randn_like(x).to(self.device)
698
+
699
+ # reverse diffusion sampling
700
+ for t in reversed(range(num_steps)):
701
+ time_steps = torch.full((xt.shape[0],), t, device=self.device)#, dtype=torch.long)
702
+ prev_time_steps = torch.full((xt.shape[0],), max(t - 1, 0), device=self.device)#, dtype=torch.long)
703
+ predicted_noise = self.noise_predictor(xt, time_steps, y_encoded, None)
704
+
705
+ if self.diffusion_model == "sde":
706
+ noise = torch.randn_like(xt) if getattr(self.reverse_diffusion, "sde_method", None) != "ode" else None
707
+ xt = self.reverse_diffusion(xt, noise, predicted_noise, time_steps)
708
+ elif self.diffusion_model == "ddim":
709
+ xt, _ = self.reverse_diffusion(xt, predicted_noise, time_steps, prev_time_steps)
710
+ elif self.diffusion_model == "ddpm":
711
+ xt = self.reverse_diffusion(xt, predicted_noise, time_steps)
712
+ else:
713
+ raise ValueError(f"Unknown model: {self.diffusion_model}. Supported: ddpm, ddim, sde")
714
+
715
+ x_hat = self.compressor_model.decode(xt)
716
+
717
+ # clamp and normalize generated samples
718
+ x_hat = torch.clamp(x_hat, min=self.image_output_range[0], max=self.image_output_range[1])
719
+ if self.normalize_output:
720
+ x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
721
+ x_orig = (x_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
722
+
723
+ # Compute metrics
724
+ metrics_result = self.metrics_.forward(x_orig, x_hat)
725
+ fid, mse, psnr, ssim, lpips_score = metrics_result
726
+
727
+ if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
728
+ fid_scores.append(fid)
729
+ if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
730
+ mse_scores.append(mse)
731
+ psnr_scores.append(psnr)
732
+ ssim_scores.append(ssim)
733
+ if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
734
+ lpips_scores.append(lpips_score)
735
+
736
+ # compute average metrics
737
+ val_loss = torch.tensor(val_losses).mean().item()
738
+
739
+ # all-reduce validation metrics across processes for DDP
740
+ if self.use_ddp:
741
+ val_loss_tensor = torch.tensor(val_loss, device=self.device)
742
+ dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
743
+ val_loss = val_loss_tensor.item()
744
+
745
+ fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
746
+ mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
747
+ psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
748
+ ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
749
+ lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
750
+
751
+ # return to training mode
752
+ self.noise_predictor.train()
753
+ if self.conditional_model is not None:
754
+ self.conditional_model.train()
755
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
756
+ self.reverse_diffusion.train()
757
+ self.forward_diffusion.train()
758
+
759
+ return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
760
+
761
+
762
+ ###==================================================================================================================###
763
+
764
+
765
+ class SampleLDM(nn.Module):
766
+ """Sampler for generating images using Latent Diffusion Models (LDM).
767
+
768
+ Generates images by iteratively denoising random noise in the latent space using a
769
+ reverse diffusion process, decoding the result back to the image space with a
770
+ pre-trained compressor, as described in Rombach et al. (2022). Supports DDPM, DDIM,
771
+ and SDE diffusion models, as well as conditional generation with text prompts.
772
+
773
+ Parameters
774
+ ----------
775
+ diffusion_model : str
776
+ Diffusion model type. Supported: "ddpm", "ddim", "sde".
777
+ reverse_diffusion : nn.Module
778
+ Reverse diffusion module (e.g., ReverseDDPM, ReverseDDIM, ReverseSDE).
779
+ noise_predictor : nn.Module
780
+ Model to predict noise added during the forward diffusion process.
781
+ compressor_model : nn.Module
782
+ Pre-trained model to encode/decode between image and latent spaces (e.g., AutoencoderLDM).
783
+ image_shape : tuple
784
+ Shape of generated images as (height, width).
785
+ conditional_model : nn.Module, optional
786
+ Model for conditional generation (e.g., TextEncoder), default None.
787
+ bert_tokenizer : str or BertTokenizer, optional
788
+ Tokenizer for processing text prompts, default "bert-base-uncased".
789
+ batch_size : int, optional
790
+ Number of images to generate per batch (default: 1).
791
+ in_channels : int, optional
792
+ Number of input channels for latent representations (default: 3).
793
+ device : torch.device, optional
794
+ Device for computation (default: CUDA if available, else CPU).
795
+ max_token_length : int, optional
796
+ Maximum length for tokenized prompts (default: 77).
797
+ image_output_range : tuple, optional
798
+ Range for clamping generated images (min, max), default (-1, 1).
799
+ """
800
+ def __init__(
801
+ self,
802
+ diffusion_model: str,
803
+ reverse_diffusion: torch.nn.Module,
804
+ noise_predictor: torch.nn.Module,
805
+ compressor_model: torch.nn.Module,
806
+ image_shape: Tuple[float, float],
807
+ conditional_model: Optional[torch.nn.Module] = None,
808
+ bert_tokenizer: str = "bert-base-uncased",
809
+ batch_size: int = 1,
810
+ in_channels: int = 3,
811
+ device: Optional[Union[str, torch.device]] = None,
812
+ max_token_length: int = 77,
813
+ image_output_range: Tuple[float, float] = (-1.0, 1.0)
814
+ ) -> None:
815
+ super().__init__()
816
+ if device is None:
817
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
818
+ elif isinstance(device, str):
819
+ self.device = torch.device(device)
820
+ else:
821
+ self.device = device
822
+ self.diffusion_model = diffusion_model
823
+ self.noise_predictor = noise_predictor.to(self.device)
824
+ self.reverse = reverse_diffusion.to(self.device)
825
+ self.compressor = compressor_model.to(self.device)
826
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
827
+ self.tokenizer = BertTokenizer.from_pretrained(bert_tokenizer)
828
+ self.in_channels = in_channels
829
+ self.image_shape = image_shape
830
+ self.batch_size = batch_size
831
+ self.max_token_length = max_token_length
832
+ self.image_output_range = image_output_range
833
+
834
+ if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(isinstance(s, int) and s > 0 for s in image_shape):
835
+ raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
836
+ if batch_size <= 0:
837
+ raise ValueError("batch_size must be positive")
838
+ if in_channels <= 0:
839
+ raise ValueError("in_channels must be positive")
840
+ if not isinstance(image_output_range, (tuple, list)) or len(image_output_range) != 2 or image_output_range[0] >= image_output_range[1]:
841
+ raise ValueError("output_range must be a tuple (min, max) with min < max")
842
+
843
+ def tokenize(self, prompts: Union[List, str]):
844
+ """Tokenizes text prompts for conditional generation.
845
+
846
+ Converts input prompts into tokenized tensors using the specified tokenizer.
847
+
848
+ Parameters
849
+ ----------
850
+ prompts : str or list
851
+ Text prompt(s) for conditional generation. Can be a single string or a list of strings.
852
+
853
+ Returns
854
+ -------
855
+ input_ids : torch.Tensor
856
+ Tokenized input IDs, shape (batch_size, max_length).
857
+ attention_mask : torch.Tensor
858
+ Attention mask, shape (batch_size, max_length).
859
+ """
860
+ if isinstance(prompts, str):
861
+ prompts = [prompts]
862
+ elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
863
+ raise TypeError("prompts must be a string or list of strings")
864
+
865
+ encoded = self.tokenizer(
866
+ prompts,
867
+ padding="max_length",
868
+ truncation=True,
869
+ max_length=self.max_token_length,
870
+ return_tensors="pt"
871
+ )
872
+ return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
873
+
874
+
875
+ def forward(
876
+ self,
877
+ conditions: Optional[Union[List, str]] = None,
878
+ normalize_output: bool = True,
879
+ save_images: bool = True,
880
+ save_path: str = "ldm_generated"
881
+ ) -> torch.Tensor:
882
+ """Generates images using the reverse diffusion process in the latent space.
883
+
884
+ Iteratively denoises random noise in the latent space using the specified reverse
885
+ diffusion model (DDPM, DDIM, SDE), then decodes the result to the image space
886
+ with the compressor model. Supports conditional generation with text prompts.
887
+
888
+ Parameters
889
+ ----------
890
+ conditions : str or list, optional
891
+ Text prompt(s) for conditional generation, default None.
892
+ normalize_output : bool, optional
893
+ If True, normalizes output images to [0, 1] (default: True).
894
+ save_images : bool, optional
895
+ If True, saves generated images to `save_path` (default: True).
896
+ save_path : str, optional
897
+ Directory to save generated images (default: "ldm_generated").
898
+
899
+ Returns
900
+ -------
901
+ generated_imgs (torch.Tensor) - Generated images, shape (batch_size, channels, height, width). If `normalize_output` is True, images are normalized to [0, 1]; otherwise, they are clamped to `output_range`.
902
+ """
903
+ if conditions is not None and self.conditional_model is None:
904
+ raise ValueError("Conditions provided but no conditional model specified")
905
+ if conditions is None and self.conditional_model is not None:
906
+ raise ValueError("Conditions must be provided for conditional model")
907
+
908
+ noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
909
+
910
+ self.noise_predictor.eval()
911
+ self.compressor.eval()
912
+ self.reverse.eval()
913
+ if self.conditional_model:
914
+ self.conditional_model.eval()
915
+
916
+ with torch.no_grad():
917
+ xt = noisy_samples
918
+ xt, _ = self.compressor.encode(xt)
919
+
920
+ if self.diffusion_model == "ddim":
921
+ num_steps = self.reverse.variance_scheduler.tau_num_steps
922
+ elif self.diffusion_model == "ddpm" or self.diffusion_model == "sde":
923
+ num_steps = self.reverse.variance_scheduler.num_steps
924
+ else:
925
+ raise ValueError(f"Unknown model: {self.diffusion_model}. Supported: ddpm, ddim, sde")
926
+
927
+ for t in reversed(range(num_steps)):
928
+ time_steps = torch.full((self.batch_size,), t, device=self.device)#, dtype=torch.long)
929
+ prev_time_steps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)#, dtype=torch.long)
930
+
931
+ if self.diffusion_model == "sde":
932
+ noise = torch.randn_like(xt) if getattr(self.reverse, "sde_method", None) != "ode" else None
933
+
934
+ if self.conditional_model is not None and conditions is not None:
935
+ input_ids, attention_masks = self.tokenize(conditions)
936
+ key_padding_mask = (attention_masks == 0)
937
+ y = self.conditional_model(input_ids, key_padding_mask)
938
+ predicted_noise = self.noise_predictor(xt, time_steps, y)
939
+ else:
940
+ predicted_noise = self.noise_predictor(xt, time_steps)
941
+
942
+ if self.diffusion_model == "sde":
943
+ xt = self.reverse(xt, noise, predicted_noise, time_steps)
944
+ elif self.diffusion_model == "ddim":
945
+ xt, _ = self.reverse(xt, predicted_noise, time_steps, prev_time_steps)
946
+ elif self.diffusion_model == "ddpm":
947
+ xt = self.reverse(xt, predicted_noise, time_steps)
948
+ else:
949
+ raise ValueError(f"Unknown model: {self.diffusion_model}. Supported: ddpm, ddim, sde")
950
+
951
+ x = self.compressor.decode(xt)
952
+ generated_imgs = torch.clamp(x, min=self.image_output_range[0], max=self.image_output_range[1])
953
+ if normalize_output:
954
+ generated_imgs = (generated_imgs - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
955
+
956
+ # save images if save_images is True
957
+ if save_images:
958
+ os.makedirs(save_path, exist_ok=True)
959
+ for i in range(generated_imgs.size(0)):
960
+ img_path = os.path.join(save_path, f"image_{i+1}.png")
961
+ save_image(generated_imgs[i], img_path)
962
+
963
+ return generated_imgs
964
+
965
+ def to(self, device: torch.device) -> Self:
966
+ """Moves the module and its components to the specified device.
967
+
968
+ Parameters
969
+ ----------
970
+ device : torch.device
971
+ Target device for computation.
972
+
973
+ Returns
974
+ -------
975
+ sample (SampleDDIM, SampleDDIM or SampleSDE) - The module moved to the specified device.
976
+ """
977
+ self.device = device
978
+ self.noise_predictor.to(device)
979
+ self.reverse.to(device)
980
+ self.compressor.to(device)
981
+ if self.conditional_model:
982
+ self.conditional_model.to(device)
983
+ return super().to(device)
984
+
985
+ ###==================================================================================================================###
986
+
987
+ class AutoencoderLDM(nn.Module):
988
+ """Variational autoencoder for latent space compression in Latent Diffusion Models.
989
+
990
+ Encodes images into a latent space and decodes them back to the image space, used as
991
+ the `compressor_model` in LDM’s `TrainLDM` and `SampleLDM`. Supports KL-divergence
992
+ or vector quantization (VQ) regularization for the latent representation.
993
+
994
+ Parameters
995
+ ----------
996
+ in_channels : int
997
+ Number of input channels (e.g., 3 for RGB images).
998
+ down_channels : list
999
+ List of channel sizes for encoder downsampling blocks (e.g., [32, 64, 128, 256]).
1000
+ up_channels : list
1001
+ List of channel sizes for decoder upsampling blocks (e.g., [256, 128, 64, 16]).
1002
+ out_channels : int
1003
+ Number of output channels, typically equal to `in_channels`.
1004
+ dropout_rate : float
1005
+ Dropout rate for regularization in convolutional and attention layers.
1006
+ num_heads : int
1007
+ Number of attention heads in self-attention layers.
1008
+ num_groups : int
1009
+ Number of groups for group normalization in attention layers.
1010
+ num_layers_per_block : int
1011
+ Number of convolutional layers in each downsampling and upsampling block.
1012
+ total_down_sampling_factor : int
1013
+ Total downsampling factor across the encoder (e.g., 8 for 8x reduction).
1014
+ latent_channels : int
1015
+ Number of channels in the latent representation for diffusion models.
1016
+ num_embeddings : int
1017
+ Number of discrete embeddings in the VQ codebook (if `use_vq=True`).
1018
+ use_vq : bool, optional
1019
+ If True, uses vector quantization (VQ) regularization; otherwise, uses
1020
+ KL-divergence (default: False).
1021
+ beta : float, optional
1022
+ Weight for KL-divergence loss (if `use_vq=False`) (default: 1.0).
1023
+ """
1024
+ def __init__(
1025
+ self,
1026
+ in_channels: int,
1027
+ down_channels: List[int],
1028
+ up_channels: List[int],
1029
+ out_channels: int,
1030
+ dropout_rate: float,
1031
+ num_heads: int,
1032
+ num_groups: int,
1033
+ num_layers_per_block: int,
1034
+ total_down_sampling_factor: int,
1035
+ latent_channels: int,
1036
+ num_embeddings: int,
1037
+ use_vq: bool = False,
1038
+ beta: float = 1.0
1039
+ ) -> None:
1040
+ super().__init__()
1041
+ assert in_channels == out_channels, "Input and output channels must match for auto-encoding"
1042
+ self.use_vq = use_vq
1043
+ self.beta = beta
1044
+ self.current_beta = beta
1045
+ num_down_blocks = len(down_channels) - 1
1046
+ self.down_sampling_factor = int(total_down_sampling_factor ** (1 / num_down_blocks))
1047
+
1048
+ # encoder
1049
+ self.conv1 = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, padding=1)
1050
+ self.down_blocks = nn.ModuleList([
1051
+ DownBlock(
1052
+ in_channels=down_channels[i],
1053
+ out_channels=down_channels[i + 1],
1054
+ num_layers=num_layers_per_block,
1055
+ down_sampling_factor=self.down_sampling_factor,
1056
+ dropout_rate=dropout_rate
1057
+ ) for i in range(num_down_blocks)
1058
+ ])
1059
+ self.attention1 = Attention(down_channels[-1], num_heads, num_groups, dropout_rate)
1060
+
1061
+ # latent projection
1062
+ if use_vq:
1063
+ self.vq_layer = VectorQuantizer(num_embeddings, down_channels[-1])
1064
+ self.quant_conv = nn.Conv2d(down_channels[-1], latent_channels, kernel_size=1)
1065
+ else:
1066
+ self.conv_mu = nn.Conv2d(down_channels[-1], down_channels[-1], kernel_size=3, padding=1)
1067
+ self.conv_logvar = nn.Conv2d(down_channels[-1], down_channels[-1], kernel_size=3, padding=1)
1068
+ self.quant_conv = nn.Conv2d(down_channels[-1], latent_channels, kernel_size=1)
1069
+
1070
+ # decoder
1071
+ self.conv2 = nn.Conv2d(latent_channels, up_channels[0], kernel_size=3, padding=1)
1072
+ self.attention2 = Attention(up_channels[0], num_heads, num_groups, dropout_rate)
1073
+ self.up_blocks = nn.ModuleList([
1074
+ UpBlock(
1075
+ in_channels=up_channels[i],
1076
+ out_channels=up_channels[i + 1],
1077
+ num_layers=num_layers_per_block,
1078
+ up_sampling_factor=self.down_sampling_factor,
1079
+ dropout_rate=dropout_rate
1080
+ ) for i in range(len(up_channels) - 1)
1081
+ ])
1082
+ self.conv3 = Conv3(up_channels[-1], out_channels, dropout_rate)
1083
+
1084
+ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
1085
+ """Applies reparameterization trick for variational autoencoding.
1086
+
1087
+ Samples from a Gaussian distribution using the mean and log-variance to enable
1088
+ differentiable training.
1089
+
1090
+ Parameters
1091
+ ----------
1092
+ mu : torch.Tensor
1093
+ Mean of the latent distribution, shape (batch_size, channels, height, width).
1094
+ logvar : torch.Tensor
1095
+ Log-variance of the latent distribution, same shape as `mu`.
1096
+
1097
+ Returns
1098
+ -------
1099
+ reparam (torch.Tensor) - Sampled latent representation, same shape as `mu`.
1100
+ """
1101
+ std = torch.exp(0.5 * logvar)
1102
+ eps = torch.randn_like(std)
1103
+ return mu + eps * std
1104
+
1105
+ def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, float]:
1106
+ """Encodes images into a latent representation.
1107
+
1108
+ Processes input images through the encoder, applying convolutions, downsampling,
1109
+ self-attention, and latent projection (VQ or KL-based).
1110
+
1111
+ Parameters
1112
+ ----------
1113
+ x : torch.Tensor
1114
+ Input images, shape (batch_size, in_channels, height, width).
1115
+
1116
+ Returns
1117
+ -------
1118
+ z : (torch.Tensor)
1119
+ Latent representation, shape (batch_size, latent_channels, height/down_sampling_factor, width/down_sampling_factor).
1120
+ reg_loss : float
1121
+ Regularization loss (VQ loss if `use_vq=True`, KL-divergence loss if `use_vq=False`).
1122
+
1123
+ **Notes**
1124
+
1125
+ - The VQ loss is computed by `VectorQuantizer` if `use_vq=True`.
1126
+ - The KL-divergence loss is normalized by batch size and latent size, weighted
1127
+ by `current_beta`.
1128
+ """
1129
+ x = self.conv1(x)
1130
+ for block in self.down_blocks:
1131
+ x = block(x)
1132
+ res_x = x
1133
+ x = self.attention1(x)
1134
+ x = x + res_x
1135
+ if self.use_vq:
1136
+ z, vq_loss = self.vq_layer(x)
1137
+ z = self.quant_conv(z)
1138
+ return z, vq_loss
1139
+ else:
1140
+ mu = self.conv_mu(x)
1141
+ logvar = self.conv_logvar(x)
1142
+ z = self.reparameterize(mu, logvar)
1143
+ z = self.quant_conv(z)
1144
+ kl_unnormalized = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
1145
+ batch_size = x.size(0)
1146
+ latent_size = torch.prod(torch.tensor(mu.shape[1:])).item()
1147
+ kl_loss = kl_unnormalized / (batch_size * latent_size) * self.current_beta
1148
+ return z, kl_loss
1149
+
1150
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
1151
+ """Decodes latent representations back to images.
1152
+
1153
+ Processes latent representations through the decoder, applying convolutions,
1154
+ self-attention, upsampling, and final reconstruction.
1155
+
1156
+ Parameters
1157
+ ----------
1158
+ z : torch.Tensor
1159
+ Latent representation, shape (batch_size, latent_channels,
1160
+ height/down_sampling_factor, width/down_sampling_factor).
1161
+
1162
+ Returns
1163
+ -------
1164
+ x (torch.Tensor) - Reconstructed images, shape (batch_size, out_channels, height, width).
1165
+ """
1166
+ x = self.conv2(z)
1167
+ res_x = x
1168
+ x = self.attention2(x)
1169
+ x = x + res_x
1170
+ for block in self.up_blocks:
1171
+ x = block(x)
1172
+ x = self.conv3(x)
1173
+ return x
1174
+
1175
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, float, float, torch.Tensor]:
1176
+ """Encodes images to latent space and decodes them, computing reconstruction and regularization losses.
1177
+
1178
+ Performs a full autoencoding pass, encoding images to the latent space, decoding
1179
+ them back, and calculating MSE reconstruction loss and regularization loss (VQ
1180
+ or KL-based).
1181
+
1182
+ Parameters
1183
+ ----------
1184
+ x : torch.Tensor
1185
+ Input images, shape (batch_size, in_channels, height, width).
1186
+
1187
+ Returns
1188
+ -------
1189
+ x_hat : torch.Tensor
1190
+ Reconstructed images, shape (batch_size, out_channels, height, width).
1191
+ total_loss : float
1192
+ Sum of reconstruction (MSE) and regularization losses.
1193
+ reg_loss : float
1194
+ Regularization loss (VQ or KL-divergence).
1195
+ z : torch.Tensor
1196
+ Latent representation, shape (batch_size, latent_channels, height/down_sampling_factor, width/down_sampling_factor).
1197
+
1198
+ **Notes**
1199
+
1200
+ - The reconstruction loss is computed as the mean squared error between `x_hat` and `x`.
1201
+ - The regularization loss depends on `use_vq` (VQ loss or KL-divergence).
1202
+ """
1203
+ z, reg_loss = self.encode(x)
1204
+ x_hat = self.decode(z)
1205
+ recon_loss = F.mse_loss(x_hat, x)
1206
+ total_loss = recon_loss.item() + reg_loss
1207
+ return x_hat, total_loss, reg_loss, z
1208
+
1209
+ ###==================================================================================================================###
1210
+
1211
+ class VectorQuantizer(nn.Module):
1212
+ """Vector quantization layer for discretizing latent representations.
1213
+
1214
+ Quantizes input latent vectors to the nearest embedding in a learned codebook,
1215
+ used in `AutoencoderLDM` when `use_vq=True` to enable discrete latent spaces for
1216
+ Latent Diffusion Models. Computes commitment and codebook losses to train the
1217
+ codebook embeddings.
1218
+
1219
+ Parameters
1220
+ ----------
1221
+ num_embeddings : int
1222
+ Number of discrete embeddings in the codebook.
1223
+ embedding_dim : int
1224
+ Dimensionality of each embedding vector (matches input channel dimension).
1225
+ commitment_cost : float, optional
1226
+ Weight for the commitment loss, encouraging inputs to be close to quantized values (default: 0.25).
1227
+
1228
+
1229
+ **Notes**
1230
+
1231
+ - The codebook embeddings are initialized uniformly in the range [-1/num_embeddings, 1/num_embeddings].
1232
+ - The forward pass flattens input latents, computes Euclidean distances to codebook embeddings, and selects the nearest embedding for quantization.
1233
+ - The commitment loss encourages input latents to be close to their quantized versions, while the codebook loss updates embeddings to match inputs.
1234
+ - A straight-through estimator is used to pass gradients from the quantized output to the input.
1235
+ """
1236
+ def __init__(self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25) -> None:
1237
+ super().__init__()
1238
+ self.embedding_dim = embedding_dim
1239
+ self.num_embeddings = num_embeddings
1240
+ self.commitment_cost = commitment_cost
1241
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
1242
+ self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
1243
+
1244
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1245
+ """Quantizes latent representations to the nearest codebook embedding.
1246
+
1247
+ Computes the closest embedding for each input vector, applies quantization,
1248
+ and calculates commitment and codebook losses for training.
1249
+
1250
+ Parameters
1251
+ ----------
1252
+ z : torch.Tensor
1253
+ Input latent representation, shape (batch_size, embedding_dim, height,
1254
+ width).
1255
+
1256
+ Returns
1257
+ -------
1258
+ quantized : torch.Tensor
1259
+ Quantized latent representation, same shape as `z`.
1260
+ vq_loss : torch.Tensor
1261
+ Sum of commitment and codebook losses.
1262
+
1263
+ **Notes**
1264
+
1265
+ - The input is flattened to (batch_size * height * width, embedding_dim) for distance computation.
1266
+ - Euclidean distances are computed efficiently using vectorized operations.
1267
+ - The commitment loss is scaled by `commitment_cost`, and the total VQ loss combines commitment and codebook losses.
1268
+ """
1269
+ z = z.contiguous()
1270
+ assert z.size(1) == self.embedding_dim, f"Expected channel dim {self.embedding_dim}, got {z.size(1)}"
1271
+ z_flattened = z.reshape(-1, self.embedding_dim)
1272
+ distances = (torch.sum(z_flattened ** 2, dim=1, keepdim=True)
1273
+ + torch.sum(self.embedding.weight ** 2, dim=1)
1274
+ - 2 * torch.matmul(z_flattened, self.embedding.weight.t()))
1275
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
1276
+ encodings = F.one_hot(encoding_indices, self.num_embeddings).float().squeeze(1)
1277
+ quantized = torch.matmul(encodings, self.embedding.weight).view_as(z)
1278
+ commitment_loss = self.commitment_cost * torch.mean((z.detach() - quantized) ** 2)
1279
+ codebook_loss = torch.mean((z - quantized.detach()) ** 2)
1280
+ quantized = z + (quantized - z).detach()
1281
+ return quantized, commitment_loss + codebook_loss
1282
+
1283
+ ###==================================================================================================================###
1284
+
1285
+ class DownBlock(nn.Module):
1286
+ """Downsampling block for the encoder in AutoencoderLDM.
1287
+
1288
+ Applies multiple convolutional layers with residual connections followed by
1289
+ downsampling to reduce spatial dimensions in the encoder of the variational
1290
+ autoencoder used in Latent Diffusion Models.
1291
+
1292
+ Parameters
1293
+ ----------
1294
+ in_channels : int
1295
+ Number of input channels.
1296
+ out_channels : int
1297
+ Number of output channels for convolutional layers.
1298
+ num_layers : int
1299
+ Number of convolutional layer pairs (Conv3) per block.
1300
+ down_sampling_factor : int
1301
+ Factor by which to downsample spatial dimensions.
1302
+ dropout_rate : float
1303
+ Dropout rate for Conv3 layers.
1304
+
1305
+ **Notes**
1306
+
1307
+ - Each layer pair consists of two Conv3 modules with a residual connection using a 1x1 convolution to match dimensions.
1308
+ - The downsampling is applied after all convolutional layers, reducing spatial dimensions by `down_sampling_factor`.
1309
+ """
1310
+ def __init__(self, in_channels: int, out_channels: int, num_layers: int, down_sampling_factor: int, dropout_rate: float) -> None:
1311
+ super().__init__()
1312
+ self.num_layers = num_layers
1313
+ self.conv1 = nn.ModuleList([
1314
+ Conv3(
1315
+ in_channels=in_channels if i == 0 else out_channels,
1316
+ out_channels=out_channels,
1317
+ dropout_rate=dropout_rate
1318
+ ) for i in range(self.num_layers)
1319
+ ])
1320
+ self.conv2 = nn.ModuleList([
1321
+ Conv3(
1322
+ in_channels=out_channels,
1323
+ out_channels=out_channels,
1324
+ dropout_rate=dropout_rate
1325
+ ) for _ in range(self.num_layers)
1326
+ ])
1327
+
1328
+ self.down_sampling = DownSampling(
1329
+ in_channels=out_channels,
1330
+ out_channels=out_channels,
1331
+ down_sampling_factor=down_sampling_factor
1332
+ )
1333
+ self.resnet = nn.ModuleList([
1334
+ nn.Conv2d(
1335
+ in_channels=in_channels if i == 0 else out_channels,
1336
+ out_channels=out_channels,
1337
+ kernel_size=1
1338
+ ) for i in range(num_layers)
1339
+
1340
+ ])
1341
+
1342
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1343
+ """Processes input through convolutional layers and downsampling.
1344
+
1345
+ Parameters
1346
+ ----------
1347
+ x : torch.Tensor
1348
+ Input tensor, shape (batch_size, in_channels, height, width).
1349
+
1350
+ Returns
1351
+ -------
1352
+ output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor).
1353
+ """
1354
+ output = x
1355
+ for i in range(self.num_layers):
1356
+ resnet_input = output
1357
+ output = self.conv1[i](output)
1358
+ output = self.conv2[i](output)
1359
+ output = output + self.resnet[i](resnet_input)
1360
+ output = self.down_sampling(output)
1361
+ return output
1362
+
1363
+ ###==================================================================================================================###
1364
+
1365
+ class Conv3(nn.Module):
1366
+ """Convolutional layer with group normalization, SiLU activation, and dropout.
1367
+
1368
+ Used in DownBlock and UpBlock of AutoencoderLDM for feature extraction and
1369
+ transformation in the encoder and decoder.
1370
+
1371
+ Parameters
1372
+ ----------
1373
+ in_channels : int
1374
+ Number of input channels.
1375
+ out_channels : int
1376
+ Number of output channels.
1377
+ dropout_rate : float
1378
+ Dropout rate for regularization.
1379
+
1380
+ **Notes**
1381
+
1382
+ - The layer applies group normalization, SiLU activation, dropout, and a 3x3 convolution in sequence.
1383
+ - Spatial dimensions are preserved due to padding=1 in the convolution.
1384
+ """
1385
+ def __init__(self, in_channels: int, out_channels: int, dropout_rate: float) -> None:
1386
+ super().__init__()
1387
+ self.group_norm = nn.GroupNorm(num_groups=8, num_channels=in_channels)
1388
+ self.activation = nn.SiLU()
1389
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
1390
+ self.dropout = nn.Dropout(p=dropout_rate)
1391
+
1392
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1393
+ """Processes input through group normalization, activation, dropout, and convolution.
1394
+
1395
+ Parameters
1396
+ ----------
1397
+ x : torch.Tensor
1398
+ Input tensor, shape (batch_size, in_channels, height, width).
1399
+
1400
+ Returns
1401
+ -------
1402
+ x (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height, width).
1403
+ """
1404
+ x = self.group_norm(x)
1405
+ x = self.activation(x)
1406
+ x = self.dropout(x)
1407
+ x = self.conv(x)
1408
+ return x
1409
+
1410
+ ###==================================================================================================================###
1411
+
1412
+ class DownSampling(nn.Module):
1413
+ """Downsampling module for reducing spatial dimensions in AutoencoderLDM’s encoder.
1414
+
1415
+ Combines convolutional downsampling and max pooling, concatenating their outputs
1416
+ to preserve feature information during downsampling in DownBlock.
1417
+
1418
+ Parameters
1419
+ ----------
1420
+ in_channels : int
1421
+ Number of input channels.
1422
+ out_channels : int
1423
+ Number of output channels (sum of conv and pool paths).
1424
+ down_sampling_factor : int
1425
+ Factor by which to downsample spatial dimensions.
1426
+
1427
+ **Notes**
1428
+
1429
+ - The module splits the output channels evenly between convolutional and pooling paths, concatenating them along the channel dimension.
1430
+ - The convolutional path uses a stride equal to `down_sampling_factor`, while the pooling path uses max pooling with the same factor.
1431
+ """
1432
+ def __init__(self, in_channels: int, out_channels: int, down_sampling_factor: int) -> None:
1433
+ super().__init__()
1434
+ self.down_sampling_factor = down_sampling_factor
1435
+ self.conv = nn.Sequential(
1436
+ nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
1437
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2,
1438
+ kernel_size=3, stride=down_sampling_factor, padding=1)
1439
+ )
1440
+ self.pool = nn.Sequential(
1441
+ nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
1442
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2,
1443
+ kernel_size=1, stride=1, padding=0)
1444
+ )
1445
+
1446
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
1447
+ """Downsamples input by combining convolutional and pooling paths.
1448
+
1449
+ Parameters
1450
+ ----------
1451
+ batch : torch.Tensor
1452
+ Input tensor, shape (batch_size, in_channels, height, width).
1453
+
1454
+ Returns
1455
+ -------
1456
+ x (torch.Tensor) - Downsampled tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor).
1457
+ """
1458
+ return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
1459
+
1460
+ ###==================================================================================================================###
1461
+
1462
+ class Attention(nn.Module):
1463
+ """Self-attention module for feature enhancement in AutoencoderLDM.
1464
+
1465
+ Applies multi-head self-attention to enhance features in the encoder and decoder,
1466
+ used after downsampling (in DownBlock) and before upsampling (in UpBlock).
1467
+
1468
+ Parameters
1469
+ ----------
1470
+ num_channels : int
1471
+ Number of input and output channels (embedding dimension for attention).
1472
+ num_heads : int
1473
+ Number of attention heads.
1474
+ num_groups : int
1475
+ Number of groups for group normalization.
1476
+ dropout_rate : float
1477
+ Dropout rate for attention outputs.
1478
+
1479
+ **Notes**
1480
+
1481
+ - The input is reshaped to (batch_size, height * width, num_channels) for attention processing, then restored to (batch_size, num_channels, height, width).
1482
+ - Group normalization is applied before attention to stabilize training.
1483
+ """
1484
+ def __init__(self, num_channels: int, num_heads: int, num_groups: int, dropout_rate: float) -> None:
1485
+ super().__init__()
1486
+ self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
1487
+ self.attention = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True)
1488
+ self.dropout = nn.Dropout(p=dropout_rate)
1489
+
1490
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1491
+ """Applies self-attention to input features.
1492
+
1493
+ Parameters
1494
+ ----------
1495
+ x : torch.Tensor
1496
+ Input tensor, shape (batch_size, num_channels, height, width).
1497
+
1498
+ Returns
1499
+ -------
1500
+ x (torch.Tensor) - Output tensor, same shape as input.
1501
+ """
1502
+ batch_size, channels, h, w = x.shape
1503
+ x = x.reshape(batch_size, channels, h * w)
1504
+ x = self.group_norm(x)
1505
+ x = x.transpose(1, 2)
1506
+ x, _ = self.attention(x, x, x)
1507
+ x = self.dropout(x)
1508
+ x = x.transpose(1, 2).reshape(batch_size, channels, h, w)
1509
+ return x
1510
+
1511
+ ###==================================================================================================================###
1512
+
1513
+ class UpBlock(nn.Module):
1514
+ """Upsampling block for the decoder in AutoencoderLDM.
1515
+
1516
+ Applies upsampling followed by multiple convolutional layers with residual
1517
+ connections to increase spatial dimensions in the decoder of the variational
1518
+ autoencoder used in Latent Diffusion Models.
1519
+
1520
+ Parameters
1521
+ ----------
1522
+ in_channels : int
1523
+ Number of input channels.
1524
+ out_channels : int
1525
+ Number of output channels for convolutional layers.
1526
+ num_layers : int
1527
+ Number of convolutional layer pairs (Conv3) per block.
1528
+ up_sampling_factor : int
1529
+ Factor by which to upsample spatial dimensions.
1530
+ dropout_rate : float
1531
+ Dropout rate for Conv3 layers.
1532
+
1533
+ **Notes**
1534
+
1535
+ - Upsampling is applied first, followed by convolutional layer pairs with residual connections using 1x1 convolutions.
1536
+ - Each layer pair consists of two Conv3 modules.
1537
+ """
1538
+ def __init__(self, in_channels: int, out_channels: int, num_layers: int, up_sampling_factor: int, dropout_rate: float) -> None:
1539
+ super().__init__()
1540
+ self.num_layers = num_layers
1541
+ effective_in_channels = in_channels
1542
+
1543
+ self.up_sampling = UpSampling(
1544
+ in_channels=in_channels,
1545
+ out_channels=in_channels,
1546
+ up_sampling_factor=up_sampling_factor
1547
+ )
1548
+
1549
+ self.conv1 = nn.ModuleList([
1550
+ Conv3(
1551
+ in_channels=effective_in_channels if i == 0 else out_channels,
1552
+ out_channels=out_channels,
1553
+ dropout_rate=dropout_rate
1554
+ ) for i in range(self.num_layers)
1555
+ ])
1556
+ self.conv2 = nn.ModuleList([
1557
+ Conv3(
1558
+ in_channels=out_channels,
1559
+ out_channels=out_channels,
1560
+ dropout_rate=dropout_rate
1561
+ ) for _ in range(self.num_layers)
1562
+ ])
1563
+ self.resnet = nn.ModuleList([
1564
+ nn.Conv2d(
1565
+ in_channels=effective_in_channels if i == 0 else out_channels,
1566
+ out_channels=out_channels,
1567
+ kernel_size=1
1568
+ ) for i in range(self.num_layers)
1569
+ ])
1570
+
1571
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1572
+ """Processes input through upsampling and convolutional layers.
1573
+
1574
+ Parameters
1575
+ ----------
1576
+ x : torch.Tensor
1577
+ Input tensor, shape (batch_size, in_channels, height, width).
1578
+
1579
+ Returns
1580
+ -------
1581
+ output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height * up_sampling_factor, width * up_sampling_factor).
1582
+ """
1583
+ x = self.up_sampling(x)
1584
+ output = x
1585
+ for i in range(self.num_layers):
1586
+ resnet_input = output
1587
+ output = self.conv1[i](output)
1588
+ output = self.conv2[i](output)
1589
+ output = output + self.resnet[i](resnet_input)
1590
+ return output
1591
+
1592
+ ###==================================================================================================================###
1593
+
1594
+ class UpSampling(nn.Module):
1595
+ """Upsampling module for increasing spatial dimensions in AutoencoderLDM’s decoder.
1596
+
1597
+ Combines transposed convolution and nearest-neighbor upsampling, concatenating
1598
+ their outputs to preserve feature information during upsampling in UpBlock.
1599
+
1600
+ Parameters
1601
+ ----------
1602
+ in_channels : int
1603
+ Number of input channels.
1604
+ out_channels : int
1605
+ Number of output channels (sum of conv and upsample paths).
1606
+ up_sampling_factor : int
1607
+ Factor by which to upsample spatial dimensions.
1608
+
1609
+ **Notes**
1610
+
1611
+ - The module splits the output channels evenly between transposed convolution and upsampling paths, concatenating them along the channel dimension.
1612
+ - If the spatial dimensions of the two paths differ, the upsampling path is interpolated to match the convolutional path’s size.
1613
+ """
1614
+ def __init__(self, in_channels: int, out_channels: int, up_sampling_factor: int) -> None:
1615
+ super().__init__()
1616
+ half_out_channels = out_channels // 2
1617
+ self.up_sampling_factor = up_sampling_factor
1618
+ self.conv = nn.Sequential(
1619
+ nn.ConvTranspose2d(
1620
+ in_channels=in_channels,
1621
+ out_channels=half_out_channels,
1622
+ kernel_size=3,
1623
+ stride=up_sampling_factor,
1624
+ padding=1,
1625
+ output_padding=up_sampling_factor - 1
1626
+ ),
1627
+ nn.Conv2d(
1628
+ in_channels=half_out_channels,
1629
+ out_channels=half_out_channels,
1630
+ kernel_size=1,
1631
+ stride=1,
1632
+ padding=0
1633
+ )
1634
+ )
1635
+ self.up_sample = nn.Sequential(
1636
+ nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
1637
+ nn.Conv2d(
1638
+ in_channels=in_channels,
1639
+ out_channels=half_out_channels,
1640
+ kernel_size=1,
1641
+ stride=1,
1642
+ padding=0
1643
+ )
1644
+ )
1645
+
1646
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
1647
+ """Upsamples input by combining transposed convolution and upsampling paths.
1648
+
1649
+ Parameters
1650
+ ----------
1651
+ batch : torch.Tensor
1652
+ Input tensor, shape (batch_size, in_channels, height, width).
1653
+
1654
+ Returns
1655
+ -------
1656
+ x (torch.Tensor) - Upsampled tensor, shape (batch_size, out_channels, height * up_sampling_factor, width * up_sampling_factor).
1657
+
1658
+ **Notes**
1659
+
1660
+ - Interpolation is applied if the spatial dimensions of the convolutional and upsampling paths differ, using nearest-neighbor mode.
1661
+ """
1662
+ conv_output = self.conv(batch)
1663
+ up_sample_output = self.up_sample(batch)
1664
+ if conv_output.shape[2:] != up_sample_output.shape[2:]:
1665
+ _, _, h, w = conv_output.shape
1666
+ up_sample_output = torch.nn.functional.interpolate(
1667
+ up_sample_output,
1668
+ size=(h, w),
1669
+ mode='nearest'
1670
+ )
1671
+ return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
1672
+
1673
+ ###==================================================================================================================###
1674
+
1675
+ class TrainAE(nn.Module):
1676
+ """Trainer for the AutoencoderLDM variational autoencoder in Latent Diffusion Models.
1677
+
1678
+ Optimizes the AutoencoderLDM model to compress images into latent space and reconstruct
1679
+ them, using reconstruction loss (MSE), regularization (KL or VQ), and optional
1680
+ perceptual loss (LPIPS). Supports mixed precision, KL warmup, early stopping, and
1681
+ learning rate scheduling, with evaluation metrics (MSE, PSNR, SSIM, FID, LPIPS).
1682
+
1683
+ Parameters
1684
+ ----------
1685
+ model : nn.Module
1686
+ The variational autoencoder model (AutoencoderLDM) to train.
1687
+ optimizer : torch.optim.Optimizer
1688
+ Optimizer for training (e.g., Adam).
1689
+ data_loader : torch.utils.data.DataLoader
1690
+ DataLoader for training data.
1691
+ val_loader : torch.utils.data.DataLoader, optional
1692
+ DataLoader for validation data (default: None).
1693
+ max_epochs : int, optional
1694
+ Maximum number of training epochs (default: 100).
1695
+ metrics_ : object, optional
1696
+ Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
1697
+ device : None, optional
1698
+ Device for computation (e.g., 'cuda', 'cpu').
1699
+ store_path : str, optional
1700
+ Path to save model checkpoints (default: 'vlc_model.pth').
1701
+ checkpoint : int, optional
1702
+ Frequency (in epochs) to save model checkpoints (default: 10).
1703
+ kl_warmup_epochs : int, optional
1704
+ Number of epochs for KL loss warmup (default: 10).
1705
+ patience : int, optional
1706
+ Number of epochs to wait for early stopping if validation loss doesn’t improve
1707
+ (default: 10).
1708
+ val_frequency : int, optional
1709
+ Frequency (in epochs) for validation and metric computation (default: 5).
1710
+ use_ddp : bool, optional
1711
+ Whether to use Distributed Data Parallel training (default: False).
1712
+ grad_accumulation_steps : int, optional
1713
+ Number of gradient accumulation steps before optimizer update (default: 1).
1714
+ log_frequency : int, optional
1715
+ Number of epochs before printing loss.
1716
+ """
1717
+
1718
+ def __init__(
1719
+ self,
1720
+ model: torch.nn.Module,
1721
+ optimizer: torch.optim.Optimizer,
1722
+ data_loader: torch.utils.data.DataLoader,
1723
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
1724
+ max_epochs: int = 100,
1725
+ metrics_: Optional[Any] = None,
1726
+ device: Optional[Union[str, torch.device]] = None,
1727
+ store_path: str = "vlc_model",
1728
+ checkpoint: int = 10,
1729
+ kl_warmup_epochs: int = 10,
1730
+ patience: int = 10,
1731
+ val_frequency: int = 5,
1732
+ warmup_epochs: int = 100,
1733
+ use_ddp: bool = False,
1734
+ grad_accumulation_steps: int = 1,
1735
+ log_frequency: int = 1,
1736
+ use_compilation: bool = False
1737
+ ) -> None:
1738
+ super().__init__()
1739
+
1740
+ # initialize DDP settings first
1741
+ self.use_ddp = use_ddp
1742
+ self.grad_accumulation_steps = grad_accumulation_steps
1743
+ if device is None:
1744
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1745
+ elif isinstance(device, str):
1746
+ self.device = torch.device(device)
1747
+ else:
1748
+ self.device = device
1749
+
1750
+ # setup distributed training if enabled
1751
+ if self.use_ddp:
1752
+ self._setup_ddp()
1753
+ else:
1754
+ self._setup_single_gpu()
1755
+
1756
+ self.model = model.to(self.device)
1757
+ self.optimizer = optimizer
1758
+ self.data_loader = data_loader
1759
+ self.val_loader = val_loader
1760
+ self.max_epochs = max_epochs
1761
+ self.metrics_ = metrics_
1762
+ self.store_path = store_path
1763
+ self.checkpoint = checkpoint
1764
+ self.kl_warmup_epochs = kl_warmup_epochs
1765
+ self.patience = patience
1766
+ self.use_compilation = use_compilation
1767
+
1768
+ # Learning rate scheduling
1769
+ self.scheduler = ReduceLROnPlateau(
1770
+ self.optimizer,
1771
+ patience=self.patience,
1772
+ factor=0.5
1773
+ )
1774
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
1775
+ self.val_frequency = val_frequency
1776
+ self.log_frequency = log_frequency
1777
+
1778
+ def _setup_ddp(self) -> None:
1779
+ """Setup Distributed Data Parallel training configuration.
1780
+
1781
+ Initializes process group, determines rank information, and sets up
1782
+ CUDA device for the current process.
1783
+ """
1784
+ # check if DDP environment variables are set
1785
+ if "RANK" not in os.environ:
1786
+ raise ValueError("DDP enabled but RANK environment variable not set")
1787
+ if "LOCAL_RANK" not in os.environ:
1788
+ raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
1789
+ if "WORLD_SIZE" not in os.environ:
1790
+ raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
1791
+
1792
+ # ensure CUDA is available for DDP
1793
+ if not torch.cuda.is_available():
1794
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
1795
+
1796
+ # initialize process group only if not already initialized
1797
+ if not torch.distributed.is_initialized():
1798
+ init_process_group(backend="nccl")
1799
+
1800
+ # get rank information
1801
+ self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
1802
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
1803
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
1804
+
1805
+ # set device and make it current
1806
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
1807
+ torch.cuda.set_device(self.device)
1808
+
1809
+ # master process handles logging, checkpointing, etc.
1810
+ self.master_process = self.ddp_rank == 0
1811
+
1812
+ if self.master_process:
1813
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
1814
+
1815
+ def _setup_single_gpu(self) -> None:
1816
+ """Setup single GPU or CPU training configuration."""
1817
+ self.ddp_rank = 0
1818
+ self.ddp_local_rank = 0
1819
+ self.ddp_world_size = 1
1820
+ self.master_process = True
1821
+
1822
+
1823
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[float, float]:
1824
+ """Loads a training checkpoint to resume training.
1825
+
1826
+ Restores the state of the noise predictor, conditional model (if applicable),
1827
+ and optimizer from a saved checkpoint.
1828
+
1829
+ Parameters
1830
+ ----------
1831
+ checkpoint_path : str
1832
+ Path to the checkpoint file.
1833
+
1834
+ Returns
1835
+ -------
1836
+ epoch : float
1837
+ The epoch at which the checkpoint was saved (int).
1838
+ loss : float
1839
+ The loss at the checkpoint (float).
1840
+ """
1841
+ try:
1842
+ # load checkpoint with proper device mapping
1843
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
1844
+ except FileNotFoundError:
1845
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
1846
+
1847
+
1848
+ if 'model_state_dict' not in checkpoint:
1849
+ raise KeyError("Checkpoint missing 'model_state_dict' key")
1850
+
1851
+ # Handle DDP wrapped model state dict
1852
+ state_dict = checkpoint['model_state_dict']
1853
+ if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
1854
+ # if loading non-DDP checkpoint into DDP model, add 'module.' prefix
1855
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
1856
+ elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
1857
+ # if loading DDP checkpoint into non-DDP model, remove 'module.' prefix
1858
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
1859
+ self.model.load_state_dict(state_dict)
1860
+
1861
+ if 'optimizer_state_dict' not in checkpoint:
1862
+ raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
1863
+ try:
1864
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
1865
+ except ValueError as e:
1866
+ warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
1867
+
1868
+ epoch = checkpoint.get('epoch', -1)
1869
+ loss = checkpoint.get('loss', float('inf'))
1870
+
1871
+ self.noise_predictor.to(self.device)
1872
+ if self.conditional_model is not None:
1873
+ self.conditional_model.to(self.device)
1874
+
1875
+ if self.master_process:
1876
+ print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
1877
+
1878
+ return epoch, loss
1879
+
1880
+ @staticmethod
1881
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
1882
+ """Creates a learning rate scheduler for warmup.
1883
+
1884
+ Generates a scheduler that linearly increases the learning rate from 0 to the
1885
+ optimizer's initial value over the specified warmup epochs, then maintains it.
1886
+
1887
+ Parameters
1888
+ ----------
1889
+ optimizer : torch.optim.Optimizer
1890
+ Optimizer to apply the scheduler to.
1891
+ warmup_epochs : int
1892
+ Number of epochs for the warmup phase.
1893
+
1894
+ Returns
1895
+ -------
1896
+ torch.optim.lr_scheduler.LambdaLR
1897
+ Learning rate scheduler for warmup.
1898
+ """
1899
+
1900
+ def lr_lambda(epoch):
1901
+ if epoch < warmup_epochs:
1902
+ return epoch / warmup_epochs
1903
+ return 1.0
1904
+
1905
+ return LambdaLR(optimizer, lr_lambda)
1906
+
1907
+ def _wrap_models_for_ddp(self) -> None:
1908
+ """Wrap models with DistributedDataParallel for multi-GPU training."""
1909
+ if self.use_ddp:
1910
+ # wrap noise predictor with DDP
1911
+ self.noise_predictor = DDP(
1912
+ self.noise_predictor,
1913
+ device_ids=[self.ddp_local_rank],
1914
+ find_unused_parameters=True
1915
+ )
1916
+
1917
+ # wrap conditional model with DDP if it exists
1918
+ if self.conditional_model is not None:
1919
+ self.conditional_model = DDP(
1920
+ self.conditional_model,
1921
+ device_ids=[self.ddp_local_rank],
1922
+ find_unused_parameters=True
1923
+ )
1924
+
1925
+
1926
+ def forward(self) -> Tuple[List[float], float]:
1927
+ """Trains the AutoencoderLDM model with mixed precision and evaluation metrics.
1928
+
1929
+ Performs training with reconstruction and regularization losses, KL warmup, gradient
1930
+ clipping, and learning rate scheduling. Saves checkpoints for the best validation
1931
+ loss and supports early stopping.
1932
+
1933
+ Returns
1934
+ -------
1935
+ train_losses : list
1936
+ List of mean training losses per epoch.
1937
+ best_val_loss : float
1938
+ Best validation loss achieved (or best training loss if no validation).
1939
+ """
1940
+ # compile models for optimization (if supported)
1941
+ if self.use_compilation:
1942
+ try:
1943
+ self.model = torch.compile(self.model)
1944
+ except Exception as e:
1945
+ if self.master_process:
1946
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
1947
+
1948
+ # wrap models for DDP after compilation
1949
+ self._wrap_models_for_ddp()
1950
+
1951
+ # initialize training components
1952
+ scaler = torch.GradScaler()
1953
+ train_losses = []
1954
+ best_val_loss = float("inf")
1955
+ wait = 0
1956
+
1957
+ # main training loop
1958
+ for epoch in range(self.max_epochs):
1959
+ # set epoch for distributed sampler if using DDP
1960
+ if self.use_ddp and hasattr(self.data_loader.sampler, 'set_epoch'):
1961
+ self.data_loader.sampler.set_epoch(epoch)
1962
+
1963
+ if self.model.use_vq:
1964
+ beta = 1.0 # no warmup for VQ
1965
+ else:
1966
+ beta = min(1.0, epoch / self.kl_warmup_epochs) * self.model.beta
1967
+ self.model.current_beta = beta
1968
+
1969
+ train_losses_epoch = []
1970
+
1971
+ # training step loop with gradient accumulation
1972
+ for step, (x, y) in enumerate(tqdm(self.data_loader, disable=not self.master_process)):
1973
+ x = x.to(self.device)
1974
+
1975
+ # forward pass with mixed precision
1976
+ with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
1977
+ x_hat, loss, reg_loss, z = self.model(x)
1978
+ # compute loss and scale for gradient accumulation
1979
+ loss = loss / self.grad_accumulation_steps
1980
+
1981
+ # backward pass
1982
+ scaler.scale(loss).backward()
1983
+
1984
+ # gradient accumulation and optimizer step
1985
+ if (step + 1) % self.grad_accumulation_steps == 0:
1986
+ # clip gradients
1987
+ scaler.unscale_(self.optimizer)
1988
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
1989
+
1990
+ # optimizer step
1991
+ scaler.step(self.optimizer)
1992
+ scaler.update()
1993
+ self.optimizer.zero_grad()
1994
+
1995
+ # update learning rate (warmup scheduler)
1996
+ self.warmup_lr_scheduler.step()
1997
+
1998
+ # record loss (unscaled)
1999
+ train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
2000
+
2001
+ # compute mean training loss
2002
+ mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
2003
+
2004
+ # all-reduce loss across processes for DDP
2005
+ if self.use_ddp:
2006
+ loss_tensor = torch.tensor(mean_train_loss, device=self.device)
2007
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
2008
+ mean_train_loss = loss_tensor.item()
2009
+
2010
+ train_losses.append(mean_train_loss)
2011
+
2012
+ # print training progress (only master process)
2013
+ if self.master_process and (epoch + 1) % self.log_frequency == 0:
2014
+ current_lr = self.optimizer.param_groups[0]['lr']
2015
+ print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
2016
+
2017
+ # validation step
2018
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
2019
+ val_metrics = self.validate()
2020
+ val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
2021
+
2022
+ if self.master_process:
2023
+ print(f" | Val Loss: {val_loss:.4f}", end="")
2024
+ if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
2025
+ print(f" | FID: {fid:.4f}", end="")
2026
+ if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
2027
+ print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
2028
+ if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
2029
+ print(f" | LPIPS: {lpips_score:.4f}", end="")
2030
+ print()
2031
+
2032
+ current_best = val_loss
2033
+ self.scheduler.step(val_loss)
2034
+ else:
2035
+ if self.master_process:
2036
+ print()
2037
+ current_best = mean_train_loss
2038
+ self.scheduler.step(mean_train_loss)
2039
+
2040
+ # save checkpoint and early stopping (only master process)
2041
+ if self.master_process:
2042
+ if current_best < best_val_loss and (epoch + 1) % self.val_frequency == 0:
2043
+ best_val_loss = current_best
2044
+ wait = 0
2045
+ self._save_checkpoint(epoch + 1, best_val_loss)
2046
+ else:
2047
+ wait += 1
2048
+ if wait >= self.patience:
2049
+ print("Early stopping triggered")
2050
+ self._save_checkpoint(epoch + 1, best_val_loss, "_early_stop")
2051
+ break
2052
+
2053
+ # clean up DDP
2054
+ if self.use_ddp:
2055
+ destroy_process_group()
2056
+
2057
+ return train_losses, best_val_loss
2058
+
2059
+ def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "") -> None:
2060
+ """Save model checkpoint (only called by master process).
2061
+
2062
+ Parameters
2063
+ ----------
2064
+ epoch : int
2065
+ Current epoch number.
2066
+ loss : float
2067
+ Current loss value.
2068
+ suffix : str, optional
2069
+ Suffix to add to checkpoint filename.
2070
+ """
2071
+ try:
2072
+ # get state dicts, handling DDP wrapping
2073
+ model_state = (
2074
+ self.model.module.state_dict() if self.use_ddp else self.model.state_dict()
2075
+ )
2076
+
2077
+ checkpoint = {
2078
+ 'epoch': epoch,
2079
+ 'model_state_dict': model_state,
2080
+ 'optimizer_state_dict': self.optimizer.state_dict(),
2081
+ 'loss': loss,
2082
+ 'max_epochs': self.max_epochs,
2083
+ }
2084
+
2085
+ filename = f"ldm_epoch_{epoch}{suffix}.pth"
2086
+ filepath = os.path.join(self.store_path, filename)
2087
+ os.makedirs(self.store_path, exist_ok=True)
2088
+ torch.save(checkpoint, filepath)
2089
+ print(f"Model saved at epoch {epoch}")
2090
+ except Exception as e:
2091
+ print(f"Failed to save model: {e}")
2092
+
2093
+ def validate(self) -> Tuple[float, float, float, float, float, float]:
2094
+ """Validates the AutoencoderLDM model and computes evaluation Metrics.
2095
+
2096
+ Computes validation loss and optional Metrics (MSE, PSNR, SSIM, FID, LPIPS) using
2097
+ the provided Metrics object.
2098
+
2099
+ Returns
2100
+ -------
2101
+ val_loss : float
2102
+ Mean validation loss.
2103
+ fid : float, or `float('inf')` if not computed
2104
+ Mean FID score.
2105
+ mse : float, or None if not computed
2106
+ Mean MSE
2107
+ psnr : float, or None if not computed
2108
+ Mean PSNR
2109
+ ssim : float, or None if not computed
2110
+ Mean SSIM
2111
+ lpips_score : float, or None if not computed
2112
+ Mean LPIPS score
2113
+ """
2114
+ self.model.eval()
2115
+
2116
+ val_losses = []
2117
+ fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
2118
+
2119
+ with torch.no_grad():
2120
+ for x, _ in self.val_loader:
2121
+ x = x.to(self.device)
2122
+ x_hat, loss, reg_loss, z = self.model(x)
2123
+ val_losses.append(loss.item())
2124
+
2125
+ # compute metrics
2126
+ if self.metrics_ is not None:
2127
+ metrics_result = self.metrics_.forward(x, x_hat)
2128
+ fid, mse, psnr, ssim, lpips_score = metrics_result
2129
+
2130
+ if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
2131
+ fid_scores.append(fid)
2132
+ if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
2133
+ mse_scores.append(mse)
2134
+ psnr_scores.append(psnr)
2135
+ ssim_scores.append(ssim)
2136
+ if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
2137
+ lpips_scores.append(lpips_score)
2138
+
2139
+ # compute average metrics
2140
+ val_loss = torch.tensor(val_losses).mean().item()
2141
+
2142
+ # all-reduce validation metrics across processes for DDP
2143
+ if self.use_ddp:
2144
+ val_loss_tensor = torch.tensor(val_loss, device=self.device)
2145
+ dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
2146
+ val_loss = val_loss_tensor.item()
2147
+
2148
+ fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
2149
+ mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
2150
+ psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
2151
+ ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
2152
+ lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
2153
+
2154
+ self.model.train()
2155
+
2156
+ return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg