TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
torchdiff/ldm.py
ADDED
|
@@ -0,0 +1,2156 @@
|
|
|
1
|
+
"""
|
|
2
|
+
**Latent Diffusion Models (LDM)**
|
|
3
|
+
|
|
4
|
+
This module provides a framework for training and sampling Latent Diffusion Models, as
|
|
5
|
+
described in Rombach et al. (2022, "High-Resolution Image Synthesis with Latent Diffusion
|
|
6
|
+
Models"). It supports diffusion in the latent space using a variational autoencoder
|
|
7
|
+
(compressor model), includes utilities for training the autoencoder, noise predictor, and
|
|
8
|
+
conditional model, and provides metrics for evaluating generated images. The framework is
|
|
9
|
+
compatible with DDPM, DDIM, and SDE diffusion models, supporting both unconditional and
|
|
10
|
+
conditional generation with text prompts.
|
|
11
|
+
|
|
12
|
+
**Components**
|
|
13
|
+
|
|
14
|
+
- **AutoencoderLDM**: Variational autoencoder for compressing images to latent space and
|
|
15
|
+
decoding back to image space.
|
|
16
|
+
- **TrainAE**: Trainer for AutoencoderLDM, optimizing reconstruction and regularization
|
|
17
|
+
losses with evaluation metrics.
|
|
18
|
+
- **TrainLDM**: Training loop with mixed precision, warmup, and scheduling for the noise
|
|
19
|
+
predictor and conditional model (e.g., TextEncoder with projection layers) in latent
|
|
20
|
+
space, with image-domain evaluation metrics using a reverse diffusion model.
|
|
21
|
+
- **SampleLDM**: Image generation from trained models, decoding from latent to image space.
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
**Notes**
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
- The `varinace_scheduler` parameter expects an external hyperparameter module (e.g.,
|
|
28
|
+
VarianceSchedulerDDPM, VarianceSchedulerSDE) as an nn.Module for noise schedule management.
|
|
29
|
+
- AutoencoderLDM serves as the `compressor_model` in TrainLDM and SampleLDM, providing
|
|
30
|
+
`encode` and `decode` methods for latent space conversion. It supports KL-divergence or
|
|
31
|
+
vector quantization (VQ) regularization, using internal components (DownBlock, UpBlock,
|
|
32
|
+
Conv3, DownSampling, UpSampling, Attention, VectorQuantizer).
|
|
33
|
+
- TrainAE trains AutoencoderLDM, optimizing reconstruction (MSE), regularization (KL or
|
|
34
|
+
VQ), and optional perceptual (LPIPS) losses, with metrics (MSE, PSNR, SSIM, FID, LPIPS)
|
|
35
|
+
computed via the Metrics class, KL warmup, early stopping, and learning rate scheduling.
|
|
36
|
+
- TrainLDM trains the noise predictor and conditional model, optimizing MSE between
|
|
37
|
+
predicted and ground truth noise, with optional validation metrics (MSE, PSNR, SSIM, FID,
|
|
38
|
+
LPIPS) on generated images decoded from latents sampled using a reverse diffusion model
|
|
39
|
+
(e.g., ReverseDDPM).
|
|
40
|
+
- SampleLDM supports multiple diffusion models ("ddpm", "ddim", "sde") via the `model`
|
|
41
|
+
parameter, requiring compatible `reverse_diffusion` modules (e.g., ReverseDDPM,
|
|
42
|
+
ReverseDDIM, ReverseSDE).
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
**References**
|
|
46
|
+
|
|
47
|
+
- Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models."
|
|
48
|
+
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
- Esser, Patrick, Robin Rombach, and Bjorn Ommer. "Taming transformers for high-resolution image synthesis."
|
|
52
|
+
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
|
|
53
|
+
|
|
54
|
+
---------------------------------------------------------------------------------
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
import torch
|
|
59
|
+
import torch.nn as nn
|
|
60
|
+
import torch.nn.functional as F
|
|
61
|
+
from typing import Optional, Tuple, Any, Callable, List, Union, Self
|
|
62
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
63
|
+
import torch.distributed as dist
|
|
64
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
65
|
+
from torch.distributed import init_process_group, destroy_process_group
|
|
66
|
+
from torch.optim.lr_scheduler import LambdaLR
|
|
67
|
+
from transformers import BertTokenizer
|
|
68
|
+
import warnings
|
|
69
|
+
from tqdm import tqdm
|
|
70
|
+
from torchvision.utils import save_image
|
|
71
|
+
import os
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
###==================================================================================================================###
|
|
76
|
+
|
|
77
|
+
class TrainLDM(nn.Module):
|
|
78
|
+
"""Trainer for the noise predictor in Latent Diffusion Models.
|
|
79
|
+
|
|
80
|
+
Optimizes the noise predictor and conditional model (e.g., TextEncoder)
|
|
81
|
+
to predict noise in the latent space of AutoencoderLDM, using a diffusion model (e.g., DDPM, DDIM, SDE).
|
|
82
|
+
Supports mixed precision, conditional generation with text prompts, and evaluation metrics
|
|
83
|
+
(MSE, PSNR, SSIM, FID, LPIPS) for generated images during validation, using a specified reverse
|
|
84
|
+
diffusion model.
|
|
85
|
+
|
|
86
|
+
Parameters
|
|
87
|
+
----------
|
|
88
|
+
diffusion_model : str
|
|
89
|
+
Diffusion model type ("ddpm", "ddim", "sde").
|
|
90
|
+
forward_diffusion : ForwardDDPM, ForwardDDIM, or ForwardSDE
|
|
91
|
+
Forward diffusion model defining the noise schedule.
|
|
92
|
+
reverse_diffusion : ReverseDDPM, ReverseDDIM, or ReverseSDE
|
|
93
|
+
Reverse diffusion model for sampling during validation (default: None).
|
|
94
|
+
noise_predictor : torch.nn.Module
|
|
95
|
+
Model to predict noise in the latent space (e.g., NoisePredictor).
|
|
96
|
+
compressor_model : torch.nn.Module
|
|
97
|
+
Variational autoencoder for encoding/decoding latents.
|
|
98
|
+
optimizer : torch.optim.Optimizer
|
|
99
|
+
Optimizer for the noise predictor and conditional model (e.g., Adam).
|
|
100
|
+
objective : Callable
|
|
101
|
+
Loss function for noise prediction (e.g., MSELoss).
|
|
102
|
+
data_loader : torch.utils.data.DataLoader
|
|
103
|
+
DataLoader for training data.
|
|
104
|
+
val_loader : torch.utils.data.DataLoader, optional
|
|
105
|
+
DataLoader for validation data (default: None).
|
|
106
|
+
conditional_model : TextEncoder, optional
|
|
107
|
+
Text encoder with projection layers for conditional generation (default: None).
|
|
108
|
+
|
|
109
|
+
metrics_ : object, optional
|
|
110
|
+
Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
|
|
111
|
+
max_epochs : int, optional
|
|
112
|
+
Maximum number of training epochs (default: 1000).
|
|
113
|
+
device : str, optional
|
|
114
|
+
Device for computation (e.g., 'cuda', 'cpu') (default: None).
|
|
115
|
+
store_path : str, optional
|
|
116
|
+
Path to save model checkpoints (default: None, uses 'ldm_model.pth').
|
|
117
|
+
patience : int, optional
|
|
118
|
+
Number of epochs to wait for early stopping if validation loss doesn’t improve
|
|
119
|
+
(default: 100).
|
|
120
|
+
warmup_epochs : int, optional
|
|
121
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
122
|
+
bert_tokenizer : BertTokenizer, optional
|
|
123
|
+
Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
|
|
124
|
+
max_token_length : int, optional
|
|
125
|
+
Maximum sequence length for tokenized text (default: 77).
|
|
126
|
+
val_frequency : int, optional
|
|
127
|
+
Frequency (in epochs) for validation and metric computation (default: 10).
|
|
128
|
+
image_output_range : tuple, optional
|
|
129
|
+
Range for clamping generated images (default: (-1, 1)).
|
|
130
|
+
normalize_output : bool, optional
|
|
131
|
+
Whether to normalize generated images to [0, 1] for metrics (default: True).
|
|
132
|
+
use_ddp : bool, optional
|
|
133
|
+
Whether to use Distributed Data Parallel training (default: False).
|
|
134
|
+
grad_accumulation_steps : int, optional
|
|
135
|
+
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
136
|
+
log_frequency : int, optional
|
|
137
|
+
Number of epochs before printing loss.
|
|
138
|
+
use_compilation : bool, optional
|
|
139
|
+
whether the model is internally compiled using torch.compile (default: false)
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
def __init__(
|
|
143
|
+
self,
|
|
144
|
+
diffusion_model: str,
|
|
145
|
+
forward_diffusion: torch.nn.Module,
|
|
146
|
+
reverse_diffusion: torch.nn.Module,
|
|
147
|
+
noise_predictor: torch.nn.Module,
|
|
148
|
+
compressor_model: torch.nn.Module,
|
|
149
|
+
optimizer: torch.optim.Optimizer,
|
|
150
|
+
objective: Callable,
|
|
151
|
+
data_loader: torch.utils.data.DataLoader,
|
|
152
|
+
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
153
|
+
conditional_model: Optional[torch.nn.Module] = None,
|
|
154
|
+
metrics_: Optional[Any] = None,
|
|
155
|
+
max_epochs: int = 1000,
|
|
156
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
157
|
+
store_path: Optional[str] = None,
|
|
158
|
+
patience: int = 100,
|
|
159
|
+
warmup_epochs: int = 100,
|
|
160
|
+
bert_tokenizer: Optional[BertTokenizer] = None,
|
|
161
|
+
max_token_length: int = 77,
|
|
162
|
+
val_frequency: int = 10,
|
|
163
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0),
|
|
164
|
+
normalize_output: bool = True,
|
|
165
|
+
use_ddp: bool = False,
|
|
166
|
+
grad_accumulation_steps: int = 1,
|
|
167
|
+
log_frequency: int = 1,
|
|
168
|
+
use_compilation: bool = False
|
|
169
|
+
) -> None:
|
|
170
|
+
super().__init__()
|
|
171
|
+
if diffusion_model not in ["ddpm", "ddim", "sde"]:
|
|
172
|
+
raise ValueError(f"Unknown model: {diffusion_model}. Supported: ddpm, ddim, sde")
|
|
173
|
+
self.diffusion_model = diffusion_model
|
|
174
|
+
|
|
175
|
+
# initialize DDP settings first
|
|
176
|
+
self.use_ddp = use_ddp
|
|
177
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
178
|
+
if device is None:
|
|
179
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
180
|
+
elif isinstance(device, str):
|
|
181
|
+
self.device = torch.device(device)
|
|
182
|
+
else:
|
|
183
|
+
self.device = device
|
|
184
|
+
|
|
185
|
+
# setup distributed training if enabled
|
|
186
|
+
if self.use_ddp:
|
|
187
|
+
self._setup_ddp()
|
|
188
|
+
else:
|
|
189
|
+
self._setup_single_gpu()
|
|
190
|
+
|
|
191
|
+
# move models to appropriate device
|
|
192
|
+
self.forward_diffusion = forward_diffusion.to(self.device)
|
|
193
|
+
self.reverse_diffusion = reverse_diffusion.to(self.device)
|
|
194
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
195
|
+
self.compressor_model = compressor_model.to(self.device)
|
|
196
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
197
|
+
|
|
198
|
+
# Training components
|
|
199
|
+
self.metrics_ = metrics_
|
|
200
|
+
self.optimizer = optimizer
|
|
201
|
+
self.objective = objective
|
|
202
|
+
self.store_path = store_path or "ldm_model"
|
|
203
|
+
self.data_loader = data_loader
|
|
204
|
+
self.val_loader = val_loader
|
|
205
|
+
self.max_epochs = max_epochs
|
|
206
|
+
self.max_token_length = max_token_length
|
|
207
|
+
self.patience = patience
|
|
208
|
+
self.val_frequency = val_frequency
|
|
209
|
+
self.image_output_range = image_output_range
|
|
210
|
+
self.normalize_output = normalize_output
|
|
211
|
+
self.log_frequency = log_frequency
|
|
212
|
+
self.use_compilation = use_compilation
|
|
213
|
+
|
|
214
|
+
# learning rate scheduling
|
|
215
|
+
self.scheduler = ReduceLROnPlateau(
|
|
216
|
+
self.optimizer,
|
|
217
|
+
patience=self.patience,
|
|
218
|
+
factor=0.5
|
|
219
|
+
)
|
|
220
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
221
|
+
|
|
222
|
+
# initialize tokenizer
|
|
223
|
+
if bert_tokenizer is None:
|
|
224
|
+
try:
|
|
225
|
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
226
|
+
except Exception as e:
|
|
227
|
+
raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
|
|
228
|
+
else:
|
|
229
|
+
self.tokenizer = bert_tokenizer
|
|
230
|
+
|
|
231
|
+
def _setup_ddp(self) -> None:
|
|
232
|
+
"""Setup Distributed Data Parallel training configuration.
|
|
233
|
+
|
|
234
|
+
Initializes process group, determines rank information, and sets up
|
|
235
|
+
CUDA device for the current process.
|
|
236
|
+
"""
|
|
237
|
+
# check if DDP environment variables are set
|
|
238
|
+
if "RANK" not in os.environ:
|
|
239
|
+
raise ValueError("DDP enabled but RANK environment variable not set")
|
|
240
|
+
if "LOCAL_RANK" not in os.environ:
|
|
241
|
+
raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
|
|
242
|
+
if "WORLD_SIZE" not in os.environ:
|
|
243
|
+
raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
|
|
244
|
+
|
|
245
|
+
# ensure CUDA is available for DDP
|
|
246
|
+
if not torch.cuda.is_available():
|
|
247
|
+
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
248
|
+
|
|
249
|
+
# initialize process group only if not already initialized
|
|
250
|
+
if not torch.distributed.is_initialized():
|
|
251
|
+
init_process_group(backend="nccl")
|
|
252
|
+
|
|
253
|
+
# get rank information
|
|
254
|
+
self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
|
|
255
|
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
|
|
256
|
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
|
|
257
|
+
|
|
258
|
+
# set device and make it current
|
|
259
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
260
|
+
torch.cuda.set_device(self.device)
|
|
261
|
+
|
|
262
|
+
# master process handles logging, checkpointing, etc.
|
|
263
|
+
self.master_process = self.ddp_rank == 0
|
|
264
|
+
|
|
265
|
+
if self.master_process:
|
|
266
|
+
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
267
|
+
|
|
268
|
+
def _setup_single_gpu(self) -> None:
|
|
269
|
+
"""Setup single GPU or CPU training configuration."""
|
|
270
|
+
self.ddp_rank = 0
|
|
271
|
+
self.ddp_local_rank = 0
|
|
272
|
+
self.ddp_world_size = 1
|
|
273
|
+
self.master_process = True
|
|
274
|
+
|
|
275
|
+
def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
|
|
276
|
+
"""Loads a training checkpoint to resume training.
|
|
277
|
+
|
|
278
|
+
Restores the state of the noise predictor, conditional model (if applicable),
|
|
279
|
+
and optimizer from a saved checkpoint. Handles DDP model state dict loading.
|
|
280
|
+
|
|
281
|
+
Parameters
|
|
282
|
+
----------
|
|
283
|
+
checkpoint_path : str
|
|
284
|
+
Path to the checkpoint file.
|
|
285
|
+
|
|
286
|
+
Returns
|
|
287
|
+
-------
|
|
288
|
+
epoch : int
|
|
289
|
+
The epoch at which the checkpoint was saved.
|
|
290
|
+
loss : float
|
|
291
|
+
The loss at the checkpoint.
|
|
292
|
+
"""
|
|
293
|
+
try:
|
|
294
|
+
# load checkpoint with proper device mapping
|
|
295
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
296
|
+
except FileNotFoundError:
|
|
297
|
+
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
|
298
|
+
|
|
299
|
+
# load noise predictor state
|
|
300
|
+
if 'model_state_dict_noise_predictor' not in checkpoint:
|
|
301
|
+
raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
|
|
302
|
+
|
|
303
|
+
# handle DDP wrapped model state dict
|
|
304
|
+
state_dict = checkpoint['model_state_dict_noise_predictor']
|
|
305
|
+
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
|
|
306
|
+
# if loading non-DDP checkpoint into DDP model, add 'module.' prefix
|
|
307
|
+
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
|
|
308
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
309
|
+
# If loading DDP checkpoint into non-DDP model, remove 'module.' prefix
|
|
310
|
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
311
|
+
|
|
312
|
+
self.noise_predictor.load_state_dict(state_dict)
|
|
313
|
+
|
|
314
|
+
# load conditional model state if applicable
|
|
315
|
+
if self.conditional_model is not None:
|
|
316
|
+
if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
|
|
317
|
+
cond_state_dict = checkpoint['model_state_dict_conditional']
|
|
318
|
+
# handle DDP wrapping for conditional model
|
|
319
|
+
if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
|
|
320
|
+
cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
|
|
321
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
|
|
322
|
+
cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
|
|
323
|
+
self.conditional_model.load_state_dict(cond_state_dict)
|
|
324
|
+
else:
|
|
325
|
+
warnings.warn(
|
|
326
|
+
"Checkpoint contains no 'model_state_dict_conditional' or it is None, "
|
|
327
|
+
"skipping conditional model loading"
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# load variance_scheduler state
|
|
331
|
+
if 'variance_scheduler_model' not in checkpoint:
|
|
332
|
+
raise KeyError("Checkpoint missing 'variance_scheduler_model' key")
|
|
333
|
+
try:
|
|
334
|
+
if isinstance(self.forward_diffusion.variance_scheduler, nn.Module):
|
|
335
|
+
self.forward_diffusion.variance_scheduler.load_state_dict(checkpoint['variance_scheduler_model'])
|
|
336
|
+
if isinstance(self.reverse_diffusion.variance_scheduler, nn.Module):
|
|
337
|
+
self.reverse_diffusion.variance_scheduler.load_state_dict(checkpoint['variance_scheduler_model'])
|
|
338
|
+
else:
|
|
339
|
+
self.forward_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
|
|
340
|
+
self.reverse_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
|
|
341
|
+
except Exception as e:
|
|
342
|
+
warnings.warn(f"Variance_scheduler loading failed: {e}. Continuing with current variance_scheduler.")
|
|
343
|
+
|
|
344
|
+
# load optimizer state
|
|
345
|
+
if 'optimizer_state_dict' not in checkpoint:
|
|
346
|
+
raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
|
|
347
|
+
try:
|
|
348
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
349
|
+
except ValueError as e:
|
|
350
|
+
warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
|
|
351
|
+
|
|
352
|
+
epoch = checkpoint.get('epoch', -1)
|
|
353
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
354
|
+
|
|
355
|
+
if self.master_process:
|
|
356
|
+
print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
|
|
357
|
+
|
|
358
|
+
return epoch, loss
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
@staticmethod
|
|
362
|
+
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
|
|
363
|
+
"""Creates a learning rate scheduler for warmup.
|
|
364
|
+
|
|
365
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
366
|
+
optimizer's initial value over the specified warmup epochs, then maintains it.
|
|
367
|
+
|
|
368
|
+
Parameters
|
|
369
|
+
----------
|
|
370
|
+
optimizer : torch.optim.Optimizer
|
|
371
|
+
Optimizer to apply the scheduler to.
|
|
372
|
+
warmup_epochs : int
|
|
373
|
+
Number of epochs for the warmup phase.
|
|
374
|
+
|
|
375
|
+
Returns
|
|
376
|
+
-------
|
|
377
|
+
torch.optim.lr_scheduler.LambdaLR
|
|
378
|
+
Learning rate scheduler for warmup.
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
def lr_lambda(epoch):
|
|
382
|
+
if epoch < warmup_epochs:
|
|
383
|
+
return epoch / warmup_epochs
|
|
384
|
+
return 1.0
|
|
385
|
+
|
|
386
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
387
|
+
|
|
388
|
+
def _wrap_models_for_ddp(self) -> None:
|
|
389
|
+
"""Wrap models with DistributedDataParallel for multi-GPU training."""
|
|
390
|
+
if self.use_ddp:
|
|
391
|
+
# wrap noise predictor with DDP
|
|
392
|
+
self.noise_predictor = DDP(
|
|
393
|
+
self.noise_predictor,
|
|
394
|
+
device_ids=[self.ddp_local_rank],
|
|
395
|
+
find_unused_parameters=True
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# wrap conditional model with DDP if it exists
|
|
399
|
+
if self.conditional_model is not None:
|
|
400
|
+
self.conditional_model = DDP(
|
|
401
|
+
self.conditional_model,
|
|
402
|
+
device_ids=[self.ddp_local_rank],
|
|
403
|
+
find_unused_parameters=True
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def forward(self) -> Tuple[List, float]:
|
|
407
|
+
"""Trains the noise predictor and conditional model with mixed precision and evaluation metrics.
|
|
408
|
+
|
|
409
|
+
Optimizes the noise predictor and conditional model (e.g., TextEncoder with projection layers)
|
|
410
|
+
using the forward diffusion model’s noise schedule, with text conditioning. Performs validation
|
|
411
|
+
with image-domain metrics (MSE, PSNR, SSIM, FID, LPIPS) using the reverse diffusion model,
|
|
412
|
+
saves checkpoints for the best validation loss, and supports early stopping.
|
|
413
|
+
|
|
414
|
+
Returns
|
|
415
|
+
-------
|
|
416
|
+
train_losses : List of float
|
|
417
|
+
List of mean training losses per epoch.
|
|
418
|
+
best_val_loss : float
|
|
419
|
+
Best validation loss achieved (or best training loss if no validation).
|
|
420
|
+
"""
|
|
421
|
+
# set models to training mode
|
|
422
|
+
self.noise_predictor.train()
|
|
423
|
+
if self.conditional_model is not None:
|
|
424
|
+
self.conditional_model.train()
|
|
425
|
+
self.compressor_model.eval() # pre-trained compressor model
|
|
426
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
427
|
+
self.reverse_diffusion.train()
|
|
428
|
+
self.forward_diffusion.train()
|
|
429
|
+
else:
|
|
430
|
+
self.reverse_diffusion.eval()
|
|
431
|
+
self.forward_diffusion.eval()
|
|
432
|
+
|
|
433
|
+
# compile models for optimization (if supported)
|
|
434
|
+
if self.use_compilation:
|
|
435
|
+
try:
|
|
436
|
+
self.noise_predictor = torch.compile(self.noise_predictor)
|
|
437
|
+
if self.conditional_model is not None:
|
|
438
|
+
self.conditional_model = torch.compile(self.conditional_model)
|
|
439
|
+
self.compressor_model = torch.compile(self.compressor_model)
|
|
440
|
+
except Exception as e:
|
|
441
|
+
if self.master_process:
|
|
442
|
+
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
# wrap models for DDP after compilation
|
|
446
|
+
self._wrap_models_for_ddp()
|
|
447
|
+
|
|
448
|
+
# initialize training components
|
|
449
|
+
scaler = torch.GradScaler()
|
|
450
|
+
train_losses = []
|
|
451
|
+
best_val_loss = float("inf")
|
|
452
|
+
wait = 0
|
|
453
|
+
|
|
454
|
+
# main training loop
|
|
455
|
+
for epoch in range(self.max_epochs):
|
|
456
|
+
# set epoch for distributed sampler if using DDP
|
|
457
|
+
if self.use_ddp and hasattr(self.data_loader.sampler, 'set_epoch'):
|
|
458
|
+
self.data_loader.sampler.set_epoch(epoch)
|
|
459
|
+
|
|
460
|
+
train_losses_epoch = []
|
|
461
|
+
|
|
462
|
+
# training step loop with gradient accumulation
|
|
463
|
+
for step, (x, y) in enumerate(tqdm(self.data_loader, disable=not self.master_process)):
|
|
464
|
+
x = x.to(self.device)
|
|
465
|
+
|
|
466
|
+
with torch.no_grad():
|
|
467
|
+
x, _ = self.compressor_model.encode(x)
|
|
468
|
+
|
|
469
|
+
# process conditional inputs if conditional model exists
|
|
470
|
+
if self.conditional_model is not None:
|
|
471
|
+
y_encoded = self._process_conditional_input(y)
|
|
472
|
+
else:
|
|
473
|
+
y_encoded = None
|
|
474
|
+
|
|
475
|
+
# forward pass with mixed precision
|
|
476
|
+
with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
|
|
477
|
+
# generate noise and timesteps
|
|
478
|
+
noise = torch.randn_like(x).to(self.device)
|
|
479
|
+
t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
|
|
480
|
+
|
|
481
|
+
# apply forward diffusion
|
|
482
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
483
|
+
|
|
484
|
+
# predict noise
|
|
485
|
+
predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
|
|
486
|
+
|
|
487
|
+
# compute loss and scale for gradient accumulation
|
|
488
|
+
loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
|
|
489
|
+
|
|
490
|
+
# backward pass
|
|
491
|
+
scaler.scale(loss).backward()
|
|
492
|
+
|
|
493
|
+
# gradient accumulation and optimizer step
|
|
494
|
+
if (step + 1) % self.grad_accumulation_steps == 0:
|
|
495
|
+
# clip gradients
|
|
496
|
+
scaler.unscale_(self.optimizer)
|
|
497
|
+
torch.nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
|
|
498
|
+
if self.conditional_model is not None:
|
|
499
|
+
torch.nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
|
|
500
|
+
|
|
501
|
+
# optimizer step
|
|
502
|
+
scaler.step(self.optimizer)
|
|
503
|
+
scaler.update()
|
|
504
|
+
self.optimizer.zero_grad()
|
|
505
|
+
|
|
506
|
+
# update learning rate (warmup scheduler)
|
|
507
|
+
self.warmup_lr_scheduler.step()
|
|
508
|
+
|
|
509
|
+
# record loss (unscaled)
|
|
510
|
+
train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
|
|
511
|
+
|
|
512
|
+
# compute mean training loss
|
|
513
|
+
mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
|
|
514
|
+
|
|
515
|
+
# all-reduce loss across processes for DDP
|
|
516
|
+
if self.use_ddp:
|
|
517
|
+
loss_tensor = torch.tensor(mean_train_loss, device=self.device)
|
|
518
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
|
519
|
+
mean_train_loss = loss_tensor.item()
|
|
520
|
+
|
|
521
|
+
train_losses.append(mean_train_loss)
|
|
522
|
+
|
|
523
|
+
# print training progress (only master process)
|
|
524
|
+
if self.master_process and (epoch + 1) % self.log_frequency == 0:
|
|
525
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
526
|
+
print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
|
|
527
|
+
|
|
528
|
+
# validation step
|
|
529
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
530
|
+
val_metrics = self.validate()
|
|
531
|
+
val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
|
|
532
|
+
|
|
533
|
+
if self.master_process:
|
|
534
|
+
print(f" | Val Loss: {val_loss:.4f}", end="")
|
|
535
|
+
if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
536
|
+
print(f" | FID: {fid:.4f}", end="")
|
|
537
|
+
if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
538
|
+
print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
|
|
539
|
+
if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
540
|
+
print(f" | LPIPS: {lpips_score:.4f}", end="")
|
|
541
|
+
print()
|
|
542
|
+
|
|
543
|
+
current_best = val_loss
|
|
544
|
+
self.scheduler.step(val_loss)
|
|
545
|
+
else:
|
|
546
|
+
if self.master_process:
|
|
547
|
+
print()
|
|
548
|
+
current_best = mean_train_loss
|
|
549
|
+
self.scheduler.step(mean_train_loss)
|
|
550
|
+
|
|
551
|
+
# save checkpoint and early stopping (only master process)
|
|
552
|
+
if self.master_process:
|
|
553
|
+
if current_best < best_val_loss and (epoch + 1) % self.val_frequency == 0:
|
|
554
|
+
best_val_loss = current_best
|
|
555
|
+
wait = 0
|
|
556
|
+
self._save_checkpoint(epoch + 1, best_val_loss)
|
|
557
|
+
else:
|
|
558
|
+
wait += 1
|
|
559
|
+
if wait >= self.patience:
|
|
560
|
+
print("Early stopping triggered")
|
|
561
|
+
self._save_checkpoint(epoch + 1, best_val_loss, "_early_stop")
|
|
562
|
+
break
|
|
563
|
+
|
|
564
|
+
# clean up DDP
|
|
565
|
+
if self.use_ddp:
|
|
566
|
+
destroy_process_group()
|
|
567
|
+
|
|
568
|
+
return train_losses, best_val_loss
|
|
569
|
+
|
|
570
|
+
def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor:
|
|
571
|
+
"""Process conditional input for text-to-image generation.
|
|
572
|
+
|
|
573
|
+
Parameters
|
|
574
|
+
----------
|
|
575
|
+
y : torch.Tensor or list
|
|
576
|
+
Conditional input (text prompts).
|
|
577
|
+
|
|
578
|
+
Returns
|
|
579
|
+
-------
|
|
580
|
+
torch.Tensor
|
|
581
|
+
Encoded conditional input.
|
|
582
|
+
"""
|
|
583
|
+
# convert to string list
|
|
584
|
+
y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
|
|
585
|
+
y_list = [str(item) for item in y_list]
|
|
586
|
+
|
|
587
|
+
# tokenize
|
|
588
|
+
y_encoded = self.tokenizer(
|
|
589
|
+
y_list,
|
|
590
|
+
padding="max_length",
|
|
591
|
+
truncation=True,
|
|
592
|
+
max_length=self.max_token_length,
|
|
593
|
+
return_tensors="pt"
|
|
594
|
+
).to(self.device)
|
|
595
|
+
|
|
596
|
+
# get embeddings
|
|
597
|
+
input_ids = y_encoded["input_ids"]
|
|
598
|
+
attention_mask = y_encoded["attention_mask"]
|
|
599
|
+
y_encoded = self.conditional_model(input_ids, attention_mask)
|
|
600
|
+
|
|
601
|
+
return y_encoded
|
|
602
|
+
|
|
603
|
+
def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "") -> None:
|
|
604
|
+
"""Save model checkpoint (only called by master process).
|
|
605
|
+
|
|
606
|
+
Parameters
|
|
607
|
+
----------
|
|
608
|
+
epoch : int
|
|
609
|
+
Current epoch number.
|
|
610
|
+
loss : float
|
|
611
|
+
Current loss value.
|
|
612
|
+
suffix : str, optional
|
|
613
|
+
Suffix to add to checkpoint filename.
|
|
614
|
+
"""
|
|
615
|
+
try:
|
|
616
|
+
# get state dicts, handling DDP wrapping
|
|
617
|
+
noise_predictor_state = (
|
|
618
|
+
self.noise_predictor.module.state_dict() if self.use_ddp
|
|
619
|
+
else self.noise_predictor.state_dict()
|
|
620
|
+
)
|
|
621
|
+
conditional_state = None
|
|
622
|
+
if self.conditional_model is not None:
|
|
623
|
+
conditional_state = (
|
|
624
|
+
self.conditional_model.module.state_dict() if self.use_ddp
|
|
625
|
+
else self.conditional_model.state_dict()
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
checkpoint = {
|
|
629
|
+
'epoch': epoch,
|
|
630
|
+
'model_state_dict_noise_predictor': noise_predictor_state,
|
|
631
|
+
'model_state_dict_conditional': conditional_state,
|
|
632
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
633
|
+
'loss': loss,
|
|
634
|
+
'variance_scheduler_model': (
|
|
635
|
+
self.forward_diffusion.variance_scheduler.state_dict() if isinstance(self.forward_diffusion.variance_scheduler, nn.Module)
|
|
636
|
+
else self.forward_diffusion.variance_scheduler
|
|
637
|
+
),
|
|
638
|
+
'max_epochs': self.max_epochs,
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
filename = f"ldm_epoch_{epoch}{suffix}.pth"
|
|
642
|
+
filepath = os.path.join(self.store_path, filename)
|
|
643
|
+
os.makedirs(self.store_path, exist_ok=True)
|
|
644
|
+
torch.save(checkpoint, filepath)
|
|
645
|
+
|
|
646
|
+
print(f"Model saved at epoch {epoch}")
|
|
647
|
+
|
|
648
|
+
except Exception as e:
|
|
649
|
+
print(f"Failed to save model: {e}")
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def validate(self) -> Tuple[float, float, float, float, float, float]:
|
|
653
|
+
"""Validates the noise predictor and computes evaluation metrics.
|
|
654
|
+
|
|
655
|
+
Computes validation loss (MSE between predicted and ground truth noise) and generates
|
|
656
|
+
samples using the reverse diffusion model. Evaluates image quality metrics if available.
|
|
657
|
+
|
|
658
|
+
Returns
|
|
659
|
+
-------
|
|
660
|
+
tuple
|
|
661
|
+
(val_loss, fid, mse, psnr, ssim, lpips_score) where metrics may be None if not computed.
|
|
662
|
+
"""
|
|
663
|
+
self.noise_predictor.eval()
|
|
664
|
+
if self.conditional_model is not None:
|
|
665
|
+
self.conditional_model.eval()
|
|
666
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
667
|
+
self.forward_diffusion.eval()
|
|
668
|
+
self.reverse_diffusion.eval()
|
|
669
|
+
|
|
670
|
+
val_losses = []
|
|
671
|
+
fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
|
|
672
|
+
|
|
673
|
+
num_steps = self.forward_diffusion.variance_scheduler.tau_num_steps if self.diffusion_model == "ddim" else self.forward_diffusion.variance_scheduler.num_steps
|
|
674
|
+
|
|
675
|
+
with torch.no_grad():
|
|
676
|
+
for x, y in self.val_loader:
|
|
677
|
+
x = x.to(self.device)
|
|
678
|
+
x_orig = x.clone()
|
|
679
|
+
x, _ = self.compressor_model.encode(x)
|
|
680
|
+
|
|
681
|
+
# process conditional input
|
|
682
|
+
if self.conditional_model is not None:
|
|
683
|
+
y_encoded = self._process_conditional_input(y)
|
|
684
|
+
else:
|
|
685
|
+
y_encoded = None
|
|
686
|
+
|
|
687
|
+
# compute validation loss
|
|
688
|
+
noise = torch.randn_like(x).to(self.device)
|
|
689
|
+
t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
|
|
690
|
+
|
|
691
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
692
|
+
predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
|
|
693
|
+
loss = self.objective(predicted_noise, noise)
|
|
694
|
+
val_losses.append(loss.item())
|
|
695
|
+
# generate samples for metrics evaluation
|
|
696
|
+
if self.metrics_ is not None and self.reverse_diffusion is not None:
|
|
697
|
+
xt = torch.randn_like(x).to(self.device)
|
|
698
|
+
|
|
699
|
+
# reverse diffusion sampling
|
|
700
|
+
for t in reversed(range(num_steps)):
|
|
701
|
+
time_steps = torch.full((xt.shape[0],), t, device=self.device)#, dtype=torch.long)
|
|
702
|
+
prev_time_steps = torch.full((xt.shape[0],), max(t - 1, 0), device=self.device)#, dtype=torch.long)
|
|
703
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y_encoded, None)
|
|
704
|
+
|
|
705
|
+
if self.diffusion_model == "sde":
|
|
706
|
+
noise = torch.randn_like(xt) if getattr(self.reverse_diffusion, "sde_method", None) != "ode" else None
|
|
707
|
+
xt = self.reverse_diffusion(xt, noise, predicted_noise, time_steps)
|
|
708
|
+
elif self.diffusion_model == "ddim":
|
|
709
|
+
xt, _ = self.reverse_diffusion(xt, predicted_noise, time_steps, prev_time_steps)
|
|
710
|
+
elif self.diffusion_model == "ddpm":
|
|
711
|
+
xt = self.reverse_diffusion(xt, predicted_noise, time_steps)
|
|
712
|
+
else:
|
|
713
|
+
raise ValueError(f"Unknown model: {self.diffusion_model}. Supported: ddpm, ddim, sde")
|
|
714
|
+
|
|
715
|
+
x_hat = self.compressor_model.decode(xt)
|
|
716
|
+
|
|
717
|
+
# clamp and normalize generated samples
|
|
718
|
+
x_hat = torch.clamp(x_hat, min=self.image_output_range[0], max=self.image_output_range[1])
|
|
719
|
+
if self.normalize_output:
|
|
720
|
+
x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
721
|
+
x_orig = (x_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
722
|
+
|
|
723
|
+
# Compute metrics
|
|
724
|
+
metrics_result = self.metrics_.forward(x_orig, x_hat)
|
|
725
|
+
fid, mse, psnr, ssim, lpips_score = metrics_result
|
|
726
|
+
|
|
727
|
+
if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
728
|
+
fid_scores.append(fid)
|
|
729
|
+
if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
730
|
+
mse_scores.append(mse)
|
|
731
|
+
psnr_scores.append(psnr)
|
|
732
|
+
ssim_scores.append(ssim)
|
|
733
|
+
if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
734
|
+
lpips_scores.append(lpips_score)
|
|
735
|
+
|
|
736
|
+
# compute average metrics
|
|
737
|
+
val_loss = torch.tensor(val_losses).mean().item()
|
|
738
|
+
|
|
739
|
+
# all-reduce validation metrics across processes for DDP
|
|
740
|
+
if self.use_ddp:
|
|
741
|
+
val_loss_tensor = torch.tensor(val_loss, device=self.device)
|
|
742
|
+
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
|
|
743
|
+
val_loss = val_loss_tensor.item()
|
|
744
|
+
|
|
745
|
+
fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
|
|
746
|
+
mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
|
|
747
|
+
psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
|
|
748
|
+
ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
|
|
749
|
+
lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
|
|
750
|
+
|
|
751
|
+
# return to training mode
|
|
752
|
+
self.noise_predictor.train()
|
|
753
|
+
if self.conditional_model is not None:
|
|
754
|
+
self.conditional_model.train()
|
|
755
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
756
|
+
self.reverse_diffusion.train()
|
|
757
|
+
self.forward_diffusion.train()
|
|
758
|
+
|
|
759
|
+
return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
###==================================================================================================================###
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
class SampleLDM(nn.Module):
|
|
766
|
+
"""Sampler for generating images using Latent Diffusion Models (LDM).
|
|
767
|
+
|
|
768
|
+
Generates images by iteratively denoising random noise in the latent space using a
|
|
769
|
+
reverse diffusion process, decoding the result back to the image space with a
|
|
770
|
+
pre-trained compressor, as described in Rombach et al. (2022). Supports DDPM, DDIM,
|
|
771
|
+
and SDE diffusion models, as well as conditional generation with text prompts.
|
|
772
|
+
|
|
773
|
+
Parameters
|
|
774
|
+
----------
|
|
775
|
+
diffusion_model : str
|
|
776
|
+
Diffusion model type. Supported: "ddpm", "ddim", "sde".
|
|
777
|
+
reverse_diffusion : nn.Module
|
|
778
|
+
Reverse diffusion module (e.g., ReverseDDPM, ReverseDDIM, ReverseSDE).
|
|
779
|
+
noise_predictor : nn.Module
|
|
780
|
+
Model to predict noise added during the forward diffusion process.
|
|
781
|
+
compressor_model : nn.Module
|
|
782
|
+
Pre-trained model to encode/decode between image and latent spaces (e.g., AutoencoderLDM).
|
|
783
|
+
image_shape : tuple
|
|
784
|
+
Shape of generated images as (height, width).
|
|
785
|
+
conditional_model : nn.Module, optional
|
|
786
|
+
Model for conditional generation (e.g., TextEncoder), default None.
|
|
787
|
+
bert_tokenizer : str or BertTokenizer, optional
|
|
788
|
+
Tokenizer for processing text prompts, default "bert-base-uncased".
|
|
789
|
+
batch_size : int, optional
|
|
790
|
+
Number of images to generate per batch (default: 1).
|
|
791
|
+
in_channels : int, optional
|
|
792
|
+
Number of input channels for latent representations (default: 3).
|
|
793
|
+
device : torch.device, optional
|
|
794
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
795
|
+
max_token_length : int, optional
|
|
796
|
+
Maximum length for tokenized prompts (default: 77).
|
|
797
|
+
image_output_range : tuple, optional
|
|
798
|
+
Range for clamping generated images (min, max), default (-1, 1).
|
|
799
|
+
"""
|
|
800
|
+
def __init__(
|
|
801
|
+
self,
|
|
802
|
+
diffusion_model: str,
|
|
803
|
+
reverse_diffusion: torch.nn.Module,
|
|
804
|
+
noise_predictor: torch.nn.Module,
|
|
805
|
+
compressor_model: torch.nn.Module,
|
|
806
|
+
image_shape: Tuple[float, float],
|
|
807
|
+
conditional_model: Optional[torch.nn.Module] = None,
|
|
808
|
+
bert_tokenizer: str = "bert-base-uncased",
|
|
809
|
+
batch_size: int = 1,
|
|
810
|
+
in_channels: int = 3,
|
|
811
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
812
|
+
max_token_length: int = 77,
|
|
813
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0)
|
|
814
|
+
) -> None:
|
|
815
|
+
super().__init__()
|
|
816
|
+
if device is None:
|
|
817
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
818
|
+
elif isinstance(device, str):
|
|
819
|
+
self.device = torch.device(device)
|
|
820
|
+
else:
|
|
821
|
+
self.device = device
|
|
822
|
+
self.diffusion_model = diffusion_model
|
|
823
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
824
|
+
self.reverse = reverse_diffusion.to(self.device)
|
|
825
|
+
self.compressor = compressor_model.to(self.device)
|
|
826
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
827
|
+
self.tokenizer = BertTokenizer.from_pretrained(bert_tokenizer)
|
|
828
|
+
self.in_channels = in_channels
|
|
829
|
+
self.image_shape = image_shape
|
|
830
|
+
self.batch_size = batch_size
|
|
831
|
+
self.max_token_length = max_token_length
|
|
832
|
+
self.image_output_range = image_output_range
|
|
833
|
+
|
|
834
|
+
if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(isinstance(s, int) and s > 0 for s in image_shape):
|
|
835
|
+
raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
|
|
836
|
+
if batch_size <= 0:
|
|
837
|
+
raise ValueError("batch_size must be positive")
|
|
838
|
+
if in_channels <= 0:
|
|
839
|
+
raise ValueError("in_channels must be positive")
|
|
840
|
+
if not isinstance(image_output_range, (tuple, list)) or len(image_output_range) != 2 or image_output_range[0] >= image_output_range[1]:
|
|
841
|
+
raise ValueError("output_range must be a tuple (min, max) with min < max")
|
|
842
|
+
|
|
843
|
+
def tokenize(self, prompts: Union[List, str]):
|
|
844
|
+
"""Tokenizes text prompts for conditional generation.
|
|
845
|
+
|
|
846
|
+
Converts input prompts into tokenized tensors using the specified tokenizer.
|
|
847
|
+
|
|
848
|
+
Parameters
|
|
849
|
+
----------
|
|
850
|
+
prompts : str or list
|
|
851
|
+
Text prompt(s) for conditional generation. Can be a single string or a list of strings.
|
|
852
|
+
|
|
853
|
+
Returns
|
|
854
|
+
-------
|
|
855
|
+
input_ids : torch.Tensor
|
|
856
|
+
Tokenized input IDs, shape (batch_size, max_length).
|
|
857
|
+
attention_mask : torch.Tensor
|
|
858
|
+
Attention mask, shape (batch_size, max_length).
|
|
859
|
+
"""
|
|
860
|
+
if isinstance(prompts, str):
|
|
861
|
+
prompts = [prompts]
|
|
862
|
+
elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
|
|
863
|
+
raise TypeError("prompts must be a string or list of strings")
|
|
864
|
+
|
|
865
|
+
encoded = self.tokenizer(
|
|
866
|
+
prompts,
|
|
867
|
+
padding="max_length",
|
|
868
|
+
truncation=True,
|
|
869
|
+
max_length=self.max_token_length,
|
|
870
|
+
return_tensors="pt"
|
|
871
|
+
)
|
|
872
|
+
return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
def forward(
|
|
876
|
+
self,
|
|
877
|
+
conditions: Optional[Union[List, str]] = None,
|
|
878
|
+
normalize_output: bool = True,
|
|
879
|
+
save_images: bool = True,
|
|
880
|
+
save_path: str = "ldm_generated"
|
|
881
|
+
) -> torch.Tensor:
|
|
882
|
+
"""Generates images using the reverse diffusion process in the latent space.
|
|
883
|
+
|
|
884
|
+
Iteratively denoises random noise in the latent space using the specified reverse
|
|
885
|
+
diffusion model (DDPM, DDIM, SDE), then decodes the result to the image space
|
|
886
|
+
with the compressor model. Supports conditional generation with text prompts.
|
|
887
|
+
|
|
888
|
+
Parameters
|
|
889
|
+
----------
|
|
890
|
+
conditions : str or list, optional
|
|
891
|
+
Text prompt(s) for conditional generation, default None.
|
|
892
|
+
normalize_output : bool, optional
|
|
893
|
+
If True, normalizes output images to [0, 1] (default: True).
|
|
894
|
+
save_images : bool, optional
|
|
895
|
+
If True, saves generated images to `save_path` (default: True).
|
|
896
|
+
save_path : str, optional
|
|
897
|
+
Directory to save generated images (default: "ldm_generated").
|
|
898
|
+
|
|
899
|
+
Returns
|
|
900
|
+
-------
|
|
901
|
+
generated_imgs (torch.Tensor) - Generated images, shape (batch_size, channels, height, width). If `normalize_output` is True, images are normalized to [0, 1]; otherwise, they are clamped to `output_range`.
|
|
902
|
+
"""
|
|
903
|
+
if conditions is not None and self.conditional_model is None:
|
|
904
|
+
raise ValueError("Conditions provided but no conditional model specified")
|
|
905
|
+
if conditions is None and self.conditional_model is not None:
|
|
906
|
+
raise ValueError("Conditions must be provided for conditional model")
|
|
907
|
+
|
|
908
|
+
noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
|
|
909
|
+
|
|
910
|
+
self.noise_predictor.eval()
|
|
911
|
+
self.compressor.eval()
|
|
912
|
+
self.reverse.eval()
|
|
913
|
+
if self.conditional_model:
|
|
914
|
+
self.conditional_model.eval()
|
|
915
|
+
|
|
916
|
+
with torch.no_grad():
|
|
917
|
+
xt = noisy_samples
|
|
918
|
+
xt, _ = self.compressor.encode(xt)
|
|
919
|
+
|
|
920
|
+
if self.diffusion_model == "ddim":
|
|
921
|
+
num_steps = self.reverse.variance_scheduler.tau_num_steps
|
|
922
|
+
elif self.diffusion_model == "ddpm" or self.diffusion_model == "sde":
|
|
923
|
+
num_steps = self.reverse.variance_scheduler.num_steps
|
|
924
|
+
else:
|
|
925
|
+
raise ValueError(f"Unknown model: {self.diffusion_model}. Supported: ddpm, ddim, sde")
|
|
926
|
+
|
|
927
|
+
for t in reversed(range(num_steps)):
|
|
928
|
+
time_steps = torch.full((self.batch_size,), t, device=self.device)#, dtype=torch.long)
|
|
929
|
+
prev_time_steps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)#, dtype=torch.long)
|
|
930
|
+
|
|
931
|
+
if self.diffusion_model == "sde":
|
|
932
|
+
noise = torch.randn_like(xt) if getattr(self.reverse, "sde_method", None) != "ode" else None
|
|
933
|
+
|
|
934
|
+
if self.conditional_model is not None and conditions is not None:
|
|
935
|
+
input_ids, attention_masks = self.tokenize(conditions)
|
|
936
|
+
key_padding_mask = (attention_masks == 0)
|
|
937
|
+
y = self.conditional_model(input_ids, key_padding_mask)
|
|
938
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y)
|
|
939
|
+
else:
|
|
940
|
+
predicted_noise = self.noise_predictor(xt, time_steps)
|
|
941
|
+
|
|
942
|
+
if self.diffusion_model == "sde":
|
|
943
|
+
xt = self.reverse(xt, noise, predicted_noise, time_steps)
|
|
944
|
+
elif self.diffusion_model == "ddim":
|
|
945
|
+
xt, _ = self.reverse(xt, predicted_noise, time_steps, prev_time_steps)
|
|
946
|
+
elif self.diffusion_model == "ddpm":
|
|
947
|
+
xt = self.reverse(xt, predicted_noise, time_steps)
|
|
948
|
+
else:
|
|
949
|
+
raise ValueError(f"Unknown model: {self.diffusion_model}. Supported: ddpm, ddim, sde")
|
|
950
|
+
|
|
951
|
+
x = self.compressor.decode(xt)
|
|
952
|
+
generated_imgs = torch.clamp(x, min=self.image_output_range[0], max=self.image_output_range[1])
|
|
953
|
+
if normalize_output:
|
|
954
|
+
generated_imgs = (generated_imgs - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
955
|
+
|
|
956
|
+
# save images if save_images is True
|
|
957
|
+
if save_images:
|
|
958
|
+
os.makedirs(save_path, exist_ok=True)
|
|
959
|
+
for i in range(generated_imgs.size(0)):
|
|
960
|
+
img_path = os.path.join(save_path, f"image_{i+1}.png")
|
|
961
|
+
save_image(generated_imgs[i], img_path)
|
|
962
|
+
|
|
963
|
+
return generated_imgs
|
|
964
|
+
|
|
965
|
+
def to(self, device: torch.device) -> Self:
|
|
966
|
+
"""Moves the module and its components to the specified device.
|
|
967
|
+
|
|
968
|
+
Parameters
|
|
969
|
+
----------
|
|
970
|
+
device : torch.device
|
|
971
|
+
Target device for computation.
|
|
972
|
+
|
|
973
|
+
Returns
|
|
974
|
+
-------
|
|
975
|
+
sample (SampleDDIM, SampleDDIM or SampleSDE) - The module moved to the specified device.
|
|
976
|
+
"""
|
|
977
|
+
self.device = device
|
|
978
|
+
self.noise_predictor.to(device)
|
|
979
|
+
self.reverse.to(device)
|
|
980
|
+
self.compressor.to(device)
|
|
981
|
+
if self.conditional_model:
|
|
982
|
+
self.conditional_model.to(device)
|
|
983
|
+
return super().to(device)
|
|
984
|
+
|
|
985
|
+
###==================================================================================================================###
|
|
986
|
+
|
|
987
|
+
class AutoencoderLDM(nn.Module):
|
|
988
|
+
"""Variational autoencoder for latent space compression in Latent Diffusion Models.
|
|
989
|
+
|
|
990
|
+
Encodes images into a latent space and decodes them back to the image space, used as
|
|
991
|
+
the `compressor_model` in LDM’s `TrainLDM` and `SampleLDM`. Supports KL-divergence
|
|
992
|
+
or vector quantization (VQ) regularization for the latent representation.
|
|
993
|
+
|
|
994
|
+
Parameters
|
|
995
|
+
----------
|
|
996
|
+
in_channels : int
|
|
997
|
+
Number of input channels (e.g., 3 for RGB images).
|
|
998
|
+
down_channels : list
|
|
999
|
+
List of channel sizes for encoder downsampling blocks (e.g., [32, 64, 128, 256]).
|
|
1000
|
+
up_channels : list
|
|
1001
|
+
List of channel sizes for decoder upsampling blocks (e.g., [256, 128, 64, 16]).
|
|
1002
|
+
out_channels : int
|
|
1003
|
+
Number of output channels, typically equal to `in_channels`.
|
|
1004
|
+
dropout_rate : float
|
|
1005
|
+
Dropout rate for regularization in convolutional and attention layers.
|
|
1006
|
+
num_heads : int
|
|
1007
|
+
Number of attention heads in self-attention layers.
|
|
1008
|
+
num_groups : int
|
|
1009
|
+
Number of groups for group normalization in attention layers.
|
|
1010
|
+
num_layers_per_block : int
|
|
1011
|
+
Number of convolutional layers in each downsampling and upsampling block.
|
|
1012
|
+
total_down_sampling_factor : int
|
|
1013
|
+
Total downsampling factor across the encoder (e.g., 8 for 8x reduction).
|
|
1014
|
+
latent_channels : int
|
|
1015
|
+
Number of channels in the latent representation for diffusion models.
|
|
1016
|
+
num_embeddings : int
|
|
1017
|
+
Number of discrete embeddings in the VQ codebook (if `use_vq=True`).
|
|
1018
|
+
use_vq : bool, optional
|
|
1019
|
+
If True, uses vector quantization (VQ) regularization; otherwise, uses
|
|
1020
|
+
KL-divergence (default: False).
|
|
1021
|
+
beta : float, optional
|
|
1022
|
+
Weight for KL-divergence loss (if `use_vq=False`) (default: 1.0).
|
|
1023
|
+
"""
|
|
1024
|
+
def __init__(
|
|
1025
|
+
self,
|
|
1026
|
+
in_channels: int,
|
|
1027
|
+
down_channels: List[int],
|
|
1028
|
+
up_channels: List[int],
|
|
1029
|
+
out_channels: int,
|
|
1030
|
+
dropout_rate: float,
|
|
1031
|
+
num_heads: int,
|
|
1032
|
+
num_groups: int,
|
|
1033
|
+
num_layers_per_block: int,
|
|
1034
|
+
total_down_sampling_factor: int,
|
|
1035
|
+
latent_channels: int,
|
|
1036
|
+
num_embeddings: int,
|
|
1037
|
+
use_vq: bool = False,
|
|
1038
|
+
beta: float = 1.0
|
|
1039
|
+
) -> None:
|
|
1040
|
+
super().__init__()
|
|
1041
|
+
assert in_channels == out_channels, "Input and output channels must match for auto-encoding"
|
|
1042
|
+
self.use_vq = use_vq
|
|
1043
|
+
self.beta = beta
|
|
1044
|
+
self.current_beta = beta
|
|
1045
|
+
num_down_blocks = len(down_channels) - 1
|
|
1046
|
+
self.down_sampling_factor = int(total_down_sampling_factor ** (1 / num_down_blocks))
|
|
1047
|
+
|
|
1048
|
+
# encoder
|
|
1049
|
+
self.conv1 = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, padding=1)
|
|
1050
|
+
self.down_blocks = nn.ModuleList([
|
|
1051
|
+
DownBlock(
|
|
1052
|
+
in_channels=down_channels[i],
|
|
1053
|
+
out_channels=down_channels[i + 1],
|
|
1054
|
+
num_layers=num_layers_per_block,
|
|
1055
|
+
down_sampling_factor=self.down_sampling_factor,
|
|
1056
|
+
dropout_rate=dropout_rate
|
|
1057
|
+
) for i in range(num_down_blocks)
|
|
1058
|
+
])
|
|
1059
|
+
self.attention1 = Attention(down_channels[-1], num_heads, num_groups, dropout_rate)
|
|
1060
|
+
|
|
1061
|
+
# latent projection
|
|
1062
|
+
if use_vq:
|
|
1063
|
+
self.vq_layer = VectorQuantizer(num_embeddings, down_channels[-1])
|
|
1064
|
+
self.quant_conv = nn.Conv2d(down_channels[-1], latent_channels, kernel_size=1)
|
|
1065
|
+
else:
|
|
1066
|
+
self.conv_mu = nn.Conv2d(down_channels[-1], down_channels[-1], kernel_size=3, padding=1)
|
|
1067
|
+
self.conv_logvar = nn.Conv2d(down_channels[-1], down_channels[-1], kernel_size=3, padding=1)
|
|
1068
|
+
self.quant_conv = nn.Conv2d(down_channels[-1], latent_channels, kernel_size=1)
|
|
1069
|
+
|
|
1070
|
+
# decoder
|
|
1071
|
+
self.conv2 = nn.Conv2d(latent_channels, up_channels[0], kernel_size=3, padding=1)
|
|
1072
|
+
self.attention2 = Attention(up_channels[0], num_heads, num_groups, dropout_rate)
|
|
1073
|
+
self.up_blocks = nn.ModuleList([
|
|
1074
|
+
UpBlock(
|
|
1075
|
+
in_channels=up_channels[i],
|
|
1076
|
+
out_channels=up_channels[i + 1],
|
|
1077
|
+
num_layers=num_layers_per_block,
|
|
1078
|
+
up_sampling_factor=self.down_sampling_factor,
|
|
1079
|
+
dropout_rate=dropout_rate
|
|
1080
|
+
) for i in range(len(up_channels) - 1)
|
|
1081
|
+
])
|
|
1082
|
+
self.conv3 = Conv3(up_channels[-1], out_channels, dropout_rate)
|
|
1083
|
+
|
|
1084
|
+
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
|
|
1085
|
+
"""Applies reparameterization trick for variational autoencoding.
|
|
1086
|
+
|
|
1087
|
+
Samples from a Gaussian distribution using the mean and log-variance to enable
|
|
1088
|
+
differentiable training.
|
|
1089
|
+
|
|
1090
|
+
Parameters
|
|
1091
|
+
----------
|
|
1092
|
+
mu : torch.Tensor
|
|
1093
|
+
Mean of the latent distribution, shape (batch_size, channels, height, width).
|
|
1094
|
+
logvar : torch.Tensor
|
|
1095
|
+
Log-variance of the latent distribution, same shape as `mu`.
|
|
1096
|
+
|
|
1097
|
+
Returns
|
|
1098
|
+
-------
|
|
1099
|
+
reparam (torch.Tensor) - Sampled latent representation, same shape as `mu`.
|
|
1100
|
+
"""
|
|
1101
|
+
std = torch.exp(0.5 * logvar)
|
|
1102
|
+
eps = torch.randn_like(std)
|
|
1103
|
+
return mu + eps * std
|
|
1104
|
+
|
|
1105
|
+
def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
|
1106
|
+
"""Encodes images into a latent representation.
|
|
1107
|
+
|
|
1108
|
+
Processes input images through the encoder, applying convolutions, downsampling,
|
|
1109
|
+
self-attention, and latent projection (VQ or KL-based).
|
|
1110
|
+
|
|
1111
|
+
Parameters
|
|
1112
|
+
----------
|
|
1113
|
+
x : torch.Tensor
|
|
1114
|
+
Input images, shape (batch_size, in_channels, height, width).
|
|
1115
|
+
|
|
1116
|
+
Returns
|
|
1117
|
+
-------
|
|
1118
|
+
z : (torch.Tensor)
|
|
1119
|
+
Latent representation, shape (batch_size, latent_channels, height/down_sampling_factor, width/down_sampling_factor).
|
|
1120
|
+
reg_loss : float
|
|
1121
|
+
Regularization loss (VQ loss if `use_vq=True`, KL-divergence loss if `use_vq=False`).
|
|
1122
|
+
|
|
1123
|
+
**Notes**
|
|
1124
|
+
|
|
1125
|
+
- The VQ loss is computed by `VectorQuantizer` if `use_vq=True`.
|
|
1126
|
+
- The KL-divergence loss is normalized by batch size and latent size, weighted
|
|
1127
|
+
by `current_beta`.
|
|
1128
|
+
"""
|
|
1129
|
+
x = self.conv1(x)
|
|
1130
|
+
for block in self.down_blocks:
|
|
1131
|
+
x = block(x)
|
|
1132
|
+
res_x = x
|
|
1133
|
+
x = self.attention1(x)
|
|
1134
|
+
x = x + res_x
|
|
1135
|
+
if self.use_vq:
|
|
1136
|
+
z, vq_loss = self.vq_layer(x)
|
|
1137
|
+
z = self.quant_conv(z)
|
|
1138
|
+
return z, vq_loss
|
|
1139
|
+
else:
|
|
1140
|
+
mu = self.conv_mu(x)
|
|
1141
|
+
logvar = self.conv_logvar(x)
|
|
1142
|
+
z = self.reparameterize(mu, logvar)
|
|
1143
|
+
z = self.quant_conv(z)
|
|
1144
|
+
kl_unnormalized = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
|
1145
|
+
batch_size = x.size(0)
|
|
1146
|
+
latent_size = torch.prod(torch.tensor(mu.shape[1:])).item()
|
|
1147
|
+
kl_loss = kl_unnormalized / (batch_size * latent_size) * self.current_beta
|
|
1148
|
+
return z, kl_loss
|
|
1149
|
+
|
|
1150
|
+
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
|
1151
|
+
"""Decodes latent representations back to images.
|
|
1152
|
+
|
|
1153
|
+
Processes latent representations through the decoder, applying convolutions,
|
|
1154
|
+
self-attention, upsampling, and final reconstruction.
|
|
1155
|
+
|
|
1156
|
+
Parameters
|
|
1157
|
+
----------
|
|
1158
|
+
z : torch.Tensor
|
|
1159
|
+
Latent representation, shape (batch_size, latent_channels,
|
|
1160
|
+
height/down_sampling_factor, width/down_sampling_factor).
|
|
1161
|
+
|
|
1162
|
+
Returns
|
|
1163
|
+
-------
|
|
1164
|
+
x (torch.Tensor) - Reconstructed images, shape (batch_size, out_channels, height, width).
|
|
1165
|
+
"""
|
|
1166
|
+
x = self.conv2(z)
|
|
1167
|
+
res_x = x
|
|
1168
|
+
x = self.attention2(x)
|
|
1169
|
+
x = x + res_x
|
|
1170
|
+
for block in self.up_blocks:
|
|
1171
|
+
x = block(x)
|
|
1172
|
+
x = self.conv3(x)
|
|
1173
|
+
return x
|
|
1174
|
+
|
|
1175
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, float, float, torch.Tensor]:
|
|
1176
|
+
"""Encodes images to latent space and decodes them, computing reconstruction and regularization losses.
|
|
1177
|
+
|
|
1178
|
+
Performs a full autoencoding pass, encoding images to the latent space, decoding
|
|
1179
|
+
them back, and calculating MSE reconstruction loss and regularization loss (VQ
|
|
1180
|
+
or KL-based).
|
|
1181
|
+
|
|
1182
|
+
Parameters
|
|
1183
|
+
----------
|
|
1184
|
+
x : torch.Tensor
|
|
1185
|
+
Input images, shape (batch_size, in_channels, height, width).
|
|
1186
|
+
|
|
1187
|
+
Returns
|
|
1188
|
+
-------
|
|
1189
|
+
x_hat : torch.Tensor
|
|
1190
|
+
Reconstructed images, shape (batch_size, out_channels, height, width).
|
|
1191
|
+
total_loss : float
|
|
1192
|
+
Sum of reconstruction (MSE) and regularization losses.
|
|
1193
|
+
reg_loss : float
|
|
1194
|
+
Regularization loss (VQ or KL-divergence).
|
|
1195
|
+
z : torch.Tensor
|
|
1196
|
+
Latent representation, shape (batch_size, latent_channels, height/down_sampling_factor, width/down_sampling_factor).
|
|
1197
|
+
|
|
1198
|
+
**Notes**
|
|
1199
|
+
|
|
1200
|
+
- The reconstruction loss is computed as the mean squared error between `x_hat` and `x`.
|
|
1201
|
+
- The regularization loss depends on `use_vq` (VQ loss or KL-divergence).
|
|
1202
|
+
"""
|
|
1203
|
+
z, reg_loss = self.encode(x)
|
|
1204
|
+
x_hat = self.decode(z)
|
|
1205
|
+
recon_loss = F.mse_loss(x_hat, x)
|
|
1206
|
+
total_loss = recon_loss.item() + reg_loss
|
|
1207
|
+
return x_hat, total_loss, reg_loss, z
|
|
1208
|
+
|
|
1209
|
+
###==================================================================================================================###
|
|
1210
|
+
|
|
1211
|
+
class VectorQuantizer(nn.Module):
|
|
1212
|
+
"""Vector quantization layer for discretizing latent representations.
|
|
1213
|
+
|
|
1214
|
+
Quantizes input latent vectors to the nearest embedding in a learned codebook,
|
|
1215
|
+
used in `AutoencoderLDM` when `use_vq=True` to enable discrete latent spaces for
|
|
1216
|
+
Latent Diffusion Models. Computes commitment and codebook losses to train the
|
|
1217
|
+
codebook embeddings.
|
|
1218
|
+
|
|
1219
|
+
Parameters
|
|
1220
|
+
----------
|
|
1221
|
+
num_embeddings : int
|
|
1222
|
+
Number of discrete embeddings in the codebook.
|
|
1223
|
+
embedding_dim : int
|
|
1224
|
+
Dimensionality of each embedding vector (matches input channel dimension).
|
|
1225
|
+
commitment_cost : float, optional
|
|
1226
|
+
Weight for the commitment loss, encouraging inputs to be close to quantized values (default: 0.25).
|
|
1227
|
+
|
|
1228
|
+
|
|
1229
|
+
**Notes**
|
|
1230
|
+
|
|
1231
|
+
- The codebook embeddings are initialized uniformly in the range [-1/num_embeddings, 1/num_embeddings].
|
|
1232
|
+
- The forward pass flattens input latents, computes Euclidean distances to codebook embeddings, and selects the nearest embedding for quantization.
|
|
1233
|
+
- The commitment loss encourages input latents to be close to their quantized versions, while the codebook loss updates embeddings to match inputs.
|
|
1234
|
+
- A straight-through estimator is used to pass gradients from the quantized output to the input.
|
|
1235
|
+
"""
|
|
1236
|
+
def __init__(self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25) -> None:
|
|
1237
|
+
super().__init__()
|
|
1238
|
+
self.embedding_dim = embedding_dim
|
|
1239
|
+
self.num_embeddings = num_embeddings
|
|
1240
|
+
self.commitment_cost = commitment_cost
|
|
1241
|
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
|
1242
|
+
self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
|
|
1243
|
+
|
|
1244
|
+
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1245
|
+
"""Quantizes latent representations to the nearest codebook embedding.
|
|
1246
|
+
|
|
1247
|
+
Computes the closest embedding for each input vector, applies quantization,
|
|
1248
|
+
and calculates commitment and codebook losses for training.
|
|
1249
|
+
|
|
1250
|
+
Parameters
|
|
1251
|
+
----------
|
|
1252
|
+
z : torch.Tensor
|
|
1253
|
+
Input latent representation, shape (batch_size, embedding_dim, height,
|
|
1254
|
+
width).
|
|
1255
|
+
|
|
1256
|
+
Returns
|
|
1257
|
+
-------
|
|
1258
|
+
quantized : torch.Tensor
|
|
1259
|
+
Quantized latent representation, same shape as `z`.
|
|
1260
|
+
vq_loss : torch.Tensor
|
|
1261
|
+
Sum of commitment and codebook losses.
|
|
1262
|
+
|
|
1263
|
+
**Notes**
|
|
1264
|
+
|
|
1265
|
+
- The input is flattened to (batch_size * height * width, embedding_dim) for distance computation.
|
|
1266
|
+
- Euclidean distances are computed efficiently using vectorized operations.
|
|
1267
|
+
- The commitment loss is scaled by `commitment_cost`, and the total VQ loss combines commitment and codebook losses.
|
|
1268
|
+
"""
|
|
1269
|
+
z = z.contiguous()
|
|
1270
|
+
assert z.size(1) == self.embedding_dim, f"Expected channel dim {self.embedding_dim}, got {z.size(1)}"
|
|
1271
|
+
z_flattened = z.reshape(-1, self.embedding_dim)
|
|
1272
|
+
distances = (torch.sum(z_flattened ** 2, dim=1, keepdim=True)
|
|
1273
|
+
+ torch.sum(self.embedding.weight ** 2, dim=1)
|
|
1274
|
+
- 2 * torch.matmul(z_flattened, self.embedding.weight.t()))
|
|
1275
|
+
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
|
|
1276
|
+
encodings = F.one_hot(encoding_indices, self.num_embeddings).float().squeeze(1)
|
|
1277
|
+
quantized = torch.matmul(encodings, self.embedding.weight).view_as(z)
|
|
1278
|
+
commitment_loss = self.commitment_cost * torch.mean((z.detach() - quantized) ** 2)
|
|
1279
|
+
codebook_loss = torch.mean((z - quantized.detach()) ** 2)
|
|
1280
|
+
quantized = z + (quantized - z).detach()
|
|
1281
|
+
return quantized, commitment_loss + codebook_loss
|
|
1282
|
+
|
|
1283
|
+
###==================================================================================================================###
|
|
1284
|
+
|
|
1285
|
+
class DownBlock(nn.Module):
|
|
1286
|
+
"""Downsampling block for the encoder in AutoencoderLDM.
|
|
1287
|
+
|
|
1288
|
+
Applies multiple convolutional layers with residual connections followed by
|
|
1289
|
+
downsampling to reduce spatial dimensions in the encoder of the variational
|
|
1290
|
+
autoencoder used in Latent Diffusion Models.
|
|
1291
|
+
|
|
1292
|
+
Parameters
|
|
1293
|
+
----------
|
|
1294
|
+
in_channels : int
|
|
1295
|
+
Number of input channels.
|
|
1296
|
+
out_channels : int
|
|
1297
|
+
Number of output channels for convolutional layers.
|
|
1298
|
+
num_layers : int
|
|
1299
|
+
Number of convolutional layer pairs (Conv3) per block.
|
|
1300
|
+
down_sampling_factor : int
|
|
1301
|
+
Factor by which to downsample spatial dimensions.
|
|
1302
|
+
dropout_rate : float
|
|
1303
|
+
Dropout rate for Conv3 layers.
|
|
1304
|
+
|
|
1305
|
+
**Notes**
|
|
1306
|
+
|
|
1307
|
+
- Each layer pair consists of two Conv3 modules with a residual connection using a 1x1 convolution to match dimensions.
|
|
1308
|
+
- The downsampling is applied after all convolutional layers, reducing spatial dimensions by `down_sampling_factor`.
|
|
1309
|
+
"""
|
|
1310
|
+
def __init__(self, in_channels: int, out_channels: int, num_layers: int, down_sampling_factor: int, dropout_rate: float) -> None:
|
|
1311
|
+
super().__init__()
|
|
1312
|
+
self.num_layers = num_layers
|
|
1313
|
+
self.conv1 = nn.ModuleList([
|
|
1314
|
+
Conv3(
|
|
1315
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
1316
|
+
out_channels=out_channels,
|
|
1317
|
+
dropout_rate=dropout_rate
|
|
1318
|
+
) for i in range(self.num_layers)
|
|
1319
|
+
])
|
|
1320
|
+
self.conv2 = nn.ModuleList([
|
|
1321
|
+
Conv3(
|
|
1322
|
+
in_channels=out_channels,
|
|
1323
|
+
out_channels=out_channels,
|
|
1324
|
+
dropout_rate=dropout_rate
|
|
1325
|
+
) for _ in range(self.num_layers)
|
|
1326
|
+
])
|
|
1327
|
+
|
|
1328
|
+
self.down_sampling = DownSampling(
|
|
1329
|
+
in_channels=out_channels,
|
|
1330
|
+
out_channels=out_channels,
|
|
1331
|
+
down_sampling_factor=down_sampling_factor
|
|
1332
|
+
)
|
|
1333
|
+
self.resnet = nn.ModuleList([
|
|
1334
|
+
nn.Conv2d(
|
|
1335
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
1336
|
+
out_channels=out_channels,
|
|
1337
|
+
kernel_size=1
|
|
1338
|
+
) for i in range(num_layers)
|
|
1339
|
+
|
|
1340
|
+
])
|
|
1341
|
+
|
|
1342
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1343
|
+
"""Processes input through convolutional layers and downsampling.
|
|
1344
|
+
|
|
1345
|
+
Parameters
|
|
1346
|
+
----------
|
|
1347
|
+
x : torch.Tensor
|
|
1348
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
1349
|
+
|
|
1350
|
+
Returns
|
|
1351
|
+
-------
|
|
1352
|
+
output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor).
|
|
1353
|
+
"""
|
|
1354
|
+
output = x
|
|
1355
|
+
for i in range(self.num_layers):
|
|
1356
|
+
resnet_input = output
|
|
1357
|
+
output = self.conv1[i](output)
|
|
1358
|
+
output = self.conv2[i](output)
|
|
1359
|
+
output = output + self.resnet[i](resnet_input)
|
|
1360
|
+
output = self.down_sampling(output)
|
|
1361
|
+
return output
|
|
1362
|
+
|
|
1363
|
+
###==================================================================================================================###
|
|
1364
|
+
|
|
1365
|
+
class Conv3(nn.Module):
|
|
1366
|
+
"""Convolutional layer with group normalization, SiLU activation, and dropout.
|
|
1367
|
+
|
|
1368
|
+
Used in DownBlock and UpBlock of AutoencoderLDM for feature extraction and
|
|
1369
|
+
transformation in the encoder and decoder.
|
|
1370
|
+
|
|
1371
|
+
Parameters
|
|
1372
|
+
----------
|
|
1373
|
+
in_channels : int
|
|
1374
|
+
Number of input channels.
|
|
1375
|
+
out_channels : int
|
|
1376
|
+
Number of output channels.
|
|
1377
|
+
dropout_rate : float
|
|
1378
|
+
Dropout rate for regularization.
|
|
1379
|
+
|
|
1380
|
+
**Notes**
|
|
1381
|
+
|
|
1382
|
+
- The layer applies group normalization, SiLU activation, dropout, and a 3x3 convolution in sequence.
|
|
1383
|
+
- Spatial dimensions are preserved due to padding=1 in the convolution.
|
|
1384
|
+
"""
|
|
1385
|
+
def __init__(self, in_channels: int, out_channels: int, dropout_rate: float) -> None:
|
|
1386
|
+
super().__init__()
|
|
1387
|
+
self.group_norm = nn.GroupNorm(num_groups=8, num_channels=in_channels)
|
|
1388
|
+
self.activation = nn.SiLU()
|
|
1389
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
1390
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
1391
|
+
|
|
1392
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1393
|
+
"""Processes input through group normalization, activation, dropout, and convolution.
|
|
1394
|
+
|
|
1395
|
+
Parameters
|
|
1396
|
+
----------
|
|
1397
|
+
x : torch.Tensor
|
|
1398
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
1399
|
+
|
|
1400
|
+
Returns
|
|
1401
|
+
-------
|
|
1402
|
+
x (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height, width).
|
|
1403
|
+
"""
|
|
1404
|
+
x = self.group_norm(x)
|
|
1405
|
+
x = self.activation(x)
|
|
1406
|
+
x = self.dropout(x)
|
|
1407
|
+
x = self.conv(x)
|
|
1408
|
+
return x
|
|
1409
|
+
|
|
1410
|
+
###==================================================================================================================###
|
|
1411
|
+
|
|
1412
|
+
class DownSampling(nn.Module):
|
|
1413
|
+
"""Downsampling module for reducing spatial dimensions in AutoencoderLDM’s encoder.
|
|
1414
|
+
|
|
1415
|
+
Combines convolutional downsampling and max pooling, concatenating their outputs
|
|
1416
|
+
to preserve feature information during downsampling in DownBlock.
|
|
1417
|
+
|
|
1418
|
+
Parameters
|
|
1419
|
+
----------
|
|
1420
|
+
in_channels : int
|
|
1421
|
+
Number of input channels.
|
|
1422
|
+
out_channels : int
|
|
1423
|
+
Number of output channels (sum of conv and pool paths).
|
|
1424
|
+
down_sampling_factor : int
|
|
1425
|
+
Factor by which to downsample spatial dimensions.
|
|
1426
|
+
|
|
1427
|
+
**Notes**
|
|
1428
|
+
|
|
1429
|
+
- The module splits the output channels evenly between convolutional and pooling paths, concatenating them along the channel dimension.
|
|
1430
|
+
- The convolutional path uses a stride equal to `down_sampling_factor`, while the pooling path uses max pooling with the same factor.
|
|
1431
|
+
"""
|
|
1432
|
+
def __init__(self, in_channels: int, out_channels: int, down_sampling_factor: int) -> None:
|
|
1433
|
+
super().__init__()
|
|
1434
|
+
self.down_sampling_factor = down_sampling_factor
|
|
1435
|
+
self.conv = nn.Sequential(
|
|
1436
|
+
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
|
|
1437
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2,
|
|
1438
|
+
kernel_size=3, stride=down_sampling_factor, padding=1)
|
|
1439
|
+
)
|
|
1440
|
+
self.pool = nn.Sequential(
|
|
1441
|
+
nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
|
|
1442
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2,
|
|
1443
|
+
kernel_size=1, stride=1, padding=0)
|
|
1444
|
+
)
|
|
1445
|
+
|
|
1446
|
+
def forward(self, batch: torch.Tensor) -> torch.Tensor:
|
|
1447
|
+
"""Downsamples input by combining convolutional and pooling paths.
|
|
1448
|
+
|
|
1449
|
+
Parameters
|
|
1450
|
+
----------
|
|
1451
|
+
batch : torch.Tensor
|
|
1452
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
1453
|
+
|
|
1454
|
+
Returns
|
|
1455
|
+
-------
|
|
1456
|
+
x (torch.Tensor) - Downsampled tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor).
|
|
1457
|
+
"""
|
|
1458
|
+
return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
|
|
1459
|
+
|
|
1460
|
+
###==================================================================================================================###
|
|
1461
|
+
|
|
1462
|
+
class Attention(nn.Module):
|
|
1463
|
+
"""Self-attention module for feature enhancement in AutoencoderLDM.
|
|
1464
|
+
|
|
1465
|
+
Applies multi-head self-attention to enhance features in the encoder and decoder,
|
|
1466
|
+
used after downsampling (in DownBlock) and before upsampling (in UpBlock).
|
|
1467
|
+
|
|
1468
|
+
Parameters
|
|
1469
|
+
----------
|
|
1470
|
+
num_channels : int
|
|
1471
|
+
Number of input and output channels (embedding dimension for attention).
|
|
1472
|
+
num_heads : int
|
|
1473
|
+
Number of attention heads.
|
|
1474
|
+
num_groups : int
|
|
1475
|
+
Number of groups for group normalization.
|
|
1476
|
+
dropout_rate : float
|
|
1477
|
+
Dropout rate for attention outputs.
|
|
1478
|
+
|
|
1479
|
+
**Notes**
|
|
1480
|
+
|
|
1481
|
+
- The input is reshaped to (batch_size, height * width, num_channels) for attention processing, then restored to (batch_size, num_channels, height, width).
|
|
1482
|
+
- Group normalization is applied before attention to stabilize training.
|
|
1483
|
+
"""
|
|
1484
|
+
def __init__(self, num_channels: int, num_heads: int, num_groups: int, dropout_rate: float) -> None:
|
|
1485
|
+
super().__init__()
|
|
1486
|
+
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
|
|
1487
|
+
self.attention = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True)
|
|
1488
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
1489
|
+
|
|
1490
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1491
|
+
"""Applies self-attention to input features.
|
|
1492
|
+
|
|
1493
|
+
Parameters
|
|
1494
|
+
----------
|
|
1495
|
+
x : torch.Tensor
|
|
1496
|
+
Input tensor, shape (batch_size, num_channels, height, width).
|
|
1497
|
+
|
|
1498
|
+
Returns
|
|
1499
|
+
-------
|
|
1500
|
+
x (torch.Tensor) - Output tensor, same shape as input.
|
|
1501
|
+
"""
|
|
1502
|
+
batch_size, channels, h, w = x.shape
|
|
1503
|
+
x = x.reshape(batch_size, channels, h * w)
|
|
1504
|
+
x = self.group_norm(x)
|
|
1505
|
+
x = x.transpose(1, 2)
|
|
1506
|
+
x, _ = self.attention(x, x, x)
|
|
1507
|
+
x = self.dropout(x)
|
|
1508
|
+
x = x.transpose(1, 2).reshape(batch_size, channels, h, w)
|
|
1509
|
+
return x
|
|
1510
|
+
|
|
1511
|
+
###==================================================================================================================###
|
|
1512
|
+
|
|
1513
|
+
class UpBlock(nn.Module):
|
|
1514
|
+
"""Upsampling block for the decoder in AutoencoderLDM.
|
|
1515
|
+
|
|
1516
|
+
Applies upsampling followed by multiple convolutional layers with residual
|
|
1517
|
+
connections to increase spatial dimensions in the decoder of the variational
|
|
1518
|
+
autoencoder used in Latent Diffusion Models.
|
|
1519
|
+
|
|
1520
|
+
Parameters
|
|
1521
|
+
----------
|
|
1522
|
+
in_channels : int
|
|
1523
|
+
Number of input channels.
|
|
1524
|
+
out_channels : int
|
|
1525
|
+
Number of output channels for convolutional layers.
|
|
1526
|
+
num_layers : int
|
|
1527
|
+
Number of convolutional layer pairs (Conv3) per block.
|
|
1528
|
+
up_sampling_factor : int
|
|
1529
|
+
Factor by which to upsample spatial dimensions.
|
|
1530
|
+
dropout_rate : float
|
|
1531
|
+
Dropout rate for Conv3 layers.
|
|
1532
|
+
|
|
1533
|
+
**Notes**
|
|
1534
|
+
|
|
1535
|
+
- Upsampling is applied first, followed by convolutional layer pairs with residual connections using 1x1 convolutions.
|
|
1536
|
+
- Each layer pair consists of two Conv3 modules.
|
|
1537
|
+
"""
|
|
1538
|
+
def __init__(self, in_channels: int, out_channels: int, num_layers: int, up_sampling_factor: int, dropout_rate: float) -> None:
|
|
1539
|
+
super().__init__()
|
|
1540
|
+
self.num_layers = num_layers
|
|
1541
|
+
effective_in_channels = in_channels
|
|
1542
|
+
|
|
1543
|
+
self.up_sampling = UpSampling(
|
|
1544
|
+
in_channels=in_channels,
|
|
1545
|
+
out_channels=in_channels,
|
|
1546
|
+
up_sampling_factor=up_sampling_factor
|
|
1547
|
+
)
|
|
1548
|
+
|
|
1549
|
+
self.conv1 = nn.ModuleList([
|
|
1550
|
+
Conv3(
|
|
1551
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
1552
|
+
out_channels=out_channels,
|
|
1553
|
+
dropout_rate=dropout_rate
|
|
1554
|
+
) for i in range(self.num_layers)
|
|
1555
|
+
])
|
|
1556
|
+
self.conv2 = nn.ModuleList([
|
|
1557
|
+
Conv3(
|
|
1558
|
+
in_channels=out_channels,
|
|
1559
|
+
out_channels=out_channels,
|
|
1560
|
+
dropout_rate=dropout_rate
|
|
1561
|
+
) for _ in range(self.num_layers)
|
|
1562
|
+
])
|
|
1563
|
+
self.resnet = nn.ModuleList([
|
|
1564
|
+
nn.Conv2d(
|
|
1565
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
1566
|
+
out_channels=out_channels,
|
|
1567
|
+
kernel_size=1
|
|
1568
|
+
) for i in range(self.num_layers)
|
|
1569
|
+
])
|
|
1570
|
+
|
|
1571
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1572
|
+
"""Processes input through upsampling and convolutional layers.
|
|
1573
|
+
|
|
1574
|
+
Parameters
|
|
1575
|
+
----------
|
|
1576
|
+
x : torch.Tensor
|
|
1577
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
1578
|
+
|
|
1579
|
+
Returns
|
|
1580
|
+
-------
|
|
1581
|
+
output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height * up_sampling_factor, width * up_sampling_factor).
|
|
1582
|
+
"""
|
|
1583
|
+
x = self.up_sampling(x)
|
|
1584
|
+
output = x
|
|
1585
|
+
for i in range(self.num_layers):
|
|
1586
|
+
resnet_input = output
|
|
1587
|
+
output = self.conv1[i](output)
|
|
1588
|
+
output = self.conv2[i](output)
|
|
1589
|
+
output = output + self.resnet[i](resnet_input)
|
|
1590
|
+
return output
|
|
1591
|
+
|
|
1592
|
+
###==================================================================================================================###
|
|
1593
|
+
|
|
1594
|
+
class UpSampling(nn.Module):
|
|
1595
|
+
"""Upsampling module for increasing spatial dimensions in AutoencoderLDM’s decoder.
|
|
1596
|
+
|
|
1597
|
+
Combines transposed convolution and nearest-neighbor upsampling, concatenating
|
|
1598
|
+
their outputs to preserve feature information during upsampling in UpBlock.
|
|
1599
|
+
|
|
1600
|
+
Parameters
|
|
1601
|
+
----------
|
|
1602
|
+
in_channels : int
|
|
1603
|
+
Number of input channels.
|
|
1604
|
+
out_channels : int
|
|
1605
|
+
Number of output channels (sum of conv and upsample paths).
|
|
1606
|
+
up_sampling_factor : int
|
|
1607
|
+
Factor by which to upsample spatial dimensions.
|
|
1608
|
+
|
|
1609
|
+
**Notes**
|
|
1610
|
+
|
|
1611
|
+
- The module splits the output channels evenly between transposed convolution and upsampling paths, concatenating them along the channel dimension.
|
|
1612
|
+
- If the spatial dimensions of the two paths differ, the upsampling path is interpolated to match the convolutional path’s size.
|
|
1613
|
+
"""
|
|
1614
|
+
def __init__(self, in_channels: int, out_channels: int, up_sampling_factor: int) -> None:
|
|
1615
|
+
super().__init__()
|
|
1616
|
+
half_out_channels = out_channels // 2
|
|
1617
|
+
self.up_sampling_factor = up_sampling_factor
|
|
1618
|
+
self.conv = nn.Sequential(
|
|
1619
|
+
nn.ConvTranspose2d(
|
|
1620
|
+
in_channels=in_channels,
|
|
1621
|
+
out_channels=half_out_channels,
|
|
1622
|
+
kernel_size=3,
|
|
1623
|
+
stride=up_sampling_factor,
|
|
1624
|
+
padding=1,
|
|
1625
|
+
output_padding=up_sampling_factor - 1
|
|
1626
|
+
),
|
|
1627
|
+
nn.Conv2d(
|
|
1628
|
+
in_channels=half_out_channels,
|
|
1629
|
+
out_channels=half_out_channels,
|
|
1630
|
+
kernel_size=1,
|
|
1631
|
+
stride=1,
|
|
1632
|
+
padding=0
|
|
1633
|
+
)
|
|
1634
|
+
)
|
|
1635
|
+
self.up_sample = nn.Sequential(
|
|
1636
|
+
nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
|
|
1637
|
+
nn.Conv2d(
|
|
1638
|
+
in_channels=in_channels,
|
|
1639
|
+
out_channels=half_out_channels,
|
|
1640
|
+
kernel_size=1,
|
|
1641
|
+
stride=1,
|
|
1642
|
+
padding=0
|
|
1643
|
+
)
|
|
1644
|
+
)
|
|
1645
|
+
|
|
1646
|
+
def forward(self, batch: torch.Tensor) -> torch.Tensor:
|
|
1647
|
+
"""Upsamples input by combining transposed convolution and upsampling paths.
|
|
1648
|
+
|
|
1649
|
+
Parameters
|
|
1650
|
+
----------
|
|
1651
|
+
batch : torch.Tensor
|
|
1652
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
1653
|
+
|
|
1654
|
+
Returns
|
|
1655
|
+
-------
|
|
1656
|
+
x (torch.Tensor) - Upsampled tensor, shape (batch_size, out_channels, height * up_sampling_factor, width * up_sampling_factor).
|
|
1657
|
+
|
|
1658
|
+
**Notes**
|
|
1659
|
+
|
|
1660
|
+
- Interpolation is applied if the spatial dimensions of the convolutional and upsampling paths differ, using nearest-neighbor mode.
|
|
1661
|
+
"""
|
|
1662
|
+
conv_output = self.conv(batch)
|
|
1663
|
+
up_sample_output = self.up_sample(batch)
|
|
1664
|
+
if conv_output.shape[2:] != up_sample_output.shape[2:]:
|
|
1665
|
+
_, _, h, w = conv_output.shape
|
|
1666
|
+
up_sample_output = torch.nn.functional.interpolate(
|
|
1667
|
+
up_sample_output,
|
|
1668
|
+
size=(h, w),
|
|
1669
|
+
mode='nearest'
|
|
1670
|
+
)
|
|
1671
|
+
return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
|
|
1672
|
+
|
|
1673
|
+
###==================================================================================================================###
|
|
1674
|
+
|
|
1675
|
+
class TrainAE(nn.Module):
|
|
1676
|
+
"""Trainer for the AutoencoderLDM variational autoencoder in Latent Diffusion Models.
|
|
1677
|
+
|
|
1678
|
+
Optimizes the AutoencoderLDM model to compress images into latent space and reconstruct
|
|
1679
|
+
them, using reconstruction loss (MSE), regularization (KL or VQ), and optional
|
|
1680
|
+
perceptual loss (LPIPS). Supports mixed precision, KL warmup, early stopping, and
|
|
1681
|
+
learning rate scheduling, with evaluation metrics (MSE, PSNR, SSIM, FID, LPIPS).
|
|
1682
|
+
|
|
1683
|
+
Parameters
|
|
1684
|
+
----------
|
|
1685
|
+
model : nn.Module
|
|
1686
|
+
The variational autoencoder model (AutoencoderLDM) to train.
|
|
1687
|
+
optimizer : torch.optim.Optimizer
|
|
1688
|
+
Optimizer for training (e.g., Adam).
|
|
1689
|
+
data_loader : torch.utils.data.DataLoader
|
|
1690
|
+
DataLoader for training data.
|
|
1691
|
+
val_loader : torch.utils.data.DataLoader, optional
|
|
1692
|
+
DataLoader for validation data (default: None).
|
|
1693
|
+
max_epochs : int, optional
|
|
1694
|
+
Maximum number of training epochs (default: 100).
|
|
1695
|
+
metrics_ : object, optional
|
|
1696
|
+
Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
|
|
1697
|
+
device : None, optional
|
|
1698
|
+
Device for computation (e.g., 'cuda', 'cpu').
|
|
1699
|
+
store_path : str, optional
|
|
1700
|
+
Path to save model checkpoints (default: 'vlc_model.pth').
|
|
1701
|
+
checkpoint : int, optional
|
|
1702
|
+
Frequency (in epochs) to save model checkpoints (default: 10).
|
|
1703
|
+
kl_warmup_epochs : int, optional
|
|
1704
|
+
Number of epochs for KL loss warmup (default: 10).
|
|
1705
|
+
patience : int, optional
|
|
1706
|
+
Number of epochs to wait for early stopping if validation loss doesn’t improve
|
|
1707
|
+
(default: 10).
|
|
1708
|
+
val_frequency : int, optional
|
|
1709
|
+
Frequency (in epochs) for validation and metric computation (default: 5).
|
|
1710
|
+
use_ddp : bool, optional
|
|
1711
|
+
Whether to use Distributed Data Parallel training (default: False).
|
|
1712
|
+
grad_accumulation_steps : int, optional
|
|
1713
|
+
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
1714
|
+
log_frequency : int, optional
|
|
1715
|
+
Number of epochs before printing loss.
|
|
1716
|
+
"""
|
|
1717
|
+
|
|
1718
|
+
def __init__(
|
|
1719
|
+
self,
|
|
1720
|
+
model: torch.nn.Module,
|
|
1721
|
+
optimizer: torch.optim.Optimizer,
|
|
1722
|
+
data_loader: torch.utils.data.DataLoader,
|
|
1723
|
+
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
1724
|
+
max_epochs: int = 100,
|
|
1725
|
+
metrics_: Optional[Any] = None,
|
|
1726
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
1727
|
+
store_path: str = "vlc_model",
|
|
1728
|
+
checkpoint: int = 10,
|
|
1729
|
+
kl_warmup_epochs: int = 10,
|
|
1730
|
+
patience: int = 10,
|
|
1731
|
+
val_frequency: int = 5,
|
|
1732
|
+
warmup_epochs: int = 100,
|
|
1733
|
+
use_ddp: bool = False,
|
|
1734
|
+
grad_accumulation_steps: int = 1,
|
|
1735
|
+
log_frequency: int = 1,
|
|
1736
|
+
use_compilation: bool = False
|
|
1737
|
+
) -> None:
|
|
1738
|
+
super().__init__()
|
|
1739
|
+
|
|
1740
|
+
# initialize DDP settings first
|
|
1741
|
+
self.use_ddp = use_ddp
|
|
1742
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
1743
|
+
if device is None:
|
|
1744
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
1745
|
+
elif isinstance(device, str):
|
|
1746
|
+
self.device = torch.device(device)
|
|
1747
|
+
else:
|
|
1748
|
+
self.device = device
|
|
1749
|
+
|
|
1750
|
+
# setup distributed training if enabled
|
|
1751
|
+
if self.use_ddp:
|
|
1752
|
+
self._setup_ddp()
|
|
1753
|
+
else:
|
|
1754
|
+
self._setup_single_gpu()
|
|
1755
|
+
|
|
1756
|
+
self.model = model.to(self.device)
|
|
1757
|
+
self.optimizer = optimizer
|
|
1758
|
+
self.data_loader = data_loader
|
|
1759
|
+
self.val_loader = val_loader
|
|
1760
|
+
self.max_epochs = max_epochs
|
|
1761
|
+
self.metrics_ = metrics_
|
|
1762
|
+
self.store_path = store_path
|
|
1763
|
+
self.checkpoint = checkpoint
|
|
1764
|
+
self.kl_warmup_epochs = kl_warmup_epochs
|
|
1765
|
+
self.patience = patience
|
|
1766
|
+
self.use_compilation = use_compilation
|
|
1767
|
+
|
|
1768
|
+
# Learning rate scheduling
|
|
1769
|
+
self.scheduler = ReduceLROnPlateau(
|
|
1770
|
+
self.optimizer,
|
|
1771
|
+
patience=self.patience,
|
|
1772
|
+
factor=0.5
|
|
1773
|
+
)
|
|
1774
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
1775
|
+
self.val_frequency = val_frequency
|
|
1776
|
+
self.log_frequency = log_frequency
|
|
1777
|
+
|
|
1778
|
+
def _setup_ddp(self) -> None:
|
|
1779
|
+
"""Setup Distributed Data Parallel training configuration.
|
|
1780
|
+
|
|
1781
|
+
Initializes process group, determines rank information, and sets up
|
|
1782
|
+
CUDA device for the current process.
|
|
1783
|
+
"""
|
|
1784
|
+
# check if DDP environment variables are set
|
|
1785
|
+
if "RANK" not in os.environ:
|
|
1786
|
+
raise ValueError("DDP enabled but RANK environment variable not set")
|
|
1787
|
+
if "LOCAL_RANK" not in os.environ:
|
|
1788
|
+
raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
|
|
1789
|
+
if "WORLD_SIZE" not in os.environ:
|
|
1790
|
+
raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
|
|
1791
|
+
|
|
1792
|
+
# ensure CUDA is available for DDP
|
|
1793
|
+
if not torch.cuda.is_available():
|
|
1794
|
+
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
1795
|
+
|
|
1796
|
+
# initialize process group only if not already initialized
|
|
1797
|
+
if not torch.distributed.is_initialized():
|
|
1798
|
+
init_process_group(backend="nccl")
|
|
1799
|
+
|
|
1800
|
+
# get rank information
|
|
1801
|
+
self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
|
|
1802
|
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
|
|
1803
|
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
|
|
1804
|
+
|
|
1805
|
+
# set device and make it current
|
|
1806
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
1807
|
+
torch.cuda.set_device(self.device)
|
|
1808
|
+
|
|
1809
|
+
# master process handles logging, checkpointing, etc.
|
|
1810
|
+
self.master_process = self.ddp_rank == 0
|
|
1811
|
+
|
|
1812
|
+
if self.master_process:
|
|
1813
|
+
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
1814
|
+
|
|
1815
|
+
def _setup_single_gpu(self) -> None:
|
|
1816
|
+
"""Setup single GPU or CPU training configuration."""
|
|
1817
|
+
self.ddp_rank = 0
|
|
1818
|
+
self.ddp_local_rank = 0
|
|
1819
|
+
self.ddp_world_size = 1
|
|
1820
|
+
self.master_process = True
|
|
1821
|
+
|
|
1822
|
+
|
|
1823
|
+
def load_checkpoint(self, checkpoint_path: str) -> Tuple[float, float]:
|
|
1824
|
+
"""Loads a training checkpoint to resume training.
|
|
1825
|
+
|
|
1826
|
+
Restores the state of the noise predictor, conditional model (if applicable),
|
|
1827
|
+
and optimizer from a saved checkpoint.
|
|
1828
|
+
|
|
1829
|
+
Parameters
|
|
1830
|
+
----------
|
|
1831
|
+
checkpoint_path : str
|
|
1832
|
+
Path to the checkpoint file.
|
|
1833
|
+
|
|
1834
|
+
Returns
|
|
1835
|
+
-------
|
|
1836
|
+
epoch : float
|
|
1837
|
+
The epoch at which the checkpoint was saved (int).
|
|
1838
|
+
loss : float
|
|
1839
|
+
The loss at the checkpoint (float).
|
|
1840
|
+
"""
|
|
1841
|
+
try:
|
|
1842
|
+
# load checkpoint with proper device mapping
|
|
1843
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
1844
|
+
except FileNotFoundError:
|
|
1845
|
+
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
|
1846
|
+
|
|
1847
|
+
|
|
1848
|
+
if 'model_state_dict' not in checkpoint:
|
|
1849
|
+
raise KeyError("Checkpoint missing 'model_state_dict' key")
|
|
1850
|
+
|
|
1851
|
+
# Handle DDP wrapped model state dict
|
|
1852
|
+
state_dict = checkpoint['model_state_dict']
|
|
1853
|
+
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
|
|
1854
|
+
# if loading non-DDP checkpoint into DDP model, add 'module.' prefix
|
|
1855
|
+
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
|
|
1856
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
1857
|
+
# if loading DDP checkpoint into non-DDP model, remove 'module.' prefix
|
|
1858
|
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
1859
|
+
self.model.load_state_dict(state_dict)
|
|
1860
|
+
|
|
1861
|
+
if 'optimizer_state_dict' not in checkpoint:
|
|
1862
|
+
raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
|
|
1863
|
+
try:
|
|
1864
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
1865
|
+
except ValueError as e:
|
|
1866
|
+
warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
|
|
1867
|
+
|
|
1868
|
+
epoch = checkpoint.get('epoch', -1)
|
|
1869
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
1870
|
+
|
|
1871
|
+
self.noise_predictor.to(self.device)
|
|
1872
|
+
if self.conditional_model is not None:
|
|
1873
|
+
self.conditional_model.to(self.device)
|
|
1874
|
+
|
|
1875
|
+
if self.master_process:
|
|
1876
|
+
print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
|
|
1877
|
+
|
|
1878
|
+
return epoch, loss
|
|
1879
|
+
|
|
1880
|
+
@staticmethod
|
|
1881
|
+
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
|
|
1882
|
+
"""Creates a learning rate scheduler for warmup.
|
|
1883
|
+
|
|
1884
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
1885
|
+
optimizer's initial value over the specified warmup epochs, then maintains it.
|
|
1886
|
+
|
|
1887
|
+
Parameters
|
|
1888
|
+
----------
|
|
1889
|
+
optimizer : torch.optim.Optimizer
|
|
1890
|
+
Optimizer to apply the scheduler to.
|
|
1891
|
+
warmup_epochs : int
|
|
1892
|
+
Number of epochs for the warmup phase.
|
|
1893
|
+
|
|
1894
|
+
Returns
|
|
1895
|
+
-------
|
|
1896
|
+
torch.optim.lr_scheduler.LambdaLR
|
|
1897
|
+
Learning rate scheduler for warmup.
|
|
1898
|
+
"""
|
|
1899
|
+
|
|
1900
|
+
def lr_lambda(epoch):
|
|
1901
|
+
if epoch < warmup_epochs:
|
|
1902
|
+
return epoch / warmup_epochs
|
|
1903
|
+
return 1.0
|
|
1904
|
+
|
|
1905
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
1906
|
+
|
|
1907
|
+
def _wrap_models_for_ddp(self) -> None:
|
|
1908
|
+
"""Wrap models with DistributedDataParallel for multi-GPU training."""
|
|
1909
|
+
if self.use_ddp:
|
|
1910
|
+
# wrap noise predictor with DDP
|
|
1911
|
+
self.noise_predictor = DDP(
|
|
1912
|
+
self.noise_predictor,
|
|
1913
|
+
device_ids=[self.ddp_local_rank],
|
|
1914
|
+
find_unused_parameters=True
|
|
1915
|
+
)
|
|
1916
|
+
|
|
1917
|
+
# wrap conditional model with DDP if it exists
|
|
1918
|
+
if self.conditional_model is not None:
|
|
1919
|
+
self.conditional_model = DDP(
|
|
1920
|
+
self.conditional_model,
|
|
1921
|
+
device_ids=[self.ddp_local_rank],
|
|
1922
|
+
find_unused_parameters=True
|
|
1923
|
+
)
|
|
1924
|
+
|
|
1925
|
+
|
|
1926
|
+
def forward(self) -> Tuple[List[float], float]:
|
|
1927
|
+
"""Trains the AutoencoderLDM model with mixed precision and evaluation metrics.
|
|
1928
|
+
|
|
1929
|
+
Performs training with reconstruction and regularization losses, KL warmup, gradient
|
|
1930
|
+
clipping, and learning rate scheduling. Saves checkpoints for the best validation
|
|
1931
|
+
loss and supports early stopping.
|
|
1932
|
+
|
|
1933
|
+
Returns
|
|
1934
|
+
-------
|
|
1935
|
+
train_losses : list
|
|
1936
|
+
List of mean training losses per epoch.
|
|
1937
|
+
best_val_loss : float
|
|
1938
|
+
Best validation loss achieved (or best training loss if no validation).
|
|
1939
|
+
"""
|
|
1940
|
+
# compile models for optimization (if supported)
|
|
1941
|
+
if self.use_compilation:
|
|
1942
|
+
try:
|
|
1943
|
+
self.model = torch.compile(self.model)
|
|
1944
|
+
except Exception as e:
|
|
1945
|
+
if self.master_process:
|
|
1946
|
+
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
1947
|
+
|
|
1948
|
+
# wrap models for DDP after compilation
|
|
1949
|
+
self._wrap_models_for_ddp()
|
|
1950
|
+
|
|
1951
|
+
# initialize training components
|
|
1952
|
+
scaler = torch.GradScaler()
|
|
1953
|
+
train_losses = []
|
|
1954
|
+
best_val_loss = float("inf")
|
|
1955
|
+
wait = 0
|
|
1956
|
+
|
|
1957
|
+
# main training loop
|
|
1958
|
+
for epoch in range(self.max_epochs):
|
|
1959
|
+
# set epoch for distributed sampler if using DDP
|
|
1960
|
+
if self.use_ddp and hasattr(self.data_loader.sampler, 'set_epoch'):
|
|
1961
|
+
self.data_loader.sampler.set_epoch(epoch)
|
|
1962
|
+
|
|
1963
|
+
if self.model.use_vq:
|
|
1964
|
+
beta = 1.0 # no warmup for VQ
|
|
1965
|
+
else:
|
|
1966
|
+
beta = min(1.0, epoch / self.kl_warmup_epochs) * self.model.beta
|
|
1967
|
+
self.model.current_beta = beta
|
|
1968
|
+
|
|
1969
|
+
train_losses_epoch = []
|
|
1970
|
+
|
|
1971
|
+
# training step loop with gradient accumulation
|
|
1972
|
+
for step, (x, y) in enumerate(tqdm(self.data_loader, disable=not self.master_process)):
|
|
1973
|
+
x = x.to(self.device)
|
|
1974
|
+
|
|
1975
|
+
# forward pass with mixed precision
|
|
1976
|
+
with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
|
|
1977
|
+
x_hat, loss, reg_loss, z = self.model(x)
|
|
1978
|
+
# compute loss and scale for gradient accumulation
|
|
1979
|
+
loss = loss / self.grad_accumulation_steps
|
|
1980
|
+
|
|
1981
|
+
# backward pass
|
|
1982
|
+
scaler.scale(loss).backward()
|
|
1983
|
+
|
|
1984
|
+
# gradient accumulation and optimizer step
|
|
1985
|
+
if (step + 1) % self.grad_accumulation_steps == 0:
|
|
1986
|
+
# clip gradients
|
|
1987
|
+
scaler.unscale_(self.optimizer)
|
|
1988
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
1989
|
+
|
|
1990
|
+
# optimizer step
|
|
1991
|
+
scaler.step(self.optimizer)
|
|
1992
|
+
scaler.update()
|
|
1993
|
+
self.optimizer.zero_grad()
|
|
1994
|
+
|
|
1995
|
+
# update learning rate (warmup scheduler)
|
|
1996
|
+
self.warmup_lr_scheduler.step()
|
|
1997
|
+
|
|
1998
|
+
# record loss (unscaled)
|
|
1999
|
+
train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
|
|
2000
|
+
|
|
2001
|
+
# compute mean training loss
|
|
2002
|
+
mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
|
|
2003
|
+
|
|
2004
|
+
# all-reduce loss across processes for DDP
|
|
2005
|
+
if self.use_ddp:
|
|
2006
|
+
loss_tensor = torch.tensor(mean_train_loss, device=self.device)
|
|
2007
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
|
2008
|
+
mean_train_loss = loss_tensor.item()
|
|
2009
|
+
|
|
2010
|
+
train_losses.append(mean_train_loss)
|
|
2011
|
+
|
|
2012
|
+
# print training progress (only master process)
|
|
2013
|
+
if self.master_process and (epoch + 1) % self.log_frequency == 0:
|
|
2014
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
2015
|
+
print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
|
|
2016
|
+
|
|
2017
|
+
# validation step
|
|
2018
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
2019
|
+
val_metrics = self.validate()
|
|
2020
|
+
val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
|
|
2021
|
+
|
|
2022
|
+
if self.master_process:
|
|
2023
|
+
print(f" | Val Loss: {val_loss:.4f}", end="")
|
|
2024
|
+
if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
2025
|
+
print(f" | FID: {fid:.4f}", end="")
|
|
2026
|
+
if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
2027
|
+
print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
|
|
2028
|
+
if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
2029
|
+
print(f" | LPIPS: {lpips_score:.4f}", end="")
|
|
2030
|
+
print()
|
|
2031
|
+
|
|
2032
|
+
current_best = val_loss
|
|
2033
|
+
self.scheduler.step(val_loss)
|
|
2034
|
+
else:
|
|
2035
|
+
if self.master_process:
|
|
2036
|
+
print()
|
|
2037
|
+
current_best = mean_train_loss
|
|
2038
|
+
self.scheduler.step(mean_train_loss)
|
|
2039
|
+
|
|
2040
|
+
# save checkpoint and early stopping (only master process)
|
|
2041
|
+
if self.master_process:
|
|
2042
|
+
if current_best < best_val_loss and (epoch + 1) % self.val_frequency == 0:
|
|
2043
|
+
best_val_loss = current_best
|
|
2044
|
+
wait = 0
|
|
2045
|
+
self._save_checkpoint(epoch + 1, best_val_loss)
|
|
2046
|
+
else:
|
|
2047
|
+
wait += 1
|
|
2048
|
+
if wait >= self.patience:
|
|
2049
|
+
print("Early stopping triggered")
|
|
2050
|
+
self._save_checkpoint(epoch + 1, best_val_loss, "_early_stop")
|
|
2051
|
+
break
|
|
2052
|
+
|
|
2053
|
+
# clean up DDP
|
|
2054
|
+
if self.use_ddp:
|
|
2055
|
+
destroy_process_group()
|
|
2056
|
+
|
|
2057
|
+
return train_losses, best_val_loss
|
|
2058
|
+
|
|
2059
|
+
def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "") -> None:
|
|
2060
|
+
"""Save model checkpoint (only called by master process).
|
|
2061
|
+
|
|
2062
|
+
Parameters
|
|
2063
|
+
----------
|
|
2064
|
+
epoch : int
|
|
2065
|
+
Current epoch number.
|
|
2066
|
+
loss : float
|
|
2067
|
+
Current loss value.
|
|
2068
|
+
suffix : str, optional
|
|
2069
|
+
Suffix to add to checkpoint filename.
|
|
2070
|
+
"""
|
|
2071
|
+
try:
|
|
2072
|
+
# get state dicts, handling DDP wrapping
|
|
2073
|
+
model_state = (
|
|
2074
|
+
self.model.module.state_dict() if self.use_ddp else self.model.state_dict()
|
|
2075
|
+
)
|
|
2076
|
+
|
|
2077
|
+
checkpoint = {
|
|
2078
|
+
'epoch': epoch,
|
|
2079
|
+
'model_state_dict': model_state,
|
|
2080
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
2081
|
+
'loss': loss,
|
|
2082
|
+
'max_epochs': self.max_epochs,
|
|
2083
|
+
}
|
|
2084
|
+
|
|
2085
|
+
filename = f"ldm_epoch_{epoch}{suffix}.pth"
|
|
2086
|
+
filepath = os.path.join(self.store_path, filename)
|
|
2087
|
+
os.makedirs(self.store_path, exist_ok=True)
|
|
2088
|
+
torch.save(checkpoint, filepath)
|
|
2089
|
+
print(f"Model saved at epoch {epoch}")
|
|
2090
|
+
except Exception as e:
|
|
2091
|
+
print(f"Failed to save model: {e}")
|
|
2092
|
+
|
|
2093
|
+
def validate(self) -> Tuple[float, float, float, float, float, float]:
|
|
2094
|
+
"""Validates the AutoencoderLDM model and computes evaluation Metrics.
|
|
2095
|
+
|
|
2096
|
+
Computes validation loss and optional Metrics (MSE, PSNR, SSIM, FID, LPIPS) using
|
|
2097
|
+
the provided Metrics object.
|
|
2098
|
+
|
|
2099
|
+
Returns
|
|
2100
|
+
-------
|
|
2101
|
+
val_loss : float
|
|
2102
|
+
Mean validation loss.
|
|
2103
|
+
fid : float, or `float('inf')` if not computed
|
|
2104
|
+
Mean FID score.
|
|
2105
|
+
mse : float, or None if not computed
|
|
2106
|
+
Mean MSE
|
|
2107
|
+
psnr : float, or None if not computed
|
|
2108
|
+
Mean PSNR
|
|
2109
|
+
ssim : float, or None if not computed
|
|
2110
|
+
Mean SSIM
|
|
2111
|
+
lpips_score : float, or None if not computed
|
|
2112
|
+
Mean LPIPS score
|
|
2113
|
+
"""
|
|
2114
|
+
self.model.eval()
|
|
2115
|
+
|
|
2116
|
+
val_losses = []
|
|
2117
|
+
fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
|
|
2118
|
+
|
|
2119
|
+
with torch.no_grad():
|
|
2120
|
+
for x, _ in self.val_loader:
|
|
2121
|
+
x = x.to(self.device)
|
|
2122
|
+
x_hat, loss, reg_loss, z = self.model(x)
|
|
2123
|
+
val_losses.append(loss.item())
|
|
2124
|
+
|
|
2125
|
+
# compute metrics
|
|
2126
|
+
if self.metrics_ is not None:
|
|
2127
|
+
metrics_result = self.metrics_.forward(x, x_hat)
|
|
2128
|
+
fid, mse, psnr, ssim, lpips_score = metrics_result
|
|
2129
|
+
|
|
2130
|
+
if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
2131
|
+
fid_scores.append(fid)
|
|
2132
|
+
if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
2133
|
+
mse_scores.append(mse)
|
|
2134
|
+
psnr_scores.append(psnr)
|
|
2135
|
+
ssim_scores.append(ssim)
|
|
2136
|
+
if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
2137
|
+
lpips_scores.append(lpips_score)
|
|
2138
|
+
|
|
2139
|
+
# compute average metrics
|
|
2140
|
+
val_loss = torch.tensor(val_losses).mean().item()
|
|
2141
|
+
|
|
2142
|
+
# all-reduce validation metrics across processes for DDP
|
|
2143
|
+
if self.use_ddp:
|
|
2144
|
+
val_loss_tensor = torch.tensor(val_loss, device=self.device)
|
|
2145
|
+
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
|
|
2146
|
+
val_loss = val_loss_tensor.item()
|
|
2147
|
+
|
|
2148
|
+
fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
|
|
2149
|
+
mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
|
|
2150
|
+
psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
|
|
2151
|
+
ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
|
|
2152
|
+
lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
|
|
2153
|
+
|
|
2154
|
+
self.model.train()
|
|
2155
|
+
|
|
2156
|
+
return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
|