TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
torchdiff/unclip.py
ADDED
|
@@ -0,0 +1,4170 @@
|
|
|
1
|
+
"""
|
|
2
|
+
**UnCLIP Diffusion Model**
|
|
3
|
+
|
|
4
|
+
This module provides a comprehensive implementation of the UnCLIP diffusion model,
|
|
5
|
+
as described in Ramesh et al. (2022, "Hierarchical Text-Conditional Image Generation with CLIP Latents").
|
|
6
|
+
It integrates CLIP embeddings with diffusion processes for high-quality image generation conditioned on text prompts or image embeddings.
|
|
7
|
+
The module supports training, sampling, and upsampling processes, leveraging components from CLIP, GLIDE, and DDIM,
|
|
8
|
+
with classifier-free guidance and text dropout for robust generation.
|
|
9
|
+
|
|
10
|
+
**Components**
|
|
11
|
+
|
|
12
|
+
- **VarianceSchedulerUnCLIP**: Manages noise schedules with support for linear, sigmoid, quadratic, constant, inverse_time,
|
|
13
|
+
and cosine beta schedules, including subsampled (tau) schedules for efficient sampling.
|
|
14
|
+
- **ForwardUnCLIP**: Forward diffusion process to add noise to image or latent embeddings.
|
|
15
|
+
- **ReverseUnCLIP**: Reverse diffusion process for denoising, supporting noise or clean image predictions with subsampled steps.
|
|
16
|
+
- **CLIPEncoder**: Encodes images or text into embeddings using a pre-trained CLIP model.
|
|
17
|
+
- **UnClipDecoder**: Generates low-resolution images (64x64) from CLIP embeddings, incorporating GLIDE text encoding and classifier-free guidance.
|
|
18
|
+
- **UnCLIPTransformerPrior**: Transformer-based prior to predict clean image embeddings from noisy embeddings and text conditions.
|
|
19
|
+
- **CLIPContextProjection**: Projects CLIP image embeddings into context tokens for the decoder.
|
|
20
|
+
- **CLIPEmbeddingProjection**: Reduces and reconstructs embedding dimensionality for efficient processing.
|
|
21
|
+
- **TrainUnClipDecoder**: Orchestrates training of the decoder with mixed precision, gradient accumulation, and DDP support.
|
|
22
|
+
- **SampleUnCLIP**: Generates images from text prompts or noise, scaling from 64x64 to 256x256 or 1024x1024 with upsamplers.
|
|
23
|
+
- **UpsamplerUnCLIP**: U-Net-based upsampler for scaling images (64x64 to 256x256 or 256x256 to 1024x1024), conditioned on low-resolution inputs.
|
|
24
|
+
- **TrainUpsamplerUnCLIP**: Trains the upsampler with noise prediction, low-resolution conditioning, and optional image corruption (Gaussian blur or BSR degradation).
|
|
25
|
+
|
|
26
|
+
**Notes**
|
|
27
|
+
|
|
28
|
+
- The model uses a subsampled time step schedule (tau) for faster sampling, controlled by the `tau_num_steps` parameter in VarianceSchedulerUnCLIP.
|
|
29
|
+
- Classifier-free guidance and text dropout enhance generation quality, with tunable parameters `classifier_free_prop` and `drop_caption`.
|
|
30
|
+
- Upsampling stages use corrupted low-resolution inputs (Gaussian blur for 64x64→256x256, BSR degradation for 256x256→1024x1024) to improve robustness.
|
|
31
|
+
- Supports distributed training with DDP, mixed precision via autocast, and learning rate scheduling with warmup and plateau reduction.
|
|
32
|
+
|
|
33
|
+
**References**
|
|
34
|
+
|
|
35
|
+
- Ramesh, Aditya, et al. "Hierarchical Text-Conditional Image Generation with CLIP Latents." arXiv preprint arXiv:2204.06125 (2022).
|
|
36
|
+
- Radford, Alec, et al. "Learning Transferable Visual Models From Natural Language Supervision." arXiv preprint arXiv:2103.00020 (2021).
|
|
37
|
+
- Nichol, Alexander, et al. "GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models." arXiv preprint arXiv:2112.10741 (2021).
|
|
38
|
+
- Song, Jiaming, et al. "Denoising Diffusion Implicit Models." arXiv preprint arXiv:2010.02502 (2020).
|
|
39
|
+
|
|
40
|
+
-------------------------------------------------------------------------------
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
import torch
|
|
44
|
+
import torch.nn as nn
|
|
45
|
+
import torch.nn.functional as F
|
|
46
|
+
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
|
47
|
+
import torch.distributed as dist
|
|
48
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
49
|
+
from torch.distributed import init_process_group, destroy_process_group
|
|
50
|
+
import torchvision
|
|
51
|
+
from PIL import Image
|
|
52
|
+
from transformers import BertTokenizer, CLIPProcessor, CLIPModel
|
|
53
|
+
from typing import Optional, List, Tuple, Union, Callable, Any, Self
|
|
54
|
+
from tqdm import tqdm
|
|
55
|
+
import os
|
|
56
|
+
import warnings
|
|
57
|
+
import random
|
|
58
|
+
import math
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
###==================================================================================================================###
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class VarianceSchedulerUnCLIP(nn.Module):
|
|
65
|
+
"""Manages noise schedule parameters for UnCLIP diffusion models.
|
|
66
|
+
|
|
67
|
+
Handles beta values, derived noise schedule quantities, and a subsampled time step schedule
|
|
68
|
+
(tau schedule) for UnCLIP diffusion processes. Supports trainable or fixed beta schedules
|
|
69
|
+
and multiple scheduling methods, including linear, sigmoid, quadratic, constant, inverse_time,
|
|
70
|
+
and cosine schedules.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
`eta` : float, optional
|
|
75
|
+
Noise scaling factor for the reverse process (default: 0, deterministic).
|
|
76
|
+
`num_steps` : int, optional
|
|
77
|
+
Total number of diffusion steps (default: 1000).
|
|
78
|
+
`tau_num_steps` : int, optional
|
|
79
|
+
Number of subsampled time steps for sampling (default: 100).
|
|
80
|
+
`beta_start` : float, optional
|
|
81
|
+
Starting value for beta (default: 1e-4).
|
|
82
|
+
`beta_end` : float, optional
|
|
83
|
+
Ending value for beta (default: 0.02).
|
|
84
|
+
`trainable_beta` : bool, optional
|
|
85
|
+
Whether the beta schedule is trainable (default: False).
|
|
86
|
+
`beta_method` : str, optional
|
|
87
|
+
Method for computing the beta schedule (default: "linear").
|
|
88
|
+
Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time", "cosine".
|
|
89
|
+
"""
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
eta: Optional[float] = None,
|
|
93
|
+
num_steps: int = 1000,
|
|
94
|
+
tau_num_steps: int = 100,
|
|
95
|
+
beta_start: float = 1e-4,
|
|
96
|
+
beta_end: float = 0.02,
|
|
97
|
+
trainable_beta: bool = False,
|
|
98
|
+
beta_method: str = "linear"
|
|
99
|
+
) -> None:
|
|
100
|
+
super().__init__()
|
|
101
|
+
self.eta = eta or 0
|
|
102
|
+
self.num_steps = num_steps
|
|
103
|
+
self.tau_num_steps = tau_num_steps
|
|
104
|
+
self.beta_start = beta_start
|
|
105
|
+
self.beta_end = beta_end
|
|
106
|
+
self.trainable_beta = trainable_beta
|
|
107
|
+
self.beta_method = beta_method
|
|
108
|
+
|
|
109
|
+
if not (0 < beta_start < beta_end < 1):
|
|
110
|
+
raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
|
|
111
|
+
if num_steps <= 0:
|
|
112
|
+
raise ValueError(f"num_steps ({num_steps}) must be positive")
|
|
113
|
+
|
|
114
|
+
beta_range = (beta_start, beta_end)
|
|
115
|
+
betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
|
|
116
|
+
|
|
117
|
+
if trainable_beta:
|
|
118
|
+
self.beta_raw = nn.Parameter(torch.logit((betas_init - beta_start) / (beta_end - beta_start)))
|
|
119
|
+
else:
|
|
120
|
+
self.register_buffer('betas_buffer', betas_init)
|
|
121
|
+
self.register_buffer('alphas', 1 - self.betas)
|
|
122
|
+
self.register_buffer('alpha_cumprod', torch.cumprod(self.alphas, dim=0))
|
|
123
|
+
self.register_buffer('sqrt_alpha_cumprod', torch.sqrt(self.alpha_cumprod))
|
|
124
|
+
self.register_buffer('sqrt_one_minus_alpha_cumprod', torch.sqrt(1 - self.alpha_cumprod))
|
|
125
|
+
|
|
126
|
+
self.register_buffer('tau_indices', torch.linspace(0, num_steps - 1, tau_num_steps, dtype=torch.long))
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def betas(self) -> torch.Tensor:
|
|
130
|
+
"""Returns the beta values, applying reparameterization if trainable.
|
|
131
|
+
|
|
132
|
+
Returns the beta values, using sigmoid reparameterization for trainable betas
|
|
133
|
+
or directly accessing the stored buffer for fixed betas.
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
betas : torch.Tensor
|
|
138
|
+
Beta values, shape (num_steps,).
|
|
139
|
+
"""
|
|
140
|
+
if self.trainable_beta:
|
|
141
|
+
return self.beta_start + (self.beta_end - self.beta_start) * torch.sigmoid(self.beta_raw)
|
|
142
|
+
return self._buffers['betas_buffer']
|
|
143
|
+
|
|
144
|
+
def compute_beta_schedule(self, beta_range: Tuple[float, float], num_steps: int, method: str) -> torch.Tensor:
|
|
145
|
+
"""Computes the beta schedule based on the specified method.
|
|
146
|
+
|
|
147
|
+
Generates a sequence of beta values for the noise schedule using the chosen method,
|
|
148
|
+
ensuring values are clamped within the specified range. Supports linear, sigmoid,
|
|
149
|
+
quadratic, constant, inverse_time, and cosine schedules.
|
|
150
|
+
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
`beta_range` : tuple
|
|
154
|
+
Tuple of (min_beta, max_beta) specifying the valid range for beta values.
|
|
155
|
+
`num_steps` : int
|
|
156
|
+
Number of diffusion steps.
|
|
157
|
+
`method` : str
|
|
158
|
+
Method for computing the beta schedule. Supported methods:
|
|
159
|
+
"linear", "sigmoid", "quadratic", "constant", "inverse_time", "cosine".
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
beta : torch.Tensor
|
|
164
|
+
Tensor of beta values, shape (num_steps,).
|
|
165
|
+
"""
|
|
166
|
+
beta_min, beta_max = beta_range
|
|
167
|
+
if method == "sigmoid":
|
|
168
|
+
x = torch.linspace(-6, 6, num_steps)
|
|
169
|
+
beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
|
|
170
|
+
elif method == "quadratic":
|
|
171
|
+
x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
|
|
172
|
+
beta = x ** 2
|
|
173
|
+
elif method == "constant":
|
|
174
|
+
beta = torch.full((num_steps,), beta_max)
|
|
175
|
+
elif method == "inverse_time":
|
|
176
|
+
beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
|
|
177
|
+
beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
|
|
178
|
+
elif method == "linear":
|
|
179
|
+
beta = torch.linspace(beta_min, beta_max, num_steps)
|
|
180
|
+
elif method == "cosine":
|
|
181
|
+
s = 0.008
|
|
182
|
+
steps = num_steps + 1
|
|
183
|
+
x = torch.linspace(0, num_steps, steps)
|
|
184
|
+
alphas_cumprod = torch.cos(((x / num_steps) + s) / (1 + s) * math.pi * 0.5) ** 2
|
|
185
|
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
|
186
|
+
beta = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
|
187
|
+
else:
|
|
188
|
+
raise ValueError(f"Unknown beta_method: {method}")
|
|
189
|
+
beta = torch.clamp(beta, min=beta_min, max=beta_max)
|
|
190
|
+
return beta
|
|
191
|
+
|
|
192
|
+
def get_tau_schedule(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
193
|
+
"""Computes the subsampled (tau) noise schedule for UnCLIP.
|
|
194
|
+
|
|
195
|
+
Returns the noise schedule parameters for the subsampled time steps used in
|
|
196
|
+
UnCLIP sampling, based on the `tau_indices`.
|
|
197
|
+
|
|
198
|
+
Returns
|
|
199
|
+
-------
|
|
200
|
+
tau_betas : torch.Tensor
|
|
201
|
+
Beta values for subsampled steps, shape (tau_num_steps,).
|
|
202
|
+
tau_alphas : torch.Tensor
|
|
203
|
+
Alpha values for subsampled steps, shape (tau_num_steps,).
|
|
204
|
+
tau_alpha_cumprod : torch.Tensor
|
|
205
|
+
Cumulative product of alphas for subsampled steps, shape (tau_num_steps,).
|
|
206
|
+
tau_sqrt_alpha_cumprod : torch.Tensor
|
|
207
|
+
Square root of alpha_cumprod for subsampled steps, shape (tau_num_steps,).
|
|
208
|
+
tau_sqrt_one_minus_alpha_cumprod : torch.Tensor
|
|
209
|
+
Square root of (1 - alpha_cumprod) for subsampled steps, shape (tau_num_steps,).
|
|
210
|
+
"""
|
|
211
|
+
if self.trainable_beta:
|
|
212
|
+
betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod = self.compute_schedule()
|
|
213
|
+
else:
|
|
214
|
+
betas = self.betas
|
|
215
|
+
alphas = self.alphas
|
|
216
|
+
alpha_cumprod = self.alpha_cumprod
|
|
217
|
+
sqrt_alpha_cumprod = self.sqrt_alpha_cumprod
|
|
218
|
+
sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod
|
|
219
|
+
|
|
220
|
+
tau_betas = betas[self.tau_indices]
|
|
221
|
+
tau_alphas = alphas[self.tau_indices]
|
|
222
|
+
tau_alpha_cumprod = alpha_cumprod[self.tau_indices]
|
|
223
|
+
tau_sqrt_alpha_cumprod = sqrt_alpha_cumprod[self.tau_indices]
|
|
224
|
+
tau_sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[self.tau_indices]
|
|
225
|
+
|
|
226
|
+
return tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod
|
|
227
|
+
|
|
228
|
+
def compute_schedule(self, time_steps: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
229
|
+
"""Computes noise schedule parameters dynamically from betas.
|
|
230
|
+
|
|
231
|
+
Calculates the derived noise schedule parameters (alphas, alpha_cumprod, etc.)
|
|
232
|
+
from the provided beta values for the UnCLIP diffusion process.
|
|
233
|
+
|
|
234
|
+
Parameters
|
|
235
|
+
----------
|
|
236
|
+
`time_steps` : torch.Tensor, optional
|
|
237
|
+
If provided, returns parameters only for specified time steps.
|
|
238
|
+
If None, returns parameters for all time steps.
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
betas : torch.Tensor
|
|
243
|
+
Beta values, shape (num_steps,) or (len(time_steps),).
|
|
244
|
+
alphas : torch.Tensor
|
|
245
|
+
1 - betas, shape (num_steps,) or (len(time_steps),).
|
|
246
|
+
alpha_cumprod : torch.Tensor
|
|
247
|
+
Cumulative product of alphas, shape (num_steps,) or (len(time_steps),).
|
|
248
|
+
sqrt_alpha_cumprod : torch.Tensor
|
|
249
|
+
Square root of alpha_cumprod, shape (num_steps,) or (len(time_steps),).
|
|
250
|
+
sqrt_one_minus_alpha_cumprod : torch.Tensor
|
|
251
|
+
Square root of (1 - alpha_cumprod), shape (num_steps,) or (len(time_steps),).
|
|
252
|
+
"""
|
|
253
|
+
betas = self.betas
|
|
254
|
+
alphas = 1 - betas
|
|
255
|
+
alpha_cumprod = torch.cumprod(alphas, dim=0)
|
|
256
|
+
sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod)
|
|
257
|
+
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alpha_cumprod)
|
|
258
|
+
if time_steps is not None:
|
|
259
|
+
return (betas[time_steps], alphas[time_steps], alpha_cumprod[time_steps],
|
|
260
|
+
sqrt_alpha_cumprod[time_steps], sqrt_one_minus_alpha_cumprod[time_steps])
|
|
261
|
+
return betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod
|
|
262
|
+
|
|
263
|
+
###==================================================================================================================###
|
|
264
|
+
|
|
265
|
+
class ForwardUnCLIP(nn.Module):
|
|
266
|
+
"""Forward diffusion process for UnCLIP diffusion models.
|
|
267
|
+
|
|
268
|
+
Applies Gaussian noise to input data (2D or 4D tensors) according to the UnCLIP
|
|
269
|
+
forward diffusion process at specified time steps, using cumulative noise schedule
|
|
270
|
+
parameters from the variance scheduler.
|
|
271
|
+
|
|
272
|
+
Parameters
|
|
273
|
+
----------
|
|
274
|
+
`variance_scheduler` : torch.nn.Module
|
|
275
|
+
Variance scheduler module (e.g., VarianceSchedulerUnCLIP) containing the noise
|
|
276
|
+
schedule parameters.
|
|
277
|
+
"""
|
|
278
|
+
def __init__(self, variance_scheduler: torch.nn.Module) -> None:
|
|
279
|
+
super().__init__()
|
|
280
|
+
self.variance_scheduler = variance_scheduler
|
|
281
|
+
|
|
282
|
+
def forward(self, x0: torch.Tensor, noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
|
|
283
|
+
"""Applies the forward diffusion process to the input data.
|
|
284
|
+
|
|
285
|
+
Perturbs the input data `x0` by adding Gaussian noise at specified time steps,
|
|
286
|
+
supporting both 2D (e.g., latent embeddings) and 4D (e.g., image) inputs.
|
|
287
|
+
|
|
288
|
+
Parameters
|
|
289
|
+
----------
|
|
290
|
+
`x0` : torch.Tensor
|
|
291
|
+
Input data tensor, shape (batch_size, embedding_dim) for 2D or
|
|
292
|
+
(batch_size, channels, height, width) for 4D.
|
|
293
|
+
`noise` : torch.Tensor
|
|
294
|
+
Gaussian noise tensor, same shape as `x0`.
|
|
295
|
+
`time_steps` : torch.Tensor
|
|
296
|
+
Tensor of time step indices (long), shape (batch_size,),
|
|
297
|
+
where each value is in the range [0, variance_scheduler.num_steps - 1].
|
|
298
|
+
|
|
299
|
+
Returns
|
|
300
|
+
-------
|
|
301
|
+
xt : torch.Tensor
|
|
302
|
+
Noisy data tensor at the specified time steps, same shape as `x0`.
|
|
303
|
+
"""
|
|
304
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.num_steps)):
|
|
305
|
+
raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.num_steps - 1}")
|
|
306
|
+
|
|
307
|
+
if self.variance_scheduler.trainable_beta:
|
|
308
|
+
_, _, _, sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.compute_schedule(
|
|
309
|
+
time_steps
|
|
310
|
+
)
|
|
311
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.to(x0.device)
|
|
312
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.to(x0.device)
|
|
313
|
+
else:
|
|
314
|
+
sqrt_alpha_cumprod_t = self.variance_scheduler.sqrt_alpha_cumprod[time_steps].to(x0.device)
|
|
315
|
+
sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.sqrt_one_minus_alpha_cumprod[time_steps].to(x0.device)
|
|
316
|
+
|
|
317
|
+
# check input dimensions and adjust reshaping for 2D or 4D tensors
|
|
318
|
+
is_2d = x0.dim() == 2 # check if input is 2D (batch_size, embedding_dim)
|
|
319
|
+
if is_2d:
|
|
320
|
+
# for 2D inputs, reshape to [batch_size, 1]
|
|
321
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1)
|
|
322
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1)
|
|
323
|
+
else:
|
|
324
|
+
# for 4D inputs, reshape to [batch_size, 1, 1, 1]
|
|
325
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1, 1, 1)
|
|
326
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1, 1, 1)
|
|
327
|
+
|
|
328
|
+
xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
|
|
329
|
+
return xt
|
|
330
|
+
|
|
331
|
+
###==================================================================================================================###
|
|
332
|
+
|
|
333
|
+
class ReverseUnCLIP(nn.Module):
|
|
334
|
+
"""Reverse diffusion process for UnCLIP diffusion models.
|
|
335
|
+
|
|
336
|
+
Denoises a noisy input `xt` using either a predicted noise component or predicted clean image
|
|
337
|
+
and a subsampled time step schedule, supporting both 2D (e.g., latent embeddings) and 4D (e.g., image) inputs.
|
|
338
|
+
|
|
339
|
+
Parameters
|
|
340
|
+
----------
|
|
341
|
+
`variance_scheduler` : torch.nn.Module
|
|
342
|
+
Variance scheduler module (e.g., VarianceSchedulerUnCLIP) containing the noise
|
|
343
|
+
schedule parameters.
|
|
344
|
+
`prediction_type` : str, default "noise"
|
|
345
|
+
Type of prediction the model makes. Either "noise" (predicts noise like DDIM) or
|
|
346
|
+
"x0" (predicts clean image like UnCLIP prior).
|
|
347
|
+
"""
|
|
348
|
+
|
|
349
|
+
def __init__(self, variance_scheduler: torch.nn.Module, prediction_type: str = "noise"):
|
|
350
|
+
super().__init__()
|
|
351
|
+
self.variance_scheduler = variance_scheduler
|
|
352
|
+
if prediction_type not in ["noise", "x0"]:
|
|
353
|
+
raise ValueError(f"prediction_type must be either 'noise' or 'x0', got {prediction_type}")
|
|
354
|
+
self.prediction_type = prediction_type
|
|
355
|
+
|
|
356
|
+
def forward(
|
|
357
|
+
self,
|
|
358
|
+
xt: torch.Tensor,
|
|
359
|
+
model_prediction: torch.Tensor,
|
|
360
|
+
time_steps: torch.Tensor,
|
|
361
|
+
prev_time_steps: torch.Tensor
|
|
362
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
363
|
+
"""Applies the reverse diffusion process to the noisy input.
|
|
364
|
+
|
|
365
|
+
Denoises the input `xt` at time step `t` to produce the previous step `xt_prev`
|
|
366
|
+
at `prev_time_steps` using either the predicted noise or predicted clean image
|
|
367
|
+
and the UnCLIP reverse process. Supports both 2D and 4D inputs.
|
|
368
|
+
|
|
369
|
+
Parameters
|
|
370
|
+
----------
|
|
371
|
+
`xt` : torch.Tensor
|
|
372
|
+
Noisy input tensor at time step `t`, shape (batch_size, embedding_dim) for 2D
|
|
373
|
+
or (batch_size, channels, height, width) for 4D.
|
|
374
|
+
`model_prediction` : torch.Tensor
|
|
375
|
+
Model prediction tensor, same shape as `xt`. Can be either predicted noise
|
|
376
|
+
or predicted clean image depending on `prediction_type`.
|
|
377
|
+
`time_steps` : torch.Tensor
|
|
378
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
379
|
+
is in the range [0, variance_scheduler.tau_num_steps - 1].
|
|
380
|
+
`prev_time_steps` : torch.Tensor
|
|
381
|
+
Tensor of previous time step indices (long), shape (batch_size,), where each
|
|
382
|
+
value is in the range [0, variance_scheduler.tau_num_steps - 1].
|
|
383
|
+
|
|
384
|
+
Returns
|
|
385
|
+
-------
|
|
386
|
+
xt_prev : torch.Tensor
|
|
387
|
+
Denoised tensor at `prev_time_steps`, same shape as `xt`.
|
|
388
|
+
x0 : torch.Tensor
|
|
389
|
+
Estimated original data (t=0), same shape as `xt`.
|
|
390
|
+
"""
|
|
391
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.tau_num_steps)):
|
|
392
|
+
raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
|
|
393
|
+
if not torch.all((prev_time_steps >= 0) & (prev_time_steps < self.variance_scheduler.tau_num_steps)):
|
|
394
|
+
raise ValueError(f"prev_time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
|
|
395
|
+
|
|
396
|
+
_, _, _, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = self.variance_scheduler.get_tau_schedule()
|
|
397
|
+
|
|
398
|
+
# check input dimensions and adjust reshaping for 2D or 4D tensors
|
|
399
|
+
is_2d = xt.dim() == 2 # check if input is 2D (batch_size, embedding_dim)
|
|
400
|
+
if is_2d:
|
|
401
|
+
# for 2D inputs, reshape to [batch_size, 1]
|
|
402
|
+
tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1)
|
|
403
|
+
tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1)
|
|
404
|
+
prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1)
|
|
405
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1)
|
|
406
|
+
else:
|
|
407
|
+
# for 4D inputs, reshape to [batch_size, 1, 1, 1]
|
|
408
|
+
tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
409
|
+
tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
410
|
+
prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
411
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
412
|
+
|
|
413
|
+
eta = self.variance_scheduler.eta
|
|
414
|
+
|
|
415
|
+
predicted_noise = None
|
|
416
|
+
x0 = None
|
|
417
|
+
# handle different prediction types
|
|
418
|
+
if self.prediction_type == "noise":
|
|
419
|
+
# model predicts noise
|
|
420
|
+
predicted_noise = model_prediction
|
|
421
|
+
x0 = (xt - tau_sqrt_one_minus_alpha_cumprod_t * predicted_noise) / tau_sqrt_alpha_cumprod_t
|
|
422
|
+
elif self.prediction_type == "x0":
|
|
423
|
+
# model predicts clean image
|
|
424
|
+
x0 = model_prediction
|
|
425
|
+
# calculate implied noise from the predicted clean image
|
|
426
|
+
predicted_noise = (xt - tau_sqrt_alpha_cumprod_t * x0) / tau_sqrt_one_minus_alpha_cumprod_t
|
|
427
|
+
|
|
428
|
+
# DDIM sampling step (same for both prediction types)
|
|
429
|
+
noise_coeff = eta * ((tau_sqrt_one_minus_alpha_cumprod_t / prev_tau_sqrt_alpha_cumprod_t) *
|
|
430
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t / torch.clamp(tau_sqrt_one_minus_alpha_cumprod_t, min=1e-8))
|
|
431
|
+
direction_coeff = torch.clamp(prev_tau_sqrt_one_minus_alpha_cumprod_t ** 2 - noise_coeff ** 2, min=1e-8).sqrt()
|
|
432
|
+
xt_prev = prev_tau_sqrt_alpha_cumprod_t * x0 + noise_coeff * torch.randn_like(xt) + direction_coeff * predicted_noise
|
|
433
|
+
|
|
434
|
+
return xt_prev, x0
|
|
435
|
+
|
|
436
|
+
def set_prediction_type(self, prediction_type: str):
|
|
437
|
+
"""Change the prediction type after initialization.
|
|
438
|
+
|
|
439
|
+
Parameters
|
|
440
|
+
----------
|
|
441
|
+
prediction_type : str
|
|
442
|
+
Type of prediction the model makes. Either "noise" or "x0".
|
|
443
|
+
"""
|
|
444
|
+
if prediction_type not in ["noise", "x0"]:
|
|
445
|
+
raise ValueError(f"prediction_type must be either 'noise' or 'x0', got {prediction_type}")
|
|
446
|
+
self.prediction_type = prediction_type
|
|
447
|
+
|
|
448
|
+
###==================================================================================================================###
|
|
449
|
+
|
|
450
|
+
class CLIPEncoder(nn.Module):
|
|
451
|
+
"""Encodes images or text using a pre-trained CLIP model.
|
|
452
|
+
|
|
453
|
+
Loads a CLIP model and processor from the transformers library, providing methods to
|
|
454
|
+
encode images or text into embeddings and compute similarity scores between them.
|
|
455
|
+
|
|
456
|
+
Parameters
|
|
457
|
+
----------
|
|
458
|
+
`model_name` : str, optional
|
|
459
|
+
Name of the CLIP model to load (default: 'openai/clip-vit-base-patch32').
|
|
460
|
+
`device` : str, optional
|
|
461
|
+
Device to run the model on (default: 'cuda' if available, else 'cpu').
|
|
462
|
+
`use_fast` : bool, optional
|
|
463
|
+
Whether to use the fast image processor (torchvision-based) (default: False).
|
|
464
|
+
"""
|
|
465
|
+
def __init__(
|
|
466
|
+
self,
|
|
467
|
+
model_name: str = "openai/clip-vit-base-patch32",
|
|
468
|
+
device: Optional[str] = None,
|
|
469
|
+
use_fast: bool = False,
|
|
470
|
+
) -> None:
|
|
471
|
+
super().__init__()
|
|
472
|
+
|
|
473
|
+
# set model name and device
|
|
474
|
+
self.model_name = model_name
|
|
475
|
+
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
476
|
+
|
|
477
|
+
try:
|
|
478
|
+
# load CLIP model and processor
|
|
479
|
+
self.model = CLIPModel.from_pretrained(self.model_name)
|
|
480
|
+
self.processor = CLIPProcessor.from_pretrained(self.model_name, use_fast=use_fast)
|
|
481
|
+
self.model = self.model.to(self.device)
|
|
482
|
+
except Exception as e:
|
|
483
|
+
raise RuntimeError(f"Failed to load CLIP model or processor for {self.model_name}: {e}")
|
|
484
|
+
|
|
485
|
+
# set model to evaluation mode by default
|
|
486
|
+
self.model.eval()
|
|
487
|
+
|
|
488
|
+
def forward(
|
|
489
|
+
self,
|
|
490
|
+
data: Union[torch.Tensor, List[str], str, Image.Image, List[Image.Image]],
|
|
491
|
+
data_type: str,
|
|
492
|
+
normalize: bool = True
|
|
493
|
+
) -> torch.Tensor:
|
|
494
|
+
"""Encodes input data (image or text) using the CLIP model.
|
|
495
|
+
|
|
496
|
+
Processes input data (images or text) to produce embeddings, with optional L2
|
|
497
|
+
normalization.
|
|
498
|
+
|
|
499
|
+
Parameters
|
|
500
|
+
----------
|
|
501
|
+
`data` : Union[torch.Tensor, List[str], str, Image.Image, List[Image.Image]]
|
|
502
|
+
Input data to encode:
|
|
503
|
+
- torch.Tensor: Preprocessed image tensor (batch_size, channels, height, width).
|
|
504
|
+
- List[str] or str: Text or list of texts.
|
|
505
|
+
- PIL.Image.Image or List[PIL.Image.Image]: Single or list of PIL images.
|
|
506
|
+
`data_type` : str
|
|
507
|
+
Type of input data ('img' or 'text').
|
|
508
|
+
`normalize` : bool, optional
|
|
509
|
+
Whether to L2-normalize the output embeddings (default: True).
|
|
510
|
+
|
|
511
|
+
Returns
|
|
512
|
+
-------
|
|
513
|
+
outputs : torch.Tensor
|
|
514
|
+
Encoded embeddings, shape (batch_size, embedding_dim).
|
|
515
|
+
"""
|
|
516
|
+
if data_type not in ["img", "text"]:
|
|
517
|
+
raise ValueError(f"Invalid data_type: {data_type}. Must be 'img' or 'text'.")
|
|
518
|
+
|
|
519
|
+
with torch.no_grad():
|
|
520
|
+
if data_type == "img":
|
|
521
|
+
outputs = self._encode_images(data)
|
|
522
|
+
else:
|
|
523
|
+
outputs = self._encode_texts(data)
|
|
524
|
+
|
|
525
|
+
# normalize embeddings if requested
|
|
526
|
+
if normalize:
|
|
527
|
+
outputs = F.normalize(outputs, p=2, dim=-1)
|
|
528
|
+
|
|
529
|
+
return outputs
|
|
530
|
+
|
|
531
|
+
def _encode_images(self, data: Union[torch.Tensor, Image.Image, List[Image.Image]]) -> torch.Tensor:
|
|
532
|
+
"""Encodes images into embeddings using the CLIP model.
|
|
533
|
+
|
|
534
|
+
Processes image inputs (tensors or PIL images) to produce image embeddings.
|
|
535
|
+
|
|
536
|
+
Parameters
|
|
537
|
+
----------
|
|
538
|
+
`data` : Union[torch.Tensor, Image.Image, List[Image.Image]]
|
|
539
|
+
Input images as a tensor or PIL image(s).
|
|
540
|
+
|
|
541
|
+
Returns
|
|
542
|
+
-------
|
|
543
|
+
image_features : torch.Tensor
|
|
544
|
+
Image embeddings, shape (batch_size, embedding_dim).
|
|
545
|
+
"""
|
|
546
|
+
if isinstance(data, torch.Tensor):
|
|
547
|
+
if data.dim() == 3:
|
|
548
|
+
data = data.unsqueeze(0)
|
|
549
|
+
inputs = {"pixel_values": data.to(self.device)}
|
|
550
|
+
elif isinstance(data, (Image.Image, list)):
|
|
551
|
+
if isinstance(data, Image.Image):
|
|
552
|
+
data = [data]
|
|
553
|
+
inputs = self.processor(images=data, return_tensors="pt", padding=True)
|
|
554
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
555
|
+
else:
|
|
556
|
+
raise ValueError(f"Invalid image data type: {type(data)}. Expected torch.Tensor, PIL.Image.Image, or List[PIL.Image.Image].")
|
|
557
|
+
return self.model.get_image_features(**inputs)
|
|
558
|
+
|
|
559
|
+
def _encode_texts(self, data: Union[str, List[str], torch.Tensor]) -> torch.Tensor:
|
|
560
|
+
"""Encodes texts into embeddings using the CLIP model.
|
|
561
|
+
|
|
562
|
+
Processes text inputs (strings or tokenized tensors) to produce text embeddings.
|
|
563
|
+
|
|
564
|
+
Parameters
|
|
565
|
+
----------
|
|
566
|
+
`data` : Union[str, List[str], torch.Tensor]
|
|
567
|
+
Input texts as strings or tokenized tensor.
|
|
568
|
+
|
|
569
|
+
Returns
|
|
570
|
+
-------
|
|
571
|
+
text_features : torch.Tensor
|
|
572
|
+
Text embeddings, shape (batch_size, embedding_dim).
|
|
573
|
+
"""
|
|
574
|
+
if isinstance(data, torch.Tensor):
|
|
575
|
+
data = data.to(self.device)
|
|
576
|
+
if data.dim() == 2:
|
|
577
|
+
return data
|
|
578
|
+
if data.dim() == 1:
|
|
579
|
+
data = data.unsqueeze(0)
|
|
580
|
+
attention_mask = torch.ones_like(data)
|
|
581
|
+
return self.model.get_text_features(input_ids=data, attention_mask=attention_mask)
|
|
582
|
+
|
|
583
|
+
if isinstance(data, str):
|
|
584
|
+
data = [data]
|
|
585
|
+
elif isinstance(data, list) and all(isinstance(t, str) for t in data):
|
|
586
|
+
pass
|
|
587
|
+
else:
|
|
588
|
+
raise ValueError(
|
|
589
|
+
f"Invalid text data type: {type(data)}. Expected str, List[str], or torch.Tensor."
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
inputs = self.processor(text=data, return_tensors="pt", padding=True, truncation=True)
|
|
593
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
594
|
+
return self.model.get_text_features(**inputs)
|
|
595
|
+
|
|
596
|
+
def compute_similarity(self, image_features: torch.Tensor, text_features: torch.Tensor) -> torch.Tensor:
|
|
597
|
+
"""Computes cosine similarity between image and text embeddings.
|
|
598
|
+
|
|
599
|
+
Calculates the cosine similarity matrix between batches of image and text embeddings.
|
|
600
|
+
|
|
601
|
+
Parameters
|
|
602
|
+
----------
|
|
603
|
+
`image_features` : torch.Tensor
|
|
604
|
+
Image embeddings, shape (batch_size, embedding_dim).
|
|
605
|
+
`text_features` : torch.Tensor
|
|
606
|
+
Text embeddings, shape (batch_size, embedding_dim).
|
|
607
|
+
|
|
608
|
+
Returns
|
|
609
|
+
-------
|
|
610
|
+
similarity : torch.Tensor
|
|
611
|
+
Cosine similarity scores, shape (batch_size, batch_size).
|
|
612
|
+
"""
|
|
613
|
+
image_features = F.normalize(image_features, p=2, dim=-1)
|
|
614
|
+
text_features = F.normalize(text_features, p=2, dim=-1)
|
|
615
|
+
return torch.matmul(image_features, text_features.T)
|
|
616
|
+
|
|
617
|
+
###==================================================================================================================###
|
|
618
|
+
|
|
619
|
+
class UnClipDecoder(nn.Module):
|
|
620
|
+
"""Decoder for UnCLIP diffusion models.
|
|
621
|
+
|
|
622
|
+
Combines CLIP image embeddings and text embeddings to guide the denoising process,
|
|
623
|
+
using a noise predictor and diffusion processes. Incorporates classifier-free guidance,
|
|
624
|
+
text caption dropout, and projection of CLIP embeddings into context tokens.
|
|
625
|
+
|
|
626
|
+
Parameters
|
|
627
|
+
----------
|
|
628
|
+
`clip_embedding_dim` : int
|
|
629
|
+
Dimensionality of the input embeddings.
|
|
630
|
+
`noise_predictor` : nn.Module
|
|
631
|
+
Model to predict noise during the denoising process.
|
|
632
|
+
`forward_diffusion` : nn.Module
|
|
633
|
+
Forward diffusion module (e.g., ForwardUnCLIP) for adding noise.
|
|
634
|
+
`reverse_diffusion` : nn.Module
|
|
635
|
+
Reverse diffusion module (e.g., ReverseUnCLIP) for denoising.
|
|
636
|
+
`glide_text_encoder` : nn.Module, optional
|
|
637
|
+
GLIDE text encoder for processing text prompts, default None.
|
|
638
|
+
`bert_tokenizer` : BertTokenizer, optional
|
|
639
|
+
Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
|
|
640
|
+
`device` : Union[str, torch.device], optional
|
|
641
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
642
|
+
`image_output_range` : Tuple[float, float], optional
|
|
643
|
+
Range for clamping output images (default: (-1.0, 1.0)).
|
|
644
|
+
`normalize_clip_embeddings` : bool, optional
|
|
645
|
+
Whether to normalize outputs (default: True).
|
|
646
|
+
`classifier_free_prop` : float, optional
|
|
647
|
+
Probability for classifier-free guidance (default: 0.1, per paper).
|
|
648
|
+
`drop_caption` : float, optional
|
|
649
|
+
Probability for text caption dropout (default: 0.5, per paper).
|
|
650
|
+
`max_token_length` : int, optional
|
|
651
|
+
Maximum length for tokenized prompts (default: 77).
|
|
652
|
+
"""
|
|
653
|
+
def __init__(
|
|
654
|
+
self,
|
|
655
|
+
clip_embedding_dim: int,
|
|
656
|
+
noise_predictor: nn.Module,
|
|
657
|
+
forward_diffusion: nn.Module,
|
|
658
|
+
reverse_diffusion: nn.Module,
|
|
659
|
+
glide_text_encoder: torch.nn.Module = None, # GLIDE text encoder
|
|
660
|
+
bert_tokenizer: Optional[BertTokenizer] = None,
|
|
661
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
662
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0),
|
|
663
|
+
normalize_clip_embeddings: bool = True,
|
|
664
|
+
classifier_free_prop: float = 0.1, # paper specifies 10%
|
|
665
|
+
drop_caption: float = 0.5, # paper specifies 50%
|
|
666
|
+
max_token_length: int = 77 # max_token_length for tokenization
|
|
667
|
+
) -> None:
|
|
668
|
+
super().__init__()
|
|
669
|
+
|
|
670
|
+
if device is None:
|
|
671
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
672
|
+
elif isinstance(device, str):
|
|
673
|
+
self.device = torch.device(device)
|
|
674
|
+
else:
|
|
675
|
+
self.device = device
|
|
676
|
+
self.clip_embedding_dim = clip_embedding_dim
|
|
677
|
+
|
|
678
|
+
# core models
|
|
679
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
680
|
+
self.forward_diffusion = forward_diffusion.to(self.device)
|
|
681
|
+
self.reverse_diffusion = reverse_diffusion.to(self.device)
|
|
682
|
+
self.glide_text_encoder = glide_text_encoder.to(self.device) if glide_text_encoder else None
|
|
683
|
+
|
|
684
|
+
# paper: "projecting CLIP embeddings into four extra tokens of context"
|
|
685
|
+
self.clip_decoder_projection = CLIPContextProjection(
|
|
686
|
+
clip_embedding_dim=self.clip_embedding_dim,
|
|
687
|
+
num_tokens=4
|
|
688
|
+
).to(self.device)
|
|
689
|
+
self.clip_time_projection = nn.Linear(self.clip_embedding_dim, self.clip_embedding_dim).to(self.device)
|
|
690
|
+
|
|
691
|
+
# training parameters
|
|
692
|
+
self.image_output_range = image_output_range
|
|
693
|
+
self.normalize_clip_embeddings = normalize_clip_embeddings
|
|
694
|
+
self.classifier_free_prop = classifier_free_prop
|
|
695
|
+
self.drop_caption = drop_caption
|
|
696
|
+
self.max_token_length = max_token_length
|
|
697
|
+
|
|
698
|
+
# initialize tokenizer
|
|
699
|
+
if bert_tokenizer is None:
|
|
700
|
+
try:
|
|
701
|
+
self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
702
|
+
except Exception as e:
|
|
703
|
+
raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
def forward(
|
|
707
|
+
self,
|
|
708
|
+
image_embeddings: torch.Tensor,
|
|
709
|
+
text_embeddings: torch.Tensor,
|
|
710
|
+
images: torch.Tensor,
|
|
711
|
+
texts: torch.Tensor,
|
|
712
|
+
p_classifier_free: float,
|
|
713
|
+
p_text_drop: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
714
|
+
"""Processes embeddings and images to predict noise for training.
|
|
715
|
+
|
|
716
|
+
Applies classifier-free guidance and text dropout, projects CLIP image embeddings
|
|
717
|
+
into context tokens, encodes text with GLIDE, and predicts noise for the diffusion process.
|
|
718
|
+
|
|
719
|
+
Parameters
|
|
720
|
+
----------
|
|
721
|
+
`image_embeddings` : torch.Tensor
|
|
722
|
+
CLIP image embeddings, shape (batch_size, embedding_dim).
|
|
723
|
+
`text_embeddings` : torch.Tensor
|
|
724
|
+
CLIP text embeddings, shape (batch_size, embedding_dim).
|
|
725
|
+
`images` : torch.Tensor
|
|
726
|
+
Input images, shape (batch_size, channels, height, width).
|
|
727
|
+
`texts` : torch.Tensor
|
|
728
|
+
Text prompts for conditional generation.
|
|
729
|
+
`p_classifier_free` : float
|
|
730
|
+
Probability for applying classifier-free guidance.
|
|
731
|
+
`p_text_drop` : float
|
|
732
|
+
Probability for applying text caption dropout.
|
|
733
|
+
|
|
734
|
+
Returns
|
|
735
|
+
-------
|
|
736
|
+
predicted_noise : torch.Tensor
|
|
737
|
+
Predicted noise tensor, shape (batch_size, channels, height, width).
|
|
738
|
+
noise : torch.Tensor
|
|
739
|
+
Ground truth noise tensor, shape (batch_size, channels, height, width).
|
|
740
|
+
"""
|
|
741
|
+
|
|
742
|
+
image_embeddings = self._apply_classifier_free_guidance(image_embeddings, p_classifier_free)
|
|
743
|
+
text_embeddings = self._apply_text_dropout(text_embeddings, p_text_drop)
|
|
744
|
+
# project z_i to 4 tokens
|
|
745
|
+
c = self.clip_decoder_projection(image_embeddings)
|
|
746
|
+
# encode text with GLIDE
|
|
747
|
+
y_encoded = self._encode_text_with_glide(texts if text_embeddings is not None else None)
|
|
748
|
+
# concatenate embeddings
|
|
749
|
+
context = self._concatenate_embeddings(y_encoded, c)
|
|
750
|
+
# sample timestep and noise
|
|
751
|
+
t, noise = self._sample_timestep_and_noise(images.shape[0], images.shape)
|
|
752
|
+
# compute noisy image
|
|
753
|
+
noisy_images = self.forward_diffusion(images, noise, t)
|
|
754
|
+
clip_image_embedding = self.clip_time_projection(image_embeddings)
|
|
755
|
+
predicted_noise = self.noise_predictor(noisy_images, t, context, clip_image_embedding)
|
|
756
|
+
return predicted_noise, noise
|
|
757
|
+
|
|
758
|
+
def inference_forward(self, image_embeddings, prompt_embeddings):
|
|
759
|
+
pass
|
|
760
|
+
|
|
761
|
+
def _apply_classifier_free_guidance(self, image_embeddings: torch.Tensor, p_value: float) -> torch.Tensor:
|
|
762
|
+
"""Applies classifier-free guidance to image embeddings.
|
|
763
|
+
|
|
764
|
+
Sets image embeddings to zero with a specified probability to implement
|
|
765
|
+
classifier-free guidance, as described in the UnCLIP paper.
|
|
766
|
+
|
|
767
|
+
Parameters
|
|
768
|
+
----------
|
|
769
|
+
`image_embeddings` : torch.Tensor
|
|
770
|
+
CLIP image embeddings, shape (batch_size, embedding_dim).
|
|
771
|
+
`p_value` : float
|
|
772
|
+
Probability for applying classifier-free guidance.
|
|
773
|
+
|
|
774
|
+
Returns
|
|
775
|
+
-------
|
|
776
|
+
image_embeddings : torch.Tensor
|
|
777
|
+
Modified image embeddings, shape (batch_size, embedding_dim).
|
|
778
|
+
"""
|
|
779
|
+
if p_value < self.classifier_free_prop:
|
|
780
|
+
# set z_i ← 0 {classifier-free guidance}
|
|
781
|
+
image_embeddings = torch.zeros_like(image_embeddings)
|
|
782
|
+
|
|
783
|
+
return image_embeddings
|
|
784
|
+
|
|
785
|
+
def _apply_text_dropout(self, text_embeddings: torch.Tensor, p_value: float) -> Optional[torch.Tensor]:
|
|
786
|
+
"""Applies text caption dropout to text embeddings.
|
|
787
|
+
|
|
788
|
+
Drops text embeddings with a specified probability to implement text dropout,
|
|
789
|
+
as described in the UnCLIP paper.
|
|
790
|
+
|
|
791
|
+
Parameters
|
|
792
|
+
----------
|
|
793
|
+
`text_embeddings` : torch.Tensor
|
|
794
|
+
CLIP text embeddings, shape (batch_size, embedding_dim).
|
|
795
|
+
`p_value` : float
|
|
796
|
+
Probability for applying text caption dropout.
|
|
797
|
+
|
|
798
|
+
Returns
|
|
799
|
+
-------
|
|
800
|
+
text_embeddings : torch.Tensor or None
|
|
801
|
+
Modified text embeddings or None if dropped, shape (batch_size, embedding_dim).
|
|
802
|
+
"""
|
|
803
|
+
if p_value < self.drop_caption:
|
|
804
|
+
# set y ← ∅ {drop text caption}
|
|
805
|
+
return None
|
|
806
|
+
|
|
807
|
+
return text_embeddings
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
def _encode_text_with_glide(self, texts: Union[List, torch.Tensor]) -> Optional[torch.Tensor]:
|
|
811
|
+
"""Encodes text prompts using the GLIDE text encoder.
|
|
812
|
+
|
|
813
|
+
Tokenizes and encodes text prompts into embeddings using the GLIDE text encoder,
|
|
814
|
+
returning None if no text or conditional model is provided.
|
|
815
|
+
|
|
816
|
+
Parameters
|
|
817
|
+
----------
|
|
818
|
+
`texts` : Union[List, torch.Tensor]
|
|
819
|
+
Text prompts or tensor of text data.
|
|
820
|
+
|
|
821
|
+
Returns
|
|
822
|
+
-------
|
|
823
|
+
y_encoded : torch.Tensor or None
|
|
824
|
+
Encoded text embeddings, shape (batch_size, seq_len, embedding_dim), or None.
|
|
825
|
+
"""
|
|
826
|
+
if texts is None:
|
|
827
|
+
return None
|
|
828
|
+
|
|
829
|
+
if self.glide_text_encoder is None:
|
|
830
|
+
return None
|
|
831
|
+
|
|
832
|
+
# convert to string list if needed
|
|
833
|
+
if isinstance(texts, torch.Tensor):
|
|
834
|
+
texts = texts.cpu().numpy().tolist()
|
|
835
|
+
texts = [str(item) for item in texts]
|
|
836
|
+
|
|
837
|
+
# tokenize
|
|
838
|
+
tokenized = self.bert_tokenizer(
|
|
839
|
+
texts,
|
|
840
|
+
padding="max_length",
|
|
841
|
+
truncation=True,
|
|
842
|
+
max_length=self.max_token_length,
|
|
843
|
+
return_tensors="pt"
|
|
844
|
+
).to(self.device)
|
|
845
|
+
|
|
846
|
+
# get embeddings from GLIDE text encoder
|
|
847
|
+
input_ids = tokenized["input_ids"]
|
|
848
|
+
attention_mask = tokenized["attention_mask"]
|
|
849
|
+
y_encoded = self.glide_text_encoder(input_ids, attention_mask)
|
|
850
|
+
# print("y shape: ", y_encoded.size())
|
|
851
|
+
|
|
852
|
+
return y_encoded
|
|
853
|
+
|
|
854
|
+
def _concatenate_embeddings(self, y_encoded: Optional[torch.Tensor], c: torch.Tensor) -> torch.Tensor:
|
|
855
|
+
"""Concatenates GLIDE text embeddings and context tokens.
|
|
856
|
+
|
|
857
|
+
Combines encoded text embeddings (if available) with projected context tokens
|
|
858
|
+
along the sequence dimension, as specified in the UnCLIP paper.
|
|
859
|
+
|
|
860
|
+
Parameters
|
|
861
|
+
----------
|
|
862
|
+
`y_encoded` : torch.Tensor or None
|
|
863
|
+
Encoded text embeddings from GLIDE, shape (batch_size, seq_len, embedding_dim).
|
|
864
|
+
`c` : torch.Tensor
|
|
865
|
+
Projected context tokens, shape (batch_size, num_tokens, embedding_dim).
|
|
866
|
+
|
|
867
|
+
Returns
|
|
868
|
+
-------
|
|
869
|
+
s : torch.Tensor
|
|
870
|
+
Concatenated embeddings, shape (batch_size, seq_len + num_tokens, embedding_dim).
|
|
871
|
+
"""
|
|
872
|
+
if y_encoded is not None:
|
|
873
|
+
# ensure y_encoded has sequence dimension
|
|
874
|
+
if len(y_encoded.shape) == 2: # [batch_size, embed_dim]
|
|
875
|
+
y_encoded = y_encoded.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
876
|
+
|
|
877
|
+
# concatenate along the sequence dimension
|
|
878
|
+
s = torch.cat([y_encoded, c], dim=1) # [batch_size, seq_len + 4, embed_dim]
|
|
879
|
+
else:
|
|
880
|
+
s = c # [batch_size, 4, embed_dim]
|
|
881
|
+
|
|
882
|
+
return s
|
|
883
|
+
|
|
884
|
+
def _sample_timestep_and_noise(self, batch_size: int, image_shape: torch.Size) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
885
|
+
"""Samples timesteps and noise for the diffusion process.
|
|
886
|
+
|
|
887
|
+
Generates random timesteps and Gaussian noise for use in the forward diffusion process.
|
|
888
|
+
|
|
889
|
+
Parameters
|
|
890
|
+
----------
|
|
891
|
+
`batch_size` : int
|
|
892
|
+
Number of samples in the batch.
|
|
893
|
+
`image_shape` : torch.Size
|
|
894
|
+
Shape of the images, typically (batch_size, channels, height, width).
|
|
895
|
+
|
|
896
|
+
Returns
|
|
897
|
+
-------
|
|
898
|
+
t : torch.Tensor
|
|
899
|
+
Sampled timestep indices, shape (batch_size,).
|
|
900
|
+
noise : torch.Tensor
|
|
901
|
+
Sampled Gaussian noise, shape (batch_size, channels, height, width).
|
|
902
|
+
"""
|
|
903
|
+
# sample timestep t ~ Uniform(1, T)
|
|
904
|
+
t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
|
|
905
|
+
# sample noise ε ~ N(0, I)
|
|
906
|
+
noise = torch.randn(image_shape, device=self.device)
|
|
907
|
+
return t, noise
|
|
908
|
+
|
|
909
|
+
###==================================================================================================================###
|
|
910
|
+
|
|
911
|
+
class UnCLIPTransformerPrior(nn.Module):
|
|
912
|
+
"""Transformer-based prior model for UnCLIP diffusion.
|
|
913
|
+
|
|
914
|
+
Predicts clean image embeddings from noisy image embeddings and text embeddings using
|
|
915
|
+
a Transformer architecture, incorporating time embeddings and optional projection
|
|
916
|
+
layers for text and image inputs.
|
|
917
|
+
|
|
918
|
+
Parameters
|
|
919
|
+
----------
|
|
920
|
+
`forward_diffusion` : nn.Module
|
|
921
|
+
Forward diffusion module (e.g., ForwardUnCLIP) for adding noise during training.
|
|
922
|
+
`reverse_diffusion` : nn.Module
|
|
923
|
+
Reverse diffusion module (e.g., ReverseUnCLIP) for denoising during training.
|
|
924
|
+
`clip_text_projection` : nn.Module, optional
|
|
925
|
+
Projection module for text embeddings, default None.
|
|
926
|
+
`clip_image_projection` : nn.Module, optional
|
|
927
|
+
Projection module for image embeddings, default None.
|
|
928
|
+
`transformer_embedding_dim` : int, optional
|
|
929
|
+
Dimensionality of embeddings (default: 320).
|
|
930
|
+
`num_layers` : int, optional
|
|
931
|
+
Number of Transformer layers (default: 12).
|
|
932
|
+
`num_attention_heads` : int, optional
|
|
933
|
+
Number of attention heads in each Transformer layer (default: 8).
|
|
934
|
+
`feedforward_dim` : int, optional
|
|
935
|
+
Dimensionality of the feedforward network in Transformer layers (default: 768).
|
|
936
|
+
`max_sequence_length` : int, optional
|
|
937
|
+
Maximum sequence length for input embeddings (default: 2).
|
|
938
|
+
`dropout_rate` : float, optional
|
|
939
|
+
Dropout probability for regularization (default: 0.2).
|
|
940
|
+
"""
|
|
941
|
+
def __init__(
|
|
942
|
+
self,
|
|
943
|
+
forward_diffusion: nn.Module, # will be used during training
|
|
944
|
+
reverse_diffusion: nn.Module, # will be used during training
|
|
945
|
+
clip_text_projection: Optional[nn.Module] = None, # used during training instead of PCA in the main paper
|
|
946
|
+
clip_image_projection: Optional[nn.Module] = None, # used during training instead of PCA in the main paper
|
|
947
|
+
transformer_embedding_dim: int = 320,
|
|
948
|
+
num_layers: int = 12,
|
|
949
|
+
num_attention_heads: int = 8,
|
|
950
|
+
feedforward_dim: int = 768,
|
|
951
|
+
max_sequence_length: int = 2,
|
|
952
|
+
dropout_rate: float = 0.2
|
|
953
|
+
) -> None:
|
|
954
|
+
super().__init__()
|
|
955
|
+
|
|
956
|
+
self.forward_diffusion = forward_diffusion
|
|
957
|
+
self.reverse_diffusion = reverse_diffusion
|
|
958
|
+
self.clip_text_projection = clip_text_projection
|
|
959
|
+
self.clip_image_projection = clip_image_projection
|
|
960
|
+
|
|
961
|
+
self.transformer_embedding_dim = transformer_embedding_dim
|
|
962
|
+
self.max_sequence_length = max_sequence_length
|
|
963
|
+
|
|
964
|
+
# time embedding network
|
|
965
|
+
self.time_embedding_net = nn.Sequential(
|
|
966
|
+
nn.Linear(transformer_embedding_dim, transformer_embedding_dim),
|
|
967
|
+
nn.GELU(),
|
|
968
|
+
nn.Linear(transformer_embedding_dim, transformer_embedding_dim)
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
# positional embeddings
|
|
972
|
+
self.positional_embeddings = nn.Parameter(torch.randn(max_sequence_length, transformer_embedding_dim))
|
|
973
|
+
|
|
974
|
+
# transformer layers
|
|
975
|
+
self.transformer_blocks = nn.ModuleList([
|
|
976
|
+
TransformerBlock(transformer_embedding_dim, num_attention_heads, feedforward_dim, dropout_rate)
|
|
977
|
+
for _ in range(num_layers)
|
|
978
|
+
])
|
|
979
|
+
|
|
980
|
+
# final output projection
|
|
981
|
+
self.output_projection = nn.Linear(transformer_embedding_dim, transformer_embedding_dim)
|
|
982
|
+
|
|
983
|
+
def forward(
|
|
984
|
+
self,
|
|
985
|
+
text_embeddings: torch.Tensor,
|
|
986
|
+
noisy_image_embeddings: torch.Tensor,
|
|
987
|
+
timesteps: torch.Tensor
|
|
988
|
+
) -> torch.Tensor:
|
|
989
|
+
"""Predicts clean image embeddings from noisy inputs and text embeddings.
|
|
990
|
+
|
|
991
|
+
Processes text and noisy image embeddings through a Transformer architecture,
|
|
992
|
+
conditioned on time embeddings, to predict the clean image embeddings.
|
|
993
|
+
|
|
994
|
+
Parameters
|
|
995
|
+
----------
|
|
996
|
+
`text_embeddings` : torch.Tensor
|
|
997
|
+
Text embeddings, shape (batch_size, embedding_dim).
|
|
998
|
+
`noisy_image_embeddings` : torch.Tensor
|
|
999
|
+
Noisy image embeddings, shape (batch_size, embedding_dim).
|
|
1000
|
+
`timesteps` : torch.Tensor
|
|
1001
|
+
Tensor of time step indices (long), shape (batch_size,).
|
|
1002
|
+
|
|
1003
|
+
Returns
|
|
1004
|
+
-------
|
|
1005
|
+
predicted_clean_embeddings : torch.Tensor
|
|
1006
|
+
Predicted clean image embeddings, shape (batch_size, embedding_dim).
|
|
1007
|
+
"""
|
|
1008
|
+
device = text_embeddings.device
|
|
1009
|
+
# create sinusoidal time embeddings
|
|
1010
|
+
time_embeddings = self._get_sinusoidal_embeddings(timesteps, self.transformer_embedding_dim, device)
|
|
1011
|
+
time_embeddings = self.time_embedding_net(time_embeddings)
|
|
1012
|
+
# add time information to image embeddings
|
|
1013
|
+
conditioned_image_embeddings = noisy_image_embeddings + time_embeddings
|
|
1014
|
+
# create sequence: [text_embeddings, conditioned_image_embeddings]
|
|
1015
|
+
sequence = torch.stack([text_embeddings, conditioned_image_embeddings], dim=1) # [B, 2, D]
|
|
1016
|
+
# add positional embeddings
|
|
1017
|
+
sequence = sequence + self.positional_embeddings.unsqueeze(0)
|
|
1018
|
+
# pass through transformer blocks
|
|
1019
|
+
for transformer_block in self.transformer_blocks:
|
|
1020
|
+
sequence = transformer_block(sequence)
|
|
1021
|
+
# extract predicted clean image embedding (second position in sequence)
|
|
1022
|
+
predicted_clean_embeddings = sequence[:, 1, :] # [B, D]
|
|
1023
|
+
# apply final projection
|
|
1024
|
+
predicted_clean_embeddings = self.output_projection(predicted_clean_embeddings)
|
|
1025
|
+
|
|
1026
|
+
return predicted_clean_embeddings
|
|
1027
|
+
|
|
1028
|
+
def _get_sinusoidal_embeddings(
|
|
1029
|
+
self,
|
|
1030
|
+
timesteps: torch.Tensor,
|
|
1031
|
+
embedding_dim: int,
|
|
1032
|
+
device: Union[torch.device, str]
|
|
1033
|
+
) -> torch.Tensor:
|
|
1034
|
+
"""Generates sinusoidal positional embeddings for timesteps.
|
|
1035
|
+
|
|
1036
|
+
Creates sinusoidal embeddings for the given timesteps to condition the Transformer
|
|
1037
|
+
on the diffusion process time steps.
|
|
1038
|
+
|
|
1039
|
+
Parameters
|
|
1040
|
+
----------
|
|
1041
|
+
`timesteps` : torch.Tensor
|
|
1042
|
+
Tensor of time step indices (long), shape (batch_size,).
|
|
1043
|
+
`embedding_dim` : int
|
|
1044
|
+
Dimensionality of the embeddings.
|
|
1045
|
+
`device` : Union[torch.device, str]
|
|
1046
|
+
Device to place the embeddings on.
|
|
1047
|
+
|
|
1048
|
+
Returns
|
|
1049
|
+
-------
|
|
1050
|
+
embeddings : torch.Tensor
|
|
1051
|
+
Sinusoidal time embeddings, shape (batch_size, embedding_dim).
|
|
1052
|
+
"""
|
|
1053
|
+
half_dim = embedding_dim // 2
|
|
1054
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
1055
|
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
|
1056
|
+
emb = timesteps[:, None].float() * emb[None, :]
|
|
1057
|
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
1058
|
+
|
|
1059
|
+
# handle odd embedding dimensions
|
|
1060
|
+
if embedding_dim % 2 == 1:
|
|
1061
|
+
emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
|
|
1062
|
+
|
|
1063
|
+
return emb
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
class TransformerBlock(nn.Module):
|
|
1067
|
+
"""Single Transformer block with multi-head attention and feedforward layers.
|
|
1068
|
+
|
|
1069
|
+
Implements a Transformer block with multi-head self-attention, layer normalization,
|
|
1070
|
+
and a feedforward network with residual connections for processing sequences in
|
|
1071
|
+
the UnCLIPTransformerPrior model.
|
|
1072
|
+
|
|
1073
|
+
Parameters
|
|
1074
|
+
----------
|
|
1075
|
+
`embedding_dim` : int
|
|
1076
|
+
Dimensionality of input and output embeddings.
|
|
1077
|
+
`num_heads` : int
|
|
1078
|
+
Number of attention heads in the multi-head attention layer.
|
|
1079
|
+
`feedforward_dim` : int
|
|
1080
|
+
Dimensionality of the feedforward network.
|
|
1081
|
+
`dropout` : float
|
|
1082
|
+
Dropout probability for regularization.
|
|
1083
|
+
"""
|
|
1084
|
+
|
|
1085
|
+
def __init__(
|
|
1086
|
+
self,
|
|
1087
|
+
embedding_dim: int,
|
|
1088
|
+
num_heads: int,
|
|
1089
|
+
feedforward_dim: int,
|
|
1090
|
+
dropout: float
|
|
1091
|
+
) -> None:
|
|
1092
|
+
super().__init__()
|
|
1093
|
+
|
|
1094
|
+
self.self_attention = nn.MultiheadAttention(
|
|
1095
|
+
embedding_dim,
|
|
1096
|
+
num_heads,
|
|
1097
|
+
dropout=dropout,
|
|
1098
|
+
batch_first=True
|
|
1099
|
+
)
|
|
1100
|
+
self.attention_norm = nn.LayerNorm(embedding_dim)
|
|
1101
|
+
self.feedforward_norm = nn.LayerNorm(embedding_dim)
|
|
1102
|
+
|
|
1103
|
+
self.feedforward = nn.Sequential(
|
|
1104
|
+
nn.Linear(embedding_dim, feedforward_dim),
|
|
1105
|
+
nn.GELU(),
|
|
1106
|
+
nn.Dropout(dropout),
|
|
1107
|
+
nn.Linear(feedforward_dim, embedding_dim),
|
|
1108
|
+
nn.Dropout(dropout)
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1112
|
+
"""Processes input sequence through the Transformer block.
|
|
1113
|
+
|
|
1114
|
+
Applies multi-head self-attention followed by a feedforward network, with residual
|
|
1115
|
+
connections and layer normalization.
|
|
1116
|
+
|
|
1117
|
+
Parameters
|
|
1118
|
+
----------
|
|
1119
|
+
`x` : torch.Tensor
|
|
1120
|
+
Input sequence tensor, shape (batch_size, sequence_length, embedding_dim).
|
|
1121
|
+
|
|
1122
|
+
Returns
|
|
1123
|
+
-------
|
|
1124
|
+
output : torch.Tensor
|
|
1125
|
+
Processed sequence tensor, shape (batch_size, sequence_length, embedding_dim).
|
|
1126
|
+
"""
|
|
1127
|
+
# self-attention with residual connection
|
|
1128
|
+
attn_output, _ = self.self_attention(x, x, x)
|
|
1129
|
+
x = self.attention_norm(x + attn_output)
|
|
1130
|
+
|
|
1131
|
+
# feedforward with residual connection
|
|
1132
|
+
ff_output = self.feedforward(x)
|
|
1133
|
+
x = self.feedforward_norm(x + ff_output)
|
|
1134
|
+
|
|
1135
|
+
return x
|
|
1136
|
+
|
|
1137
|
+
###==================================================================================================================###
|
|
1138
|
+
|
|
1139
|
+
class CLIPContextProjection(nn.Module):
|
|
1140
|
+
"""Projects CLIP image embeddings into multiple context tokens.
|
|
1141
|
+
|
|
1142
|
+
Transforms a single CLIP image embedding into a specified number of context tokens
|
|
1143
|
+
using a linear projection followed by layer normalization.
|
|
1144
|
+
|
|
1145
|
+
Parameters
|
|
1146
|
+
----------
|
|
1147
|
+
`clip_embedding_dim` : int
|
|
1148
|
+
Dimensionality of the input CLIP embedding (e.g., 319 or 512).
|
|
1149
|
+
`num_tokens` : int, optional
|
|
1150
|
+
Number of context tokens to generate (default: 4).
|
|
1151
|
+
"""
|
|
1152
|
+
def __init__(self, clip_embedding_dim, num_tokens=4):
|
|
1153
|
+
super().__init__()
|
|
1154
|
+
self.clip_embedding_dim = clip_embedding_dim
|
|
1155
|
+
self.num_tokens = num_tokens
|
|
1156
|
+
self.clip_projection = nn.Linear(clip_embedding_dim, clip_embedding_dim * num_tokens)
|
|
1157
|
+
self.clip_embedding_norm = nn.LayerNorm(clip_embedding_dim)
|
|
1158
|
+
|
|
1159
|
+
def forward(self, z_i):
|
|
1160
|
+
"""Projects CLIP image embedding into context tokens.
|
|
1161
|
+
|
|
1162
|
+
Applies a linear projection to transform the input embedding into multiple tokens,
|
|
1163
|
+
reshapes the output, and applies layer normalization.
|
|
1164
|
+
|
|
1165
|
+
Parameters
|
|
1166
|
+
----------
|
|
1167
|
+
`z_i` : torch.Tensor
|
|
1168
|
+
Input CLIP image embedding, shape (batch_size, input_dim).
|
|
1169
|
+
|
|
1170
|
+
Returns
|
|
1171
|
+
-------
|
|
1172
|
+
c : torch.Tensor
|
|
1173
|
+
Context tokens, shape (batch_size, num_tokens, input_dim).
|
|
1174
|
+
"""
|
|
1175
|
+
batch_size = z_i.shape[0]
|
|
1176
|
+
projected = self.clip_projection(z_i)
|
|
1177
|
+
c = projected.view(batch_size, self.num_tokens, self.clip_embedding_dim)
|
|
1178
|
+
c = self.clip_embedding_norm(c)
|
|
1179
|
+
return c
|
|
1180
|
+
|
|
1181
|
+
###==================================================================================================================###
|
|
1182
|
+
|
|
1183
|
+
class CLIPEmbeddingProjection(nn.Module):
|
|
1184
|
+
"""Projection module for dimensionality reduction and reconstruction.
|
|
1185
|
+
|
|
1186
|
+
Implements a neural network with forward and inverse projections to reduce and
|
|
1187
|
+
restore input dimensionality, supporting customizable hidden layers, dropout, and
|
|
1188
|
+
layer normalization.
|
|
1189
|
+
|
|
1190
|
+
Parameters
|
|
1191
|
+
----------
|
|
1192
|
+
`clip_embedding_dim` : int, optional
|
|
1193
|
+
Input dimensionality (default: 1024).
|
|
1194
|
+
`transformer_embedding_dim` : int, optional
|
|
1195
|
+
Output dimensionality for forward projection (default: 320).
|
|
1196
|
+
`hidden_dim` : int, optional
|
|
1197
|
+
Hidden layer dimensionality (default: 512).
|
|
1198
|
+
`num_layers` : int, optional
|
|
1199
|
+
Number of layers in the projection network (default: 2).
|
|
1200
|
+
`dropout_rate` : float, optional
|
|
1201
|
+
Dropout probability for regularization (default: 0.2).
|
|
1202
|
+
`use_layer_norm` : bool, optional
|
|
1203
|
+
Whether to apply layer normalization after hidden layers (default: True).
|
|
1204
|
+
"""
|
|
1205
|
+
def __init__(
|
|
1206
|
+
self,
|
|
1207
|
+
clip_embedding_dim: int = 1024,
|
|
1208
|
+
transformer_embedding_dim: int = 320,
|
|
1209
|
+
hidden_dim: int = 512,
|
|
1210
|
+
num_layers: int = 2,
|
|
1211
|
+
dropout_rate: float = 0.2,
|
|
1212
|
+
use_layer_norm: bool = True
|
|
1213
|
+
) -> None:
|
|
1214
|
+
super().__init__()
|
|
1215
|
+
|
|
1216
|
+
self.clip_embedding_dim = clip_embedding_dim
|
|
1217
|
+
self.transformer_embedding_dim = transformer_embedding_dim
|
|
1218
|
+
|
|
1219
|
+
# Forward projection: input_dim -> output_dim
|
|
1220
|
+
self.forward_projection = self._build_projection_network(
|
|
1221
|
+
clip_embedding_dim, transformer_embedding_dim, hidden_dim, num_layers, dropout_rate, use_layer_norm
|
|
1222
|
+
)
|
|
1223
|
+
|
|
1224
|
+
# Inverse projection: output_dim -> input_dim
|
|
1225
|
+
self.inverse_projection = self._build_projection_network(
|
|
1226
|
+
transformer_embedding_dim, clip_embedding_dim, hidden_dim, num_layers, dropout_rate, use_layer_norm
|
|
1227
|
+
)
|
|
1228
|
+
def _build_projection_network(
|
|
1229
|
+
self,
|
|
1230
|
+
input_dim: int,
|
|
1231
|
+
output_dim: int,
|
|
1232
|
+
hidden_dim: int,
|
|
1233
|
+
num_layers: int,
|
|
1234
|
+
dropout: float,
|
|
1235
|
+
use_layer_norm: bool
|
|
1236
|
+
) -> nn.Sequential:
|
|
1237
|
+
"""Builds a projection network with customizable layers.
|
|
1238
|
+
|
|
1239
|
+
Constructs a neural network with linear layers, optional layer normalization,
|
|
1240
|
+
GELU activation, and dropout for either forward or inverse projection.
|
|
1241
|
+
|
|
1242
|
+
Parameters
|
|
1243
|
+
----------
|
|
1244
|
+
`input_dim` : int
|
|
1245
|
+
Input dimensionality for the network.
|
|
1246
|
+
`output_dim` : int
|
|
1247
|
+
Output dimensionality for the network.
|
|
1248
|
+
`hidden_dim` : int
|
|
1249
|
+
Hidden layer dimensionality.
|
|
1250
|
+
`num_layers` : int
|
|
1251
|
+
Number of layers in the network.
|
|
1252
|
+
`dropout` : float
|
|
1253
|
+
Dropout probability for regularization.
|
|
1254
|
+
`use_layer_norm` : bool
|
|
1255
|
+
Whether to apply layer normalization after hidden layers.
|
|
1256
|
+
|
|
1257
|
+
Returns
|
|
1258
|
+
-------
|
|
1259
|
+
network : nn.Sequential
|
|
1260
|
+
Sequential container of the projection network layers.
|
|
1261
|
+
"""
|
|
1262
|
+
layers = []
|
|
1263
|
+
current_dim = input_dim
|
|
1264
|
+
|
|
1265
|
+
# Hidden layers
|
|
1266
|
+
for i in range(num_layers - 1):
|
|
1267
|
+
layers.append(nn.Linear(current_dim, hidden_dim))
|
|
1268
|
+
if use_layer_norm:
|
|
1269
|
+
layers.append(nn.LayerNorm(hidden_dim))
|
|
1270
|
+
layers.append(nn.GELU())
|
|
1271
|
+
layers.append(nn.Dropout(dropout))
|
|
1272
|
+
current_dim = hidden_dim
|
|
1273
|
+
|
|
1274
|
+
# Output layer
|
|
1275
|
+
layers.append(nn.Linear(current_dim, output_dim))
|
|
1276
|
+
|
|
1277
|
+
return nn.Sequential(*layers)
|
|
1278
|
+
|
|
1279
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
1280
|
+
"""Projects input to a lower-dimensional space.
|
|
1281
|
+
|
|
1282
|
+
Applies the forward projection network to reduce the dimensionality of the input tensor.
|
|
1283
|
+
|
|
1284
|
+
Parameters
|
|
1285
|
+
----------
|
|
1286
|
+
`x` : torch.Tensor
|
|
1287
|
+
Input tensor to be projected, shape (batch_size, input_dim).
|
|
1288
|
+
|
|
1289
|
+
Returns
|
|
1290
|
+
-------
|
|
1291
|
+
x_reduced : torch.Tensor
|
|
1292
|
+
Projected tensor, shape (batch_size, output_dim).
|
|
1293
|
+
"""
|
|
1294
|
+
return self.forward_projection(x)
|
|
1295
|
+
|
|
1296
|
+
def inverse_transform(self, x_reduced: torch.Tensor) -> torch.Tensor:
|
|
1297
|
+
"""Reconstructs input from lower-dimensional space.
|
|
1298
|
+
|
|
1299
|
+
Applies the inverse projection network to restore the original dimensionality
|
|
1300
|
+
of the input tensor.
|
|
1301
|
+
|
|
1302
|
+
Parameters
|
|
1303
|
+
----------
|
|
1304
|
+
`x_reduced` : torch.Tensor
|
|
1305
|
+
Reduced-dimensionality tensor, shape (batch_size, output_dim).
|
|
1306
|
+
|
|
1307
|
+
Returns
|
|
1308
|
+
-------
|
|
1309
|
+
x_reconstructed : torch.Tensor
|
|
1310
|
+
Reconstructed tensor, shape (batch_size, input_dim).
|
|
1311
|
+
"""
|
|
1312
|
+
return self.inverse_projection(x_reduced)
|
|
1313
|
+
|
|
1314
|
+
def reconstruction_loss(self, x: torch.Tensor) -> torch.Tensor:
|
|
1315
|
+
"""Computes the reconstruction loss for the projection.
|
|
1316
|
+
|
|
1317
|
+
Calculates the mean squared error between the original input and its reconstruction
|
|
1318
|
+
after forward and inverse projections.
|
|
1319
|
+
|
|
1320
|
+
Parameters
|
|
1321
|
+
----------
|
|
1322
|
+
`x` : torch.Tensor
|
|
1323
|
+
Original input tensor, shape (batch_size, input_dim).
|
|
1324
|
+
|
|
1325
|
+
Returns
|
|
1326
|
+
-------
|
|
1327
|
+
loss : torch.Tensor
|
|
1328
|
+
Mean squared error loss between the original and reconstructed tensors.
|
|
1329
|
+
"""
|
|
1330
|
+
x_reduced = self.forward(x)
|
|
1331
|
+
x_reconstructed = self.inverse_transform(x_reduced)
|
|
1332
|
+
return F.mse_loss(x_reconstructed, x)
|
|
1333
|
+
|
|
1334
|
+
###==================================================================================================================###
|
|
1335
|
+
|
|
1336
|
+
class TrainUnClipDecoder(nn.Module):
|
|
1337
|
+
"""Trainer for the UnCLIP decoder model.
|
|
1338
|
+
|
|
1339
|
+
Orchestrates the training of the UnCLIP decoder model, integrating CLIP embeddings, forward
|
|
1340
|
+
and reverse diffusion processes, and optional dimensionality reduction. Supports mixed
|
|
1341
|
+
precision, gradient accumulation, DDP, and comprehensive evaluation metrics.
|
|
1342
|
+
|
|
1343
|
+
Parameters
|
|
1344
|
+
----------
|
|
1345
|
+
`clip_embedding_dim` : int
|
|
1346
|
+
Dimensionality of the input embeddings.
|
|
1347
|
+
`decoder_model` : nn.Module
|
|
1348
|
+
The UnCLIP decoder model (e.g., UnClipDecoder) to be trained.
|
|
1349
|
+
`clip_model` : nn.Module
|
|
1350
|
+
CLIP model for generating text and image embeddings.
|
|
1351
|
+
`train_loader` : torch.utils.data.DataLoader
|
|
1352
|
+
DataLoader for training data.
|
|
1353
|
+
`optimizer` : torch.optim.Optimizer
|
|
1354
|
+
Optimizer for training the decoder model.
|
|
1355
|
+
`objective` : Callable
|
|
1356
|
+
Loss function to compute the difference between predicted and target noise.
|
|
1357
|
+
`clip_text_projection` : nn.Module, optional
|
|
1358
|
+
Projection module for text embeddings, default None.
|
|
1359
|
+
`clip_image_projection` : nn.Module, optional
|
|
1360
|
+
Projection module for image embeddings, default None.
|
|
1361
|
+
`val_loader` : torch.utils.data.DataLoader, optional
|
|
1362
|
+
DataLoader for validation data, default None.
|
|
1363
|
+
`metrics_` : Any, optional
|
|
1364
|
+
Object providing evaluation metrics (e.g., FID, MSE, PSNR, SSIM, LPIPS), default None.
|
|
1365
|
+
`max_epochs` : int, optional
|
|
1366
|
+
Maximum number of training epochs (default: 1000).
|
|
1367
|
+
`device` : Union[str, torch.device], optional
|
|
1368
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
1369
|
+
`store_path` : str, optional
|
|
1370
|
+
Directory to save model checkpoints (default: "unclip_decoder").
|
|
1371
|
+
`patience` : int, optional
|
|
1372
|
+
Number of epochs to wait for improvement before early stopping (default: 100).
|
|
1373
|
+
`warmup_epochs` : int, optional
|
|
1374
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
1375
|
+
`val_frequency` : int, optional
|
|
1376
|
+
Frequency (in epochs) for validation (default: 10).
|
|
1377
|
+
`use_ddp` : bool, optional
|
|
1378
|
+
Whether to use Distributed Data Parallel training (default: False).
|
|
1379
|
+
`grad_accumulation_steps` : int, optional
|
|
1380
|
+
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
1381
|
+
`log_frequency` : int, optional
|
|
1382
|
+
Frequency (in epochs) for printing progress (default: 1).
|
|
1383
|
+
`use_compilation` : bool, optional
|
|
1384
|
+
Whether to compile the model using torch.compile (default: False).
|
|
1385
|
+
`image_output_range` : Tuple[float, float], optional
|
|
1386
|
+
Range for clamping output images (default: (-1.0, 1.0)).
|
|
1387
|
+
`reduce_clip_embedding_dim` : bool, optional
|
|
1388
|
+
Whether to apply dimensionality reduction to embeddings (default: True).
|
|
1389
|
+
`transformer_embedding_dim` : int, optional
|
|
1390
|
+
Output dimensionality for reduced embeddings (default: 312).
|
|
1391
|
+
`normalize_clip_embeddings` : bool, optional
|
|
1392
|
+
Whether to normalize CLIP embeddings (default: True).
|
|
1393
|
+
`finetune_clip_projections` : bool, optional
|
|
1394
|
+
Whether to fine-tune projection layers (default: False).
|
|
1395
|
+
"""
|
|
1396
|
+
def __init__(
|
|
1397
|
+
self,
|
|
1398
|
+
clip_embedding_dim: int,
|
|
1399
|
+
decoder_model: nn.Module,
|
|
1400
|
+
clip_model: nn.Module,
|
|
1401
|
+
train_loader: torch.utils.data.DataLoader,
|
|
1402
|
+
optimizer: torch.optim.Optimizer,
|
|
1403
|
+
objective: Callable,
|
|
1404
|
+
clip_text_projection: Optional[nn.Module] = None,
|
|
1405
|
+
clip_image_projection: Optional[nn.Module] = None,
|
|
1406
|
+
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
1407
|
+
metrics_: Optional[Any] = None,
|
|
1408
|
+
max_epochs: int = 1000,
|
|
1409
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
1410
|
+
store_path: str = "unclip_decoder",
|
|
1411
|
+
patience: int = 100,
|
|
1412
|
+
warmup_epochs: int = 100,
|
|
1413
|
+
val_frequency: int = 10,
|
|
1414
|
+
use_ddp: bool = False,
|
|
1415
|
+
grad_accumulation_steps: int = 1,
|
|
1416
|
+
log_frequency: int = 1,
|
|
1417
|
+
use_compilation: bool = False,
|
|
1418
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0),
|
|
1419
|
+
reduce_clip_embedding_dim: bool = True,
|
|
1420
|
+
transformer_embedding_dim: int = 312,
|
|
1421
|
+
normalize_clip_embeddings: bool = True,
|
|
1422
|
+
finetune_clip_projections: bool = False # if text_projection and image_projection model should be finetune
|
|
1423
|
+
):
|
|
1424
|
+
super().__init__()
|
|
1425
|
+
# training configuration
|
|
1426
|
+
self.use_ddp = use_ddp
|
|
1427
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
1428
|
+
self.use_compilation = use_compilation
|
|
1429
|
+
if device is None:
|
|
1430
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
1431
|
+
elif isinstance(device, str):
|
|
1432
|
+
self.device = torch.device(device)
|
|
1433
|
+
else:
|
|
1434
|
+
self.device = device
|
|
1435
|
+
|
|
1436
|
+
# core models
|
|
1437
|
+
self.decoder_model = decoder_model.to(self.device)
|
|
1438
|
+
self.clip_model = clip_model.to(self.device)
|
|
1439
|
+
|
|
1440
|
+
self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
|
|
1441
|
+
|
|
1442
|
+
# setup distributed training
|
|
1443
|
+
if self.use_ddp:
|
|
1444
|
+
self._setup_ddp()
|
|
1445
|
+
else:
|
|
1446
|
+
self._setup_single_gpu()
|
|
1447
|
+
|
|
1448
|
+
# compile and wrap models
|
|
1449
|
+
self._compile_models()
|
|
1450
|
+
self._wrap_models_for_ddp()
|
|
1451
|
+
|
|
1452
|
+
# projection models (PCA equivalent in the paper)
|
|
1453
|
+
if self.reduce_clip_embedding_dim and clip_text_projection is not None and clip_image_projection is not None:
|
|
1454
|
+
self.clip_text_projection = clip_text_projection.to(self.device)
|
|
1455
|
+
self.clip_image_projection = clip_image_projection.to(self.device)
|
|
1456
|
+
else:
|
|
1457
|
+
self.clip_text_projection = None
|
|
1458
|
+
self.clip_image_projection = None
|
|
1459
|
+
|
|
1460
|
+
# training components
|
|
1461
|
+
self.clip_embedding_dim = transformer_embedding_dim if self.reduce_clip_embedding_dim else clip_embedding_dim
|
|
1462
|
+
self.metrics_ = metrics_
|
|
1463
|
+
self.optimizer = optimizer
|
|
1464
|
+
self.objective = objective
|
|
1465
|
+
self.train_loader = train_loader
|
|
1466
|
+
self.val_loader = val_loader
|
|
1467
|
+
|
|
1468
|
+
# training parameters
|
|
1469
|
+
self.max_epochs = max_epochs
|
|
1470
|
+
self.patience = patience
|
|
1471
|
+
self.val_frequency = val_frequency
|
|
1472
|
+
self.log_frequency = log_frequency
|
|
1473
|
+
self.image_output_range = image_output_range
|
|
1474
|
+
self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
|
|
1475
|
+
self.normalize_clip_embeddings = normalize_clip_embeddings
|
|
1476
|
+
self.transformer_embedding_dim = transformer_embedding_dim
|
|
1477
|
+
self.finetune_clip_projections = finetune_clip_projections
|
|
1478
|
+
|
|
1479
|
+
|
|
1480
|
+
# checkpoint management
|
|
1481
|
+
self.store_path = store_path
|
|
1482
|
+
|
|
1483
|
+
# learning rate scheduling
|
|
1484
|
+
self.scheduler = ReduceLROnPlateau(
|
|
1485
|
+
self.optimizer,
|
|
1486
|
+
patience=self.patience,
|
|
1487
|
+
factor=0.5
|
|
1488
|
+
)
|
|
1489
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
1490
|
+
|
|
1491
|
+
def forward(self) -> Tuple[List[float], float]:
|
|
1492
|
+
"""Trains the UnCLIP decoder model to predict noise for denoising.
|
|
1493
|
+
|
|
1494
|
+
Executes the training loop, optimizing the decoder model using CLIP embeddings, mixed
|
|
1495
|
+
precision, gradient clipping, and learning rate scheduling. Supports validation, early
|
|
1496
|
+
stopping, and checkpointing.
|
|
1497
|
+
|
|
1498
|
+
Returns
|
|
1499
|
+
-------
|
|
1500
|
+
train_losses : List[float]
|
|
1501
|
+
List of mean training losses per epoch.
|
|
1502
|
+
best_val_loss : float
|
|
1503
|
+
Best validation or training loss achieved.
|
|
1504
|
+
"""
|
|
1505
|
+
# set models to training mode
|
|
1506
|
+
self.decoder_model.train() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj to train mode
|
|
1507
|
+
if not self.decoder_model.forward_diffusion.variance_scheduler.trainable_beta: # ff beta is not trainable
|
|
1508
|
+
self.decoder_model.forward_diffusion.variance_scheduler.eval()
|
|
1509
|
+
|
|
1510
|
+
# set text_projection and image_projection to train mode if fine-tuning
|
|
1511
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
1512
|
+
if self.finetune_clip_projections:
|
|
1513
|
+
self.clip_text_projection.train()
|
|
1514
|
+
self.clip_image_projection.train()
|
|
1515
|
+
else:
|
|
1516
|
+
self.clip_text_projection.eval()
|
|
1517
|
+
self.clip_image_projection.eval()
|
|
1518
|
+
|
|
1519
|
+
# set CLIP model to eval mode (frozen)
|
|
1520
|
+
if self.clip_model is not None:
|
|
1521
|
+
self.clip_model.eval()
|
|
1522
|
+
|
|
1523
|
+
# initialize training components
|
|
1524
|
+
scaler = torch.GradScaler()
|
|
1525
|
+
train_losses = []
|
|
1526
|
+
best_val_loss = float("inf")
|
|
1527
|
+
wait = 0
|
|
1528
|
+
|
|
1529
|
+
# main training loop
|
|
1530
|
+
for epoch in range(self.max_epochs):
|
|
1531
|
+
# set epoch for distributed sampler if using DDP
|
|
1532
|
+
if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
|
|
1533
|
+
self.train_loader.sampler.set_epoch(epoch)
|
|
1534
|
+
|
|
1535
|
+
train_losses_epoch = []
|
|
1536
|
+
|
|
1537
|
+
# training step loop with gradient accumulation
|
|
1538
|
+
for step, (images, texts) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
|
|
1539
|
+
images = images.to(self.device, non_blocking=True)
|
|
1540
|
+
|
|
1541
|
+
# forward pass with mixed precision
|
|
1542
|
+
with torch.autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
|
|
1543
|
+
# encode text and image with CLIP
|
|
1544
|
+
text_embeddings, image_embeddings = self._get_clip_embeddings(images, texts)
|
|
1545
|
+
|
|
1546
|
+
# reduce dimensionality (PCA equivalent)
|
|
1547
|
+
text_embeddings, image_embeddings = self._apply_dimensionality_reduction(
|
|
1548
|
+
text_embeddings, image_embeddings
|
|
1549
|
+
)
|
|
1550
|
+
|
|
1551
|
+
# use decoder model to predict noise
|
|
1552
|
+
p_classifier_free = torch.rand(1).item()
|
|
1553
|
+
p_text_drop = torch.rand(1).item()
|
|
1554
|
+
predicted_noise, noise = self.decoder_model(
|
|
1555
|
+
image_embeddings,
|
|
1556
|
+
text_embeddings,
|
|
1557
|
+
images,
|
|
1558
|
+
texts,
|
|
1559
|
+
p_classifier_free,
|
|
1560
|
+
p_text_drop
|
|
1561
|
+
)
|
|
1562
|
+
|
|
1563
|
+
# compute loss
|
|
1564
|
+
loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
|
|
1565
|
+
|
|
1566
|
+
scaler.scale(loss).backward()
|
|
1567
|
+
|
|
1568
|
+
if (step + 1) % self.grad_accumulation_steps == 0:
|
|
1569
|
+
# clip gradients
|
|
1570
|
+
scaler.unscale_(self.optimizer)
|
|
1571
|
+
torch.nn.utils.clip_grad_norm_(self.decoder_model.parameters(), max_norm=1.0) # covers all submodules
|
|
1572
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
|
|
1573
|
+
torch.nn.utils.clip_grad_norm_(self.clip_text_projection.parameters(), max_norm=1.0)
|
|
1574
|
+
torch.nn.utils.clip_grad_norm_(self.clip_image_projection.parameters(), max_norm=1.0)
|
|
1575
|
+
|
|
1576
|
+
scaler.step(self.optimizer)
|
|
1577
|
+
scaler.update()
|
|
1578
|
+
self.optimizer.zero_grad()
|
|
1579
|
+
self.warmup_lr_scheduler.step()
|
|
1580
|
+
torch.cuda.empty_cache() # clear memory after optimizer step
|
|
1581
|
+
|
|
1582
|
+
train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
|
|
1583
|
+
|
|
1584
|
+
mean_train_loss = self._compute_mean_loss(train_losses_epoch)
|
|
1585
|
+
train_losses.append(mean_train_loss)
|
|
1586
|
+
|
|
1587
|
+
if self.master_process and (epoch + 1) % self.log_frequency == 0:
|
|
1588
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
1589
|
+
print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
|
|
1590
|
+
|
|
1591
|
+
current_loss = mean_train_loss
|
|
1592
|
+
|
|
1593
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
1594
|
+
val_metrics = self.validate()
|
|
1595
|
+
val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
|
|
1596
|
+
|
|
1597
|
+
if self.master_process:
|
|
1598
|
+
print(f" | Val Loss: {val_loss:.4f}", end="")
|
|
1599
|
+
if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
1600
|
+
print(f" | FID: {fid:.4f}", end="")
|
|
1601
|
+
if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
1602
|
+
print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
|
|
1603
|
+
if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
1604
|
+
print(f" | LPIPS: {lpips_score:.4f}", end="")
|
|
1605
|
+
print()
|
|
1606
|
+
|
|
1607
|
+
self.scheduler.step(current_loss)
|
|
1608
|
+
|
|
1609
|
+
if self.master_process:
|
|
1610
|
+
if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
|
|
1611
|
+
best_val_loss = current_loss
|
|
1612
|
+
wait = 0
|
|
1613
|
+
self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
|
|
1614
|
+
else:
|
|
1615
|
+
wait += 1
|
|
1616
|
+
if wait >= self.patience:
|
|
1617
|
+
print("Early stopping triggered")
|
|
1618
|
+
self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
|
|
1619
|
+
break
|
|
1620
|
+
|
|
1621
|
+
if self.use_ddp:
|
|
1622
|
+
destroy_process_group()
|
|
1623
|
+
|
|
1624
|
+
return train_losses, best_val_loss
|
|
1625
|
+
|
|
1626
|
+
def _setup_ddp(self) -> None:
|
|
1627
|
+
"""Sets up Distributed Data Parallel training configuration.
|
|
1628
|
+
|
|
1629
|
+
Initializes the process group, sets up rank information, and configures the CUDA
|
|
1630
|
+
device for the current process in DDP mode.
|
|
1631
|
+
"""
|
|
1632
|
+
required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
|
|
1633
|
+
for var in required_env_vars:
|
|
1634
|
+
if var not in os.environ:
|
|
1635
|
+
raise ValueError(f"DDP enabled but {var} environment variable not set")
|
|
1636
|
+
|
|
1637
|
+
if not torch.cuda.is_available():
|
|
1638
|
+
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
1639
|
+
|
|
1640
|
+
if not torch.distributed.is_initialized():
|
|
1641
|
+
init_process_group(backend="nccl")
|
|
1642
|
+
|
|
1643
|
+
self.ddp_rank = int(os.environ["RANK"])
|
|
1644
|
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
|
1645
|
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"])
|
|
1646
|
+
|
|
1647
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
1648
|
+
torch.cuda.set_device(self.device)
|
|
1649
|
+
|
|
1650
|
+
self.master_process = self.ddp_rank == 0
|
|
1651
|
+
|
|
1652
|
+
if self.master_process:
|
|
1653
|
+
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
1654
|
+
|
|
1655
|
+
def _setup_single_gpu(self) -> None:
|
|
1656
|
+
"""Sets up single GPU or CPU training configuration.
|
|
1657
|
+
|
|
1658
|
+
Configures the training setup for single-device operation, setting rank and process
|
|
1659
|
+
information for non-DDP training.
|
|
1660
|
+
"""
|
|
1661
|
+
self.ddp_rank = 0
|
|
1662
|
+
self.ddp_local_rank = 0
|
|
1663
|
+
self.ddp_world_size = 1
|
|
1664
|
+
self.master_process = True
|
|
1665
|
+
|
|
1666
|
+
@staticmethod
|
|
1667
|
+
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
|
|
1668
|
+
"""Creates a learning rate scheduler for warmup.
|
|
1669
|
+
|
|
1670
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
1671
|
+
optimizer's initial value over the specified warmup epochs.
|
|
1672
|
+
|
|
1673
|
+
Parameters
|
|
1674
|
+
----------
|
|
1675
|
+
`optimizer` : torch.optim.Optimizer
|
|
1676
|
+
Optimizer to apply the scheduler to.
|
|
1677
|
+
`warmup_epochs` : int
|
|
1678
|
+
Number of epochs for the warmup phase.
|
|
1679
|
+
|
|
1680
|
+
Returns
|
|
1681
|
+
-------
|
|
1682
|
+
lr_scheduler : torch.optim.lr_scheduler.LambdaLR
|
|
1683
|
+
Learning rate scheduler for warmup.
|
|
1684
|
+
"""
|
|
1685
|
+
def lr_lambda(epoch):
|
|
1686
|
+
return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
|
|
1687
|
+
|
|
1688
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
1689
|
+
|
|
1690
|
+
def _wrap_models_for_ddp(self) -> None:
|
|
1691
|
+
"""Wraps models with DistributedDataParallel for multi-GPU training.
|
|
1692
|
+
|
|
1693
|
+
Configures the decoder model and, if fine-tuning, the projection models for DDP training.
|
|
1694
|
+
"""
|
|
1695
|
+
if self.use_ddp:
|
|
1696
|
+
self.decoder_model = self.decoder_model.to(self.ddp_local_rank)
|
|
1697
|
+
self.decoder_model = DDP(
|
|
1698
|
+
self.decoder_model,
|
|
1699
|
+
device_ids=[self.ddp_local_rank],
|
|
1700
|
+
find_unused_parameters=True
|
|
1701
|
+
)
|
|
1702
|
+
# only wrap text_projection and image_projection if they are trainable
|
|
1703
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
|
|
1704
|
+
self.clip_text_projection = self.clip_text_projection.to(self.ddp_local_rank)
|
|
1705
|
+
self.clip_image_projection = self.clip_image_projection.to(self.ddp_local_rank)
|
|
1706
|
+
self.clip_text_projection = DDP(self.clip_text_projection, device_ids=[self.ddp_local_rank])
|
|
1707
|
+
self.clip_image_projection = DDP(self.clip_image_projection, device_ids=[self.ddp_local_rank])
|
|
1708
|
+
|
|
1709
|
+
def _compile_models(self) -> None:
|
|
1710
|
+
"""Compiles models for optimization if supported.
|
|
1711
|
+
|
|
1712
|
+
Attempts to compile the decoder model and, if fine-tuning, the projection models using
|
|
1713
|
+
torch.compile for optimization, falling back to uncompiled execution if compilation fails.
|
|
1714
|
+
"""
|
|
1715
|
+
if self.use_compilation:
|
|
1716
|
+
try:
|
|
1717
|
+
self.decoder_model = self.decoder_model.to(self.device)
|
|
1718
|
+
self.decoder_model = torch.compile(self.decoder_model, mode="reduce-overhead")
|
|
1719
|
+
# only compile text_projection and image_projection if they are trainable
|
|
1720
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
|
|
1721
|
+
self.clip_text_projection = self.clip_text_projection.to(self.device)
|
|
1722
|
+
self.clip_image_projection = self.clip_image_projection.to(self.device)
|
|
1723
|
+
self.clip_text_projection = torch.compile(self.clip_text_projection, mode="reduce-overhead")
|
|
1724
|
+
self.clip_image_projection = torch.compile(self.clip_image_projection, mode="reduce-overhead")
|
|
1725
|
+
if self.master_process:
|
|
1726
|
+
print("Models compiled successfully")
|
|
1727
|
+
except Exception as e:
|
|
1728
|
+
if self.master_process:
|
|
1729
|
+
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
1730
|
+
|
|
1731
|
+
def _get_clip_embeddings(
|
|
1732
|
+
self,
|
|
1733
|
+
images: torch.Tensor,
|
|
1734
|
+
texts: Union[List, torch.Tensor]
|
|
1735
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1736
|
+
"""Encodes images and texts using the CLIP model.
|
|
1737
|
+
|
|
1738
|
+
Generates text and image embeddings using the CLIP model, with optional normalization.
|
|
1739
|
+
|
|
1740
|
+
Parameters
|
|
1741
|
+
----------
|
|
1742
|
+
`images` : torch.Tensor
|
|
1743
|
+
Input images, shape (batch_size, channels, height, width).
|
|
1744
|
+
`texts` : Union[List, torch.Tensor]
|
|
1745
|
+
Text prompts for conditional generation.
|
|
1746
|
+
|
|
1747
|
+
Returns
|
|
1748
|
+
-------
|
|
1749
|
+
text_embeddings : torch.Tensor
|
|
1750
|
+
CLIP text embeddings, shape (batch_size, embedding_dim).
|
|
1751
|
+
image_embeddings : torch.Tensor
|
|
1752
|
+
CLIP image embeddings, shape (batch_size, embedding_dim).
|
|
1753
|
+
"""
|
|
1754
|
+
with torch.no_grad():
|
|
1755
|
+
# encode text y with CLIP text encoder: z_t ← CLIP_text(y)
|
|
1756
|
+
text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize_clip_embeddings)
|
|
1757
|
+
# encode image x with CLIP image encoder: z_i ← CLIP_image(x)
|
|
1758
|
+
image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize_clip_embeddings)
|
|
1759
|
+
return text_embeddings, image_embeddings
|
|
1760
|
+
|
|
1761
|
+
def _apply_dimensionality_reduction(
|
|
1762
|
+
self,
|
|
1763
|
+
text_embeddings: torch.Tensor,
|
|
1764
|
+
image_embeddings: torch.Tensor
|
|
1765
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1766
|
+
"""Applies dimensionality reduction to embeddings if enabled.
|
|
1767
|
+
|
|
1768
|
+
Projects text and image embeddings to a lower-dimensional space using learned
|
|
1769
|
+
projection layers, mimicking PCA as used in the UnCLIP paper.
|
|
1770
|
+
|
|
1771
|
+
Parameters
|
|
1772
|
+
----------
|
|
1773
|
+
`text_embeddings` : torch.Tensor
|
|
1774
|
+
CLIP text embeddings, shape (batch_size, embedding_dim).
|
|
1775
|
+
`image_embeddings` : torch.Tensor
|
|
1776
|
+
CLIP image embeddings, shape (batch_size, embedding_dim).
|
|
1777
|
+
|
|
1778
|
+
Returns
|
|
1779
|
+
-------
|
|
1780
|
+
text_embeddings : torch.Tensor
|
|
1781
|
+
Projected text embeddings, shape (batch_size, output_dim) if reduced, else unchanged.
|
|
1782
|
+
image_embeddings : torch.Tensor
|
|
1783
|
+
Projected image embeddings, shape (batch_size, output_dim) if reduced, else unchanged.
|
|
1784
|
+
"""
|
|
1785
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
1786
|
+
if not self.finetune_clip_projections:
|
|
1787
|
+
with torch.no_grad():
|
|
1788
|
+
text_embeddings = self.clip_text_projection(text_embeddings.to(self.device))
|
|
1789
|
+
image_embeddings = self.clip_image_projection(image_embeddings.to(self.device))
|
|
1790
|
+
else:
|
|
1791
|
+
text_embeddings = self.clip_text_projection(text_embeddings.to(self.device))
|
|
1792
|
+
image_embeddings = self.clip_image_projection(image_embeddings.to(self.device))
|
|
1793
|
+
return text_embeddings.to(self.device), image_embeddings.to(self.device)
|
|
1794
|
+
|
|
1795
|
+
def _compute_mean_loss(self, losses: List[float]) -> float:
|
|
1796
|
+
"""Computes mean loss with DDP synchronization if needed.
|
|
1797
|
+
|
|
1798
|
+
Calculates the mean of the provided losses and synchronizes the result across
|
|
1799
|
+
processes in DDP mode.
|
|
1800
|
+
|
|
1801
|
+
Parameters
|
|
1802
|
+
----------
|
|
1803
|
+
`losses` : List[float]
|
|
1804
|
+
List of loss values for the current epoch.
|
|
1805
|
+
|
|
1806
|
+
Returns
|
|
1807
|
+
-------
|
|
1808
|
+
mean_loss : float
|
|
1809
|
+
Mean loss value, synchronized if using DDP.
|
|
1810
|
+
"""
|
|
1811
|
+
if not losses:
|
|
1812
|
+
return 0.0
|
|
1813
|
+
mean_loss = sum(losses) / len(losses)
|
|
1814
|
+
if self.use_ddp:
|
|
1815
|
+
# synchronize loss across all processes
|
|
1816
|
+
loss_tensor = torch.tensor(mean_loss, device=self.device)
|
|
1817
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
|
|
1818
|
+
mean_loss = (loss_tensor / self.ddp_world_size).item()
|
|
1819
|
+
|
|
1820
|
+
return mean_loss
|
|
1821
|
+
|
|
1822
|
+
def _save_checkpoint(self, epoch: int, loss: float, is_best: bool = False, suffix: str = ""):
|
|
1823
|
+
"""Saves model checkpoint.
|
|
1824
|
+
|
|
1825
|
+
Saves the state of the decoder model, its submodules, optimizer, and schedulers,
|
|
1826
|
+
with options for best model and epoch-specific checkpoints.
|
|
1827
|
+
|
|
1828
|
+
Parameters
|
|
1829
|
+
----------
|
|
1830
|
+
`epoch` : int
|
|
1831
|
+
Current epoch number.
|
|
1832
|
+
`loss` : float
|
|
1833
|
+
Current loss value.
|
|
1834
|
+
`is_best` : bool, optional
|
|
1835
|
+
Whether to save as the best model checkpoint (default: False).
|
|
1836
|
+
`suffix` : str, optional
|
|
1837
|
+
Suffix to add to checkpoint filename, default "".
|
|
1838
|
+
"""
|
|
1839
|
+
if not self.master_process:
|
|
1840
|
+
return
|
|
1841
|
+
checkpoint = {
|
|
1842
|
+
'epoch': epoch,
|
|
1843
|
+
'loss': loss,
|
|
1844
|
+
# core models (submodules of decoder_model)
|
|
1845
|
+
'noise_predictor_state_dict': self.decoder_model.module.noise_predictor.state_dict() if self.use_ddp else self.decoder_model.noise_predictor.state_dict(),
|
|
1846
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
1847
|
+
# training configuration
|
|
1848
|
+
'embedding_dim': self.clip_embedding_dim,
|
|
1849
|
+
'output_dim': self.transformer_embedding_dim,
|
|
1850
|
+
'reduce_dim': self.reduce_clip_embedding_dim,
|
|
1851
|
+
'normalize': self.normalize_clip_embeddings
|
|
1852
|
+
}
|
|
1853
|
+
|
|
1854
|
+
# save conditional model (submodule of decoder_model)
|
|
1855
|
+
if self.decoder_model.glide_text_encoder is not None:
|
|
1856
|
+
checkpoint['conditional_model_state_dict'] = (
|
|
1857
|
+
self.decoder_model.module.glide_text_encoder.state_dict() if self.use_ddp
|
|
1858
|
+
else self.decoder_model.glide_text_encoder.state_dict()
|
|
1859
|
+
)
|
|
1860
|
+
|
|
1861
|
+
# save variance scheduler (submodule of decoder_model, always saved)
|
|
1862
|
+
checkpoint['variance_scheduler_state_dict'] = (
|
|
1863
|
+
self.decoder_model.forward_diffusion.module.variance_scheduler.state_dict() if self.use_ddp
|
|
1864
|
+
else self.decoder_model.forward_diffusion.variance_scheduler.state_dict()
|
|
1865
|
+
)
|
|
1866
|
+
|
|
1867
|
+
# save CLIP time projection layer (submodule of decoder_model)
|
|
1868
|
+
checkpoint['clip_time_proj_state_dict'] = (
|
|
1869
|
+
self.decoder_model.module.clip_time_projection.state_dict() if self.use_ddp
|
|
1870
|
+
else self.decoder_model.clip_time_projection.state_dict()
|
|
1871
|
+
)
|
|
1872
|
+
|
|
1873
|
+
# save decoder projection layer (submodule of decoder_model)
|
|
1874
|
+
checkpoint['decoder_projection_state_dict'] = (
|
|
1875
|
+
self.decoder_model.module.clip_decoder_projection.state_dict() if self.use_ddp
|
|
1876
|
+
else self.decoder_model.clip_decoder_projection.state_dict()
|
|
1877
|
+
)
|
|
1878
|
+
# a nn.Linear projection layer
|
|
1879
|
+
checkpoint['clip_time_projection_state_dict'] = (
|
|
1880
|
+
self.decoder_model.module.clip_time_projection.state_dict() if self.use_ddp
|
|
1881
|
+
else self.decoder_model.clip_time_projection.state_dict()
|
|
1882
|
+
)
|
|
1883
|
+
|
|
1884
|
+
# save projection models (PCA equivalent)
|
|
1885
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
1886
|
+
checkpoint['text_projection_state_dict'] = (
|
|
1887
|
+
self.clip_text_projection.module.state_dict() if self.use_ddp
|
|
1888
|
+
else self.clip_text_projection.state_dict()
|
|
1889
|
+
)
|
|
1890
|
+
checkpoint['image_projection_state_dict'] = (
|
|
1891
|
+
self.clip_image_projection.module.state_dict() if self.use_ddp
|
|
1892
|
+
else self.clip_image_projection.state_dict()
|
|
1893
|
+
)
|
|
1894
|
+
|
|
1895
|
+
# save schedulers state
|
|
1896
|
+
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
|
1897
|
+
checkpoint['warmup_scheduler_state_dict'] = self.warmup_lr_scheduler.state_dict()
|
|
1898
|
+
|
|
1899
|
+
filename = f"unclip_decoder_epoch_{epoch}{suffix}.pth"
|
|
1900
|
+
if is_best:
|
|
1901
|
+
filename = f"unclip_decoder_best{suffix}.pth"
|
|
1902
|
+
|
|
1903
|
+
filepath = os.path.join(self.store_path, filename)
|
|
1904
|
+
os.makedirs(self.store_path, exist_ok=True)
|
|
1905
|
+
torch.save(checkpoint, filepath)
|
|
1906
|
+
|
|
1907
|
+
if is_best:
|
|
1908
|
+
print(f"Best model saved: {filepath}")
|
|
1909
|
+
|
|
1910
|
+
def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
|
|
1911
|
+
"""Loads model checkpoint.
|
|
1912
|
+
|
|
1913
|
+
Restores the state of the decoder model, its submodules, optimizer, and schedulers
|
|
1914
|
+
from a saved checkpoint, handling DDP compatibility.
|
|
1915
|
+
|
|
1916
|
+
Parameters
|
|
1917
|
+
----------
|
|
1918
|
+
`checkpoint_path` : str
|
|
1919
|
+
Path to the checkpoint file.
|
|
1920
|
+
|
|
1921
|
+
Returns
|
|
1922
|
+
-------
|
|
1923
|
+
epoch : int
|
|
1924
|
+
The epoch at which the checkpoint was saved.
|
|
1925
|
+
loss : float
|
|
1926
|
+
The loss at the checkpoint.
|
|
1927
|
+
"""
|
|
1928
|
+
try:
|
|
1929
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
1930
|
+
except FileNotFoundError:
|
|
1931
|
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
1932
|
+
|
|
1933
|
+
def _load_model_state_dict(model: nn.Module, state_dict: dict, model_name: str) -> None:
|
|
1934
|
+
"""Helper function to load state dict with DDP compatibility."""
|
|
1935
|
+
try:
|
|
1936
|
+
# handle DDP state dict compatibility
|
|
1937
|
+
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
|
|
1938
|
+
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
|
|
1939
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
1940
|
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
1941
|
+
|
|
1942
|
+
model.load_state_dict(state_dict)
|
|
1943
|
+
if self.master_process:
|
|
1944
|
+
print(f"✓ Loaded {model_name}")
|
|
1945
|
+
except Exception as e:
|
|
1946
|
+
warnings.warn(f"Failed to load {model_name}: {e}")
|
|
1947
|
+
|
|
1948
|
+
# load core noise predictor model (submodule of decoder_model)
|
|
1949
|
+
if 'noise_predictor_state_dict' in checkpoint:
|
|
1950
|
+
_load_model_state_dict(self.decoder_model.noise_predictor, checkpoint['noise_predictor_state_dict'],
|
|
1951
|
+
'noise_predictor')
|
|
1952
|
+
|
|
1953
|
+
# load conditional model (submodule of decoder_model) - matches your save logic
|
|
1954
|
+
if self.decoder_model.glide_text_encoder is not None and 'conditional_model_state_dict' in checkpoint:
|
|
1955
|
+
_load_model_state_dict(self.decoder_model.glide_text_encoder, checkpoint['conditional_model_state_dict'],
|
|
1956
|
+
'glide_text_encoder')
|
|
1957
|
+
|
|
1958
|
+
# load variance scheduler (submodule of decoder_model)
|
|
1959
|
+
if 'variance_scheduler_state_dict' in checkpoint:
|
|
1960
|
+
try:
|
|
1961
|
+
_load_model_state_dict(self.decoder_model.forward_diffusion.variance_scheduler,
|
|
1962
|
+
checkpoint['variance_scheduler_state_dict'], 'variance_scheduler')
|
|
1963
|
+
except Exception as e:
|
|
1964
|
+
warnings.warn(f"Failed to load variance scheduler: {e}")
|
|
1965
|
+
|
|
1966
|
+
# load CLIP time projection layer (submodule of decoder_model)
|
|
1967
|
+
if 'clip_time_proj_state_dict' in checkpoint:
|
|
1968
|
+
try:
|
|
1969
|
+
_load_model_state_dict(self.decoder_model.clip_time_projection,
|
|
1970
|
+
checkpoint['clip_time_proj_state_dict'], 'clip_time_projection')
|
|
1971
|
+
except Exception as e:
|
|
1972
|
+
warnings.warn(f"Failed to load CLIP time projection: {e}")
|
|
1973
|
+
|
|
1974
|
+
# load decoder projection layer (submodule of decoder_model)
|
|
1975
|
+
if 'decoder_projection_state_dict' in checkpoint:
|
|
1976
|
+
try:
|
|
1977
|
+
_load_model_state_dict(self.decoder_model.clip_decoder_projection,
|
|
1978
|
+
checkpoint['decoder_projection_state_dict'], 'clip_decoder_projection')
|
|
1979
|
+
except Exception as e:
|
|
1980
|
+
warnings.warn(f"Failed to load decoder projection: {e}")
|
|
1981
|
+
|
|
1982
|
+
# handle the duplicate clip_time_projection_state_dict (from your save function)
|
|
1983
|
+
# This loads the same thing as clip_time_proj_state_dict above, so we'll skip it
|
|
1984
|
+
# to avoid overwriting, but add a warning if it exists
|
|
1985
|
+
if 'clip_time_projection_state_dict' in checkpoint and self.master_process:
|
|
1986
|
+
warnings.warn(
|
|
1987
|
+
"Found duplicate 'clip_time_projection_state_dict' in checkpoint - skipping to avoid conflict")
|
|
1988
|
+
|
|
1989
|
+
# load projection models (PCA equivalent)
|
|
1990
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
1991
|
+
if 'text_projection_state_dict' in checkpoint:
|
|
1992
|
+
_load_model_state_dict(self.clip_text_projection, checkpoint['text_projection_state_dict'],
|
|
1993
|
+
'text_projection')
|
|
1994
|
+
if 'image_projection_state_dict' in checkpoint:
|
|
1995
|
+
_load_model_state_dict(self.clip_image_projection, checkpoint['image_projection_state_dict'],
|
|
1996
|
+
'image_projection')
|
|
1997
|
+
|
|
1998
|
+
# load optimizer
|
|
1999
|
+
if 'optimizer_state_dict' in checkpoint:
|
|
2000
|
+
try:
|
|
2001
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
2002
|
+
if self.master_process:
|
|
2003
|
+
print("✓ Loaded optimizer")
|
|
2004
|
+
except Exception as e:
|
|
2005
|
+
warnings.warn(f"Failed to load optimizer state: {e}")
|
|
2006
|
+
|
|
2007
|
+
# load schedulers
|
|
2008
|
+
if 'scheduler_state_dict' in checkpoint:
|
|
2009
|
+
try:
|
|
2010
|
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
2011
|
+
if self.master_process:
|
|
2012
|
+
print("✓ Loaded main scheduler")
|
|
2013
|
+
except Exception as e:
|
|
2014
|
+
warnings.warn(f"Failed to load scheduler state: {e}")
|
|
2015
|
+
|
|
2016
|
+
if 'warmup_scheduler_state_dict' in checkpoint:
|
|
2017
|
+
try:
|
|
2018
|
+
self.warmup_lr_scheduler.load_state_dict(checkpoint['warmup_scheduler_state_dict'])
|
|
2019
|
+
if self.master_process:
|
|
2020
|
+
print("✓ Loaded warmup scheduler")
|
|
2021
|
+
except Exception as e:
|
|
2022
|
+
warnings.warn(f"Failed to load warmup scheduler state: {e}")
|
|
2023
|
+
|
|
2024
|
+
# verify configuration compatibility
|
|
2025
|
+
if 'embedding_dim' in checkpoint:
|
|
2026
|
+
if checkpoint['embedding_dim'] != self.clip_embedding_dim:
|
|
2027
|
+
warnings.warn(
|
|
2028
|
+
f"Embedding dimension mismatch: checkpoint={checkpoint['embedding_dim']}, current={self.clip_embedding_dim}")
|
|
2029
|
+
|
|
2030
|
+
if 'reduce_dim' in checkpoint:
|
|
2031
|
+
if checkpoint['reduce_dim'] != self.reduce_clip_embedding_dim:
|
|
2032
|
+
warnings.warn(
|
|
2033
|
+
f"Reduce dimension setting mismatch: checkpoint={checkpoint['reduce_dim']}, current={self.reduce_clip_embedding_dim}")
|
|
2034
|
+
|
|
2035
|
+
epoch = checkpoint.get('epoch', 0)
|
|
2036
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
2037
|
+
|
|
2038
|
+
if self.master_process:
|
|
2039
|
+
print(f"Successfully loaded checkpoint from {checkpoint_path}")
|
|
2040
|
+
print(f"Epoch: {epoch}, Loss: {loss:.4f}")
|
|
2041
|
+
|
|
2042
|
+
return epoch, loss
|
|
2043
|
+
|
|
2044
|
+
|
|
2045
|
+
def validate(self) -> Tuple[float, Optional[float], Optional[float], Optional[float], Optional[float], Optional[float]]:
|
|
2046
|
+
"""Validates the UnCLIP decoder model.
|
|
2047
|
+
|
|
2048
|
+
Computes validation loss and optional metrics (FID, MSE, PSNR, SSIM, LPIPS) by
|
|
2049
|
+
encoding images and texts, applying forward diffusion, predicting noise, and
|
|
2050
|
+
reconstructing images through reverse diffusion.
|
|
2051
|
+
|
|
2052
|
+
Returns
|
|
2053
|
+
-------
|
|
2054
|
+
val_loss : float
|
|
2055
|
+
Mean validation loss.
|
|
2056
|
+
fid_avg : float or None
|
|
2057
|
+
Average FID score, if computed.
|
|
2058
|
+
mse_avg : float or None
|
|
2059
|
+
Average MSE score, if computed.
|
|
2060
|
+
psnr_avg : float or None
|
|
2061
|
+
Average PSNR score, if computed.
|
|
2062
|
+
ssim_avg : float or None
|
|
2063
|
+
Average SSIM score, if computed.
|
|
2064
|
+
lpips_avg : float or None
|
|
2065
|
+
Average LPIPS score, if computed.
|
|
2066
|
+
"""
|
|
2067
|
+
|
|
2068
|
+
# set models to eval mode for evaluation
|
|
2069
|
+
self.decoder_model.eval() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj, decoder_projection to eval mode
|
|
2070
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
2071
|
+
self.clip_text_projection.eval()
|
|
2072
|
+
self.clip_image_projection.eval()
|
|
2073
|
+
if self.clip_model is not None:
|
|
2074
|
+
self.clip_model.eval()
|
|
2075
|
+
|
|
2076
|
+
val_losses = []
|
|
2077
|
+
fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
|
|
2078
|
+
|
|
2079
|
+
with torch.no_grad():
|
|
2080
|
+
for images, texts in self.val_loader:
|
|
2081
|
+
images = images.to(self.device, non_blocking=True)
|
|
2082
|
+
images_orig = images.clone()
|
|
2083
|
+
text_embeddings, image_embeddings = self._get_clip_embeddings(images, texts)
|
|
2084
|
+
text_embeddings, image_embeddings = self._apply_dimensionality_reduction(
|
|
2085
|
+
text_embeddings, image_embeddings
|
|
2086
|
+
)
|
|
2087
|
+
p_classifier_free = torch.rand(1).item()
|
|
2088
|
+
p_text_drop = torch.rand(1).item()
|
|
2089
|
+
predicted_noise, noise = self.decoder_model(
|
|
2090
|
+
image_embeddings,
|
|
2091
|
+
text_embeddings,
|
|
2092
|
+
images,
|
|
2093
|
+
texts,
|
|
2094
|
+
p_classifier_free,
|
|
2095
|
+
p_text_drop
|
|
2096
|
+
)
|
|
2097
|
+
loss = self.objective(predicted_noise, noise)
|
|
2098
|
+
val_losses.append(loss.item())
|
|
2099
|
+
|
|
2100
|
+
if self.metrics_ is not None and self.decoder_model.reverse_diffusion is not None:
|
|
2101
|
+
xt = torch.randn_like(images).to(self.device)
|
|
2102
|
+
for t in reversed(range(self.decoder_model.forward_diffusion.variance_scheduler.tau_num_steps)):
|
|
2103
|
+
time_steps = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
|
|
2104
|
+
prev_time_steps = torch.full((xt.shape[0],), max(t - 1, 0), device=self.device, dtype=torch.long)
|
|
2105
|
+
image_embeddings = self.decoder_model._apply_classifier_free_guidance(image_embeddings, p_classifier_free)
|
|
2106
|
+
text_embeddings = self.decoder_model._apply_text_dropout(text_embeddings, p_text_drop)
|
|
2107
|
+
c = self.decoder_model.clip_decoder_projection(image_embeddings)
|
|
2108
|
+
y_encoded = self.decoder_model._encode_text_with_glide(texts if text_embeddings is not None else None)
|
|
2109
|
+
context = self.decoder_model._concatenate_embeddings(y_encoded, c)
|
|
2110
|
+
clip_image_embedding = self.decoder_model.clip_time_projection(image_embeddings)
|
|
2111
|
+
predicted_noise = self.decoder_model.noise_predictor(xt, time_steps, context, clip_image_embedding)
|
|
2112
|
+
xt, _ = self.decoder_model.reverse_diffusion(xt, predicted_noise, time_steps, prev_time_steps)
|
|
2113
|
+
|
|
2114
|
+
x_hat = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
|
|
2115
|
+
|
|
2116
|
+
if self.normalize_clip_embeddings:
|
|
2117
|
+
x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
2118
|
+
x_orig = (images_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
2119
|
+
|
|
2120
|
+
metrics_result = self.metrics_.forward(x_orig, x_hat)
|
|
2121
|
+
fid = metrics_result[0] if getattr(self.metrics_, 'fid', False) else float('inf')
|
|
2122
|
+
mse = metrics_result[1] if getattr(self.metrics_, 'metrics', False) else None
|
|
2123
|
+
psnr = metrics_result[2] if getattr(self.metrics_, 'metrics', False) else None
|
|
2124
|
+
ssim = metrics_result[3] if getattr(self.metrics_, 'metrics', False) else None
|
|
2125
|
+
lpips_score = metrics_result[4] if getattr(self.metrics_, 'lpips', False) else None
|
|
2126
|
+
|
|
2127
|
+
if fid != float('inf'):
|
|
2128
|
+
fid_scores.append(fid)
|
|
2129
|
+
if mse is not None:
|
|
2130
|
+
mse_scores.append(mse)
|
|
2131
|
+
if psnr is not None:
|
|
2132
|
+
psnr_scores.append(psnr)
|
|
2133
|
+
if ssim is not None:
|
|
2134
|
+
ssim_scores.append(ssim)
|
|
2135
|
+
if lpips_score is not None:
|
|
2136
|
+
lpips_scores.append(lpips_score)
|
|
2137
|
+
|
|
2138
|
+
# compute averages
|
|
2139
|
+
val_loss = torch.tensor(val_losses).mean().item()
|
|
2140
|
+
fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
|
|
2141
|
+
mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
|
|
2142
|
+
psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
|
|
2143
|
+
ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
|
|
2144
|
+
lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
|
|
2145
|
+
|
|
2146
|
+
# synchronize metrics across GPUs in DDP mode
|
|
2147
|
+
if self.use_ddp:
|
|
2148
|
+
metrics = [val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg]
|
|
2149
|
+
metrics_tensors = [torch.tensor(m, device=self.device) if m is not None else torch.tensor(float('inf'), device=self.device) for m in metrics]
|
|
2150
|
+
for tensor in metrics_tensors:
|
|
2151
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.AVG)
|
|
2152
|
+
val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg = [t.item() if t.item() != float('inf') else (None if i > 1 else float('inf')) for i, t in enumerate(metrics_tensors)]
|
|
2153
|
+
|
|
2154
|
+
# return to training mode
|
|
2155
|
+
self.decoder_model.train() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj, decoder_projection to train mode
|
|
2156
|
+
if not self.decoder_model.forward_diffusion.variance_scheduler.trainable_beta:
|
|
2157
|
+
self.decoder_model.forward_diffusion.variance_scheduler.eval()
|
|
2158
|
+
self.decoder_model.reverse_diffusion.variance_scheduler.eval()
|
|
2159
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
2160
|
+
if self.finetune_clip_projections:
|
|
2161
|
+
self.clip_text_projection.train()
|
|
2162
|
+
self.clip_image_projection.train()
|
|
2163
|
+
else:
|
|
2164
|
+
self.clip_text_projection.eval()
|
|
2165
|
+
self.clip_image_projection.eval()
|
|
2166
|
+
if self.clip_model is not None:
|
|
2167
|
+
self.clip_model.eval()
|
|
2168
|
+
|
|
2169
|
+
return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
|
|
2170
|
+
|
|
2171
|
+
###==================================================================================================================###
|
|
2172
|
+
|
|
2173
|
+
class TrainUnCLIPPrior(nn.Module):
|
|
2174
|
+
"""Trainer for the UnCLIPTransformerPrior model.
|
|
2175
|
+
|
|
2176
|
+
Handles the training of the UnCLIP prior model to predict clean image embeddings from
|
|
2177
|
+
noisy image embeddings and text embeddings, with support for dimension reduction,
|
|
2178
|
+
mixed precision training, and distributed training.
|
|
2179
|
+
|
|
2180
|
+
Parameters
|
|
2181
|
+
----------
|
|
2182
|
+
`prior_model` : nn.Module
|
|
2183
|
+
The UnCLIP prior model to be trained (e.g., UnCLIPTransformerPrior).
|
|
2184
|
+
`clip_model` : nn.Module
|
|
2185
|
+
CLIP model for encoding text and images.
|
|
2186
|
+
`train_loader` : torch.utils.data.DataLoader
|
|
2187
|
+
DataLoader for training data.
|
|
2188
|
+
`optimizer` : torch.optim.Optimizer
|
|
2189
|
+
Optimizer for training the prior model.
|
|
2190
|
+
`objective` : Callable
|
|
2191
|
+
Loss function to compute the difference between predicted and target embeddings.
|
|
2192
|
+
`val_loader` : torch.utils.data.DataLoader, optional
|
|
2193
|
+
DataLoader for validation data, default None.
|
|
2194
|
+
`max_epochs` : int, optional
|
|
2195
|
+
Maximum number of training epochs (default: 1000).
|
|
2196
|
+
`device` : Union[str, torch.device], optional
|
|
2197
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
2198
|
+
`store_path` : str, optional
|
|
2199
|
+
Directory path to save model checkpoints, default None.
|
|
2200
|
+
`patience` : int, optional
|
|
2201
|
+
Number of epochs to wait for improvement before early stopping (default: 100).
|
|
2202
|
+
`warmup_epochs` : int, optional
|
|
2203
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
2204
|
+
`val_frequency` : int, optional
|
|
2205
|
+
Frequency (in epochs) for validation (default: 10).
|
|
2206
|
+
`use_ddp` : bool, optional
|
|
2207
|
+
Whether to use Distributed Data Parallel training (default: False).
|
|
2208
|
+
`num_grad_accumulation` : int, optional
|
|
2209
|
+
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
2210
|
+
`log_frequency` : int, optional
|
|
2211
|
+
Frequency (in epochs) for printing training progress (default: 1).
|
|
2212
|
+
`use_compilation` : bool, optional
|
|
2213
|
+
Whether to compile models for optimization (default: False).
|
|
2214
|
+
`embedding_output_range` : Tuple[float, float], optional
|
|
2215
|
+
Range for clamping output embeddings (default: (-1.0, 1.0)).
|
|
2216
|
+
`reduce_clip_embedding_dim` : bool, optional
|
|
2217
|
+
Whether to apply dimension reduction to embeddings (default: True).
|
|
2218
|
+
`transformer_embedding_dim` : int, optional
|
|
2219
|
+
Target dimensionality for reduced embeddings (default: 319).
|
|
2220
|
+
`normalize` : bool, optional
|
|
2221
|
+
Whether to normalize CLIP embeddings (default: True).
|
|
2222
|
+
"""
|
|
2223
|
+
|
|
2224
|
+
def __init__(
|
|
2225
|
+
self,
|
|
2226
|
+
prior_model: nn.Module,
|
|
2227
|
+
clip_model: nn.Module,
|
|
2228
|
+
train_loader: torch.utils.data.DataLoader,
|
|
2229
|
+
optimizer: torch.optim.Optimizer,
|
|
2230
|
+
objective: Callable,
|
|
2231
|
+
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
2232
|
+
max_epochs: int = 1000,
|
|
2233
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
2234
|
+
store_path: Optional[str] = None,
|
|
2235
|
+
patience: int = 100,
|
|
2236
|
+
warmup_epochs: int = 100,
|
|
2237
|
+
val_frequency: int = 10,
|
|
2238
|
+
use_ddp: bool = False,
|
|
2239
|
+
grad_accumulation_steps: int = 1,
|
|
2240
|
+
log_frequency: int = 1,
|
|
2241
|
+
use_compilation: bool = False,
|
|
2242
|
+
embedding_output_range: Tuple[float, float] = (-1.0, 1.0),
|
|
2243
|
+
reduce_clip_embedding_dim: bool = True,
|
|
2244
|
+
transformer_embedding_dim: int = 319,
|
|
2245
|
+
normalize_clip_embeddings: bool = True
|
|
2246
|
+
) -> None:
|
|
2247
|
+
super().__init__()
|
|
2248
|
+
|
|
2249
|
+
# training configuration
|
|
2250
|
+
self.use_ddp = use_ddp
|
|
2251
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
2252
|
+
if device is None:
|
|
2253
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
2254
|
+
elif isinstance(device, str):
|
|
2255
|
+
self.device = torch.device(device)
|
|
2256
|
+
else:
|
|
2257
|
+
self.device = device
|
|
2258
|
+
|
|
2259
|
+
# setup distributed training
|
|
2260
|
+
if self.use_ddp:
|
|
2261
|
+
self._setup_ddp()
|
|
2262
|
+
else:
|
|
2263
|
+
self._setup_single_gpu()
|
|
2264
|
+
|
|
2265
|
+
# core models
|
|
2266
|
+
self.prior_model = prior_model.to(self.device)
|
|
2267
|
+
self.clip_model = clip_model.to(self.device)
|
|
2268
|
+
|
|
2269
|
+
# training components
|
|
2270
|
+
self.optimizer = optimizer
|
|
2271
|
+
self.objective = objective
|
|
2272
|
+
self.train_loader = train_loader
|
|
2273
|
+
self.val_loader = val_loader
|
|
2274
|
+
|
|
2275
|
+
# training parameters
|
|
2276
|
+
self.max_epochs = max_epochs
|
|
2277
|
+
self.patience = patience
|
|
2278
|
+
self.val_frequency = val_frequency
|
|
2279
|
+
self.log_frequency = log_frequency
|
|
2280
|
+
self.use_compilation = use_compilation
|
|
2281
|
+
self.embedding_output_range = embedding_output_range
|
|
2282
|
+
self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
|
|
2283
|
+
self.normalize_clip_embeddings = normalize_clip_embeddings
|
|
2284
|
+
self.transformer_embedding_dim = transformer_embedding_dim
|
|
2285
|
+
|
|
2286
|
+
# checkpoint management
|
|
2287
|
+
self.store_path = store_path
|
|
2288
|
+
|
|
2289
|
+
# learning rate scheduling
|
|
2290
|
+
self.scheduler = ReduceLROnPlateau(
|
|
2291
|
+
self.optimizer,
|
|
2292
|
+
patience=self.patience,
|
|
2293
|
+
factor=0.5
|
|
2294
|
+
)
|
|
2295
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
2296
|
+
|
|
2297
|
+
|
|
2298
|
+
def _setup_ddp(self) -> None:
|
|
2299
|
+
"""Sets up Distributed Data Parallel training configuration.
|
|
2300
|
+
|
|
2301
|
+
Initializes the process group, sets up rank information, and configures the CUDA
|
|
2302
|
+
device for the current process.
|
|
2303
|
+
|
|
2304
|
+
Raises
|
|
2305
|
+
------
|
|
2306
|
+
ValueError
|
|
2307
|
+
If required DDP environment variables (RANK, LOCAL_RANK, WORLD_SIZE) are not set.
|
|
2308
|
+
RuntimeError
|
|
2309
|
+
If CUDA is not available when DDP is enabled.
|
|
2310
|
+
"""
|
|
2311
|
+
|
|
2312
|
+
required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
|
|
2313
|
+
for var in required_env_vars:
|
|
2314
|
+
if var not in os.environ:
|
|
2315
|
+
raise ValueError(f"DDP enabled but {var} environment variable not set")
|
|
2316
|
+
|
|
2317
|
+
# ensure CUDA is available for DDP
|
|
2318
|
+
if not torch.cuda.is_available():
|
|
2319
|
+
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
2320
|
+
|
|
2321
|
+
# initialize process group only if not already initialized
|
|
2322
|
+
if not torch.distributed.is_initialized():
|
|
2323
|
+
init_process_group(backend="nccl")
|
|
2324
|
+
|
|
2325
|
+
# get rank information
|
|
2326
|
+
self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
|
|
2327
|
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
|
|
2328
|
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
|
|
2329
|
+
|
|
2330
|
+
# set device and make it current
|
|
2331
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
2332
|
+
torch.cuda.set_device(self.device)
|
|
2333
|
+
|
|
2334
|
+
# master process handles logging, checkpointing, etc.
|
|
2335
|
+
self.master_process = self.ddp_rank == 0
|
|
2336
|
+
|
|
2337
|
+
if self.master_process:
|
|
2338
|
+
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
2339
|
+
|
|
2340
|
+
|
|
2341
|
+
def _setup_single_gpu(self) -> None:
|
|
2342
|
+
"""Sets up single GPU or CPU training configuration.
|
|
2343
|
+
|
|
2344
|
+
Configures the training setup for single-device operation, setting rank and process
|
|
2345
|
+
information for non-DDP training.
|
|
2346
|
+
"""
|
|
2347
|
+
self.ddp_rank = 0
|
|
2348
|
+
self.ddp_local_rank = 0
|
|
2349
|
+
self.ddp_world_size = 1
|
|
2350
|
+
self.master_process = True
|
|
2351
|
+
|
|
2352
|
+
@staticmethod
|
|
2353
|
+
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
|
|
2354
|
+
"""Creates a learning rate scheduler for warmup.
|
|
2355
|
+
|
|
2356
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
2357
|
+
optimizer's initial value over the specified warmup epochs.
|
|
2358
|
+
|
|
2359
|
+
Parameters
|
|
2360
|
+
----------
|
|
2361
|
+
`optimizer` : torch.optim.Optimizer
|
|
2362
|
+
Optimizer to apply the scheduler to.
|
|
2363
|
+
`warmup_epochs` : int
|
|
2364
|
+
Number of epochs for the warmup phase.
|
|
2365
|
+
|
|
2366
|
+
Returns
|
|
2367
|
+
-------
|
|
2368
|
+
lr_scheduler : torch.optim.lr_scheduler.LambdaLR
|
|
2369
|
+
Learning rate scheduler for warmup.
|
|
2370
|
+
"""
|
|
2371
|
+
def lr_lambda(epoch):
|
|
2372
|
+
return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
|
|
2373
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
2374
|
+
|
|
2375
|
+
def _wrap_models_for_ddp(self) -> None:
|
|
2376
|
+
"""Wraps the prior model with DistributedDataParallel for multi-GPU training.
|
|
2377
|
+
|
|
2378
|
+
Configures the prior model for DDP, setting device IDs and handling unused parameters.
|
|
2379
|
+
"""
|
|
2380
|
+
if self.use_ddp:
|
|
2381
|
+
# wrap prior with DDP
|
|
2382
|
+
self.prior_model = DDP(
|
|
2383
|
+
self.prior_model,
|
|
2384
|
+
device_ids=[self.ddp_local_rank],
|
|
2385
|
+
find_unused_parameters=True
|
|
2386
|
+
)
|
|
2387
|
+
|
|
2388
|
+
def _compile_models(self) -> None:
|
|
2389
|
+
"""Compiles models for optimization if supported.
|
|
2390
|
+
|
|
2391
|
+
Attempts to compile the prior model using torch.compile for performance optimization,
|
|
2392
|
+
with fallback to uncompiled models if compilation fails.
|
|
2393
|
+
"""
|
|
2394
|
+
if self.use_compilation:
|
|
2395
|
+
try:
|
|
2396
|
+
self.prior_model = torch.compile(self.prior_model)
|
|
2397
|
+
|
|
2398
|
+
if self.master_process:
|
|
2399
|
+
print("Models compiled successfully")
|
|
2400
|
+
except Exception as e:
|
|
2401
|
+
if self.master_process:
|
|
2402
|
+
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
2403
|
+
|
|
2404
|
+
def forward(self) -> Tuple[List[float], float]:
|
|
2405
|
+
"""Trains the UnCLIP prior model.
|
|
2406
|
+
|
|
2407
|
+
Executes the training loop, optimizing the prior model to predict clean image embeddings
|
|
2408
|
+
from noisy embeddings and text conditions, with support for validation, early stopping,
|
|
2409
|
+
and checkpointing.
|
|
2410
|
+
|
|
2411
|
+
Returns
|
|
2412
|
+
-------
|
|
2413
|
+
train_losses : List[float]
|
|
2414
|
+
List of mean training losses per epoch.
|
|
2415
|
+
best_val_loss : float
|
|
2416
|
+
Best validation or training loss achieved.
|
|
2417
|
+
"""
|
|
2418
|
+
# set models to training mode
|
|
2419
|
+
self.prior_model.train()
|
|
2420
|
+
|
|
2421
|
+
# compile and wrap models
|
|
2422
|
+
self._compile_models()
|
|
2423
|
+
self._wrap_models_for_ddp()
|
|
2424
|
+
|
|
2425
|
+
# initialize training components
|
|
2426
|
+
scaler = torch.GradScaler()
|
|
2427
|
+
train_losses = []
|
|
2428
|
+
best_val_loss = float("inf")
|
|
2429
|
+
wait = 0
|
|
2430
|
+
|
|
2431
|
+
# main training loop
|
|
2432
|
+
for epoch in range(self.max_epochs):
|
|
2433
|
+
# set epoch for distributed sampler if using DDP
|
|
2434
|
+
if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
|
|
2435
|
+
self.train_loader.sampler.set_epoch(epoch)
|
|
2436
|
+
|
|
2437
|
+
train_losses_epoch = []
|
|
2438
|
+
|
|
2439
|
+
# training step loop with gradient accumulation
|
|
2440
|
+
for step, (x, y) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
|
|
2441
|
+
x = x.to(self.device, non_blocking=True)
|
|
2442
|
+
|
|
2443
|
+
# forward pass with mixed precision
|
|
2444
|
+
with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
|
|
2445
|
+
loss = self._compute_training_loss(x, y)
|
|
2446
|
+
loss = loss / self.grad_accumulation_steps
|
|
2447
|
+
|
|
2448
|
+
# backward pass
|
|
2449
|
+
scaler.scale(loss).backward()
|
|
2450
|
+
|
|
2451
|
+
# optimizer step with gradient accumulation
|
|
2452
|
+
if (step + 1) % self.grad_accumulation_steps == 0:
|
|
2453
|
+
self._optimizer_step(scaler)
|
|
2454
|
+
# update learning rate (warmup scheduler)
|
|
2455
|
+
self.warmup_lr_scheduler.step()
|
|
2456
|
+
|
|
2457
|
+
# record loss (unscaled)
|
|
2458
|
+
train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
|
|
2459
|
+
|
|
2460
|
+
# compute and sync training loss
|
|
2461
|
+
mean_train_loss = self._compute_mean_loss(train_losses_epoch)
|
|
2462
|
+
train_losses.append(mean_train_loss)
|
|
2463
|
+
|
|
2464
|
+
# print training progress (only master process)
|
|
2465
|
+
if self.master_process and (epoch + 1) % self.log_frequency == 0:
|
|
2466
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
2467
|
+
print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}", end="")
|
|
2468
|
+
|
|
2469
|
+
# validation and checkpointing
|
|
2470
|
+
current_loss = mean_train_loss
|
|
2471
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
2472
|
+
val_loss = self.validate()
|
|
2473
|
+
current_loss = val_loss
|
|
2474
|
+
|
|
2475
|
+
if self.master_process:
|
|
2476
|
+
print(f" | Val Loss: {val_loss:.4f}")
|
|
2477
|
+
elif self.master_process:
|
|
2478
|
+
print()
|
|
2479
|
+
|
|
2480
|
+
# learning rate scheduling
|
|
2481
|
+
self.scheduler.step(current_loss)
|
|
2482
|
+
|
|
2483
|
+
# save checkpoint and early stopping
|
|
2484
|
+
if self.master_process:
|
|
2485
|
+
if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
|
|
2486
|
+
best_val_loss = current_loss
|
|
2487
|
+
wait = 0
|
|
2488
|
+
self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
|
|
2489
|
+
else:
|
|
2490
|
+
wait += 1
|
|
2491
|
+
if wait >= self.patience:
|
|
2492
|
+
print("Early stopping triggered")
|
|
2493
|
+
self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
|
|
2494
|
+
break
|
|
2495
|
+
|
|
2496
|
+
# cleanup
|
|
2497
|
+
if self.use_ddp:
|
|
2498
|
+
destroy_process_group()
|
|
2499
|
+
|
|
2500
|
+
return train_losses, best_val_loss
|
|
2501
|
+
|
|
2502
|
+
|
|
2503
|
+
def _compute_training_loss(self, images: torch.Tensor, texts: List[str]) -> torch.Tensor:
|
|
2504
|
+
"""Computes the training loss for the UnCLIP prior model.
|
|
2505
|
+
|
|
2506
|
+
Calculates the loss by encoding images and text with CLIP, applying forward diffusion,
|
|
2507
|
+
predicting clean embeddings, and comparing with target embeddings.
|
|
2508
|
+
|
|
2509
|
+
Parameters
|
|
2510
|
+
----------
|
|
2511
|
+
`images` : torch.Tensor
|
|
2512
|
+
Input images, shape (batch_size, channels, height, width).
|
|
2513
|
+
`texts` : List[str]
|
|
2514
|
+
List of text prompts for conditioning.
|
|
2515
|
+
|
|
2516
|
+
Returns
|
|
2517
|
+
-------
|
|
2518
|
+
loss : torch.Tensor
|
|
2519
|
+
Loss value computed between predicted and target embeddings.
|
|
2520
|
+
"""
|
|
2521
|
+
|
|
2522
|
+
with torch.no_grad():
|
|
2523
|
+
# encode text and image with CLIP
|
|
2524
|
+
text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize_clip_embeddings)
|
|
2525
|
+
image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize_clip_embeddings)
|
|
2526
|
+
|
|
2527
|
+
# reduce dimensionality (optional)
|
|
2528
|
+
if self.reduce_clip_embedding_dim:
|
|
2529
|
+
text_embeddings = self.prior_model.clip_text_projection(text_embeddings)
|
|
2530
|
+
image_embeddings = self.prior_model.clip_image_projection(image_embeddings)
|
|
2531
|
+
|
|
2532
|
+
# sample timestep t ~ Uniform(1, T)
|
|
2533
|
+
batch_size = image_embeddings.shape[0]
|
|
2534
|
+
timesteps = torch.randint(0, self.prior_model.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
|
|
2535
|
+
|
|
2536
|
+
# sample noise ε ~ N(0, I)
|
|
2537
|
+
noise = torch.randn_like(image_embeddings)
|
|
2538
|
+
|
|
2539
|
+
# compute noised embedding z_{i,t}
|
|
2540
|
+
noisy_image_embeddings = self.prior_model.forward_diffusion(image_embeddings, noise, timesteps)
|
|
2541
|
+
|
|
2542
|
+
# Predict unnoised embedding ẑ_i
|
|
2543
|
+
predicted_image_embeddings = self.prior_model(text_embeddings, noisy_image_embeddings, timesteps)
|
|
2544
|
+
|
|
2545
|
+
# transform back to original space if using dimension reduction
|
|
2546
|
+
if self.reduce_clip_embedding_dim:
|
|
2547
|
+
predicted_image_embeddings = self.prior_model.clip_image_projection.inverse_transform(predicted_image_embeddings)
|
|
2548
|
+
target_embeddings = self.prior_model.clip_image_projection.inverse_transform(image_embeddings)
|
|
2549
|
+
else:
|
|
2550
|
+
target_embeddings = image_embeddings
|
|
2551
|
+
|
|
2552
|
+
# compute loss L = ||ẑ_i - z_i||²
|
|
2553
|
+
loss = self.objective(predicted_image_embeddings, target_embeddings)
|
|
2554
|
+
return loss
|
|
2555
|
+
|
|
2556
|
+
def _optimizer_step(self, scaler: torch.GradScaler) -> None:
|
|
2557
|
+
"""Performs an optimizer step with gradient clipping.
|
|
2558
|
+
|
|
2559
|
+
Applies gradient clipping, updates the optimizer with scaled gradients, and resets
|
|
2560
|
+
gradients for the next iteration.
|
|
2561
|
+
|
|
2562
|
+
Parameters
|
|
2563
|
+
----------
|
|
2564
|
+
`scaler` : torch.GradScaler
|
|
2565
|
+
Gradient scaler for mixed precision training.
|
|
2566
|
+
"""
|
|
2567
|
+
scaler.unscale_(self.optimizer)
|
|
2568
|
+
|
|
2569
|
+
# gradient clipping
|
|
2570
|
+
torch.nn.utils.clip_grad_norm_(self.prior_model.parameters(), max_norm=1.0)
|
|
2571
|
+
|
|
2572
|
+
scaler.step(self.optimizer)
|
|
2573
|
+
scaler.update()
|
|
2574
|
+
self.optimizer.zero_grad()
|
|
2575
|
+
|
|
2576
|
+
def _compute_mean_loss(self, losses: List[float]) -> float:
|
|
2577
|
+
"""Computes the mean loss and synchronizes across processes if using DDP.
|
|
2578
|
+
|
|
2579
|
+
Calculates the mean of the provided loss values and performs an all-reduce operation
|
|
2580
|
+
in DDP mode to synchronize the loss across processes.
|
|
2581
|
+
|
|
2582
|
+
Parameters
|
|
2583
|
+
----------
|
|
2584
|
+
`losses` : List[float]
|
|
2585
|
+
List of loss values from a training or validation epoch.
|
|
2586
|
+
|
|
2587
|
+
Returns
|
|
2588
|
+
-------
|
|
2589
|
+
mean_loss : float
|
|
2590
|
+
Mean loss value, synchronized across processes if DDP is enabled.
|
|
2591
|
+
"""
|
|
2592
|
+
mean_loss = torch.tensor(losses).mean().item()
|
|
2593
|
+
|
|
2594
|
+
if self.use_ddp:
|
|
2595
|
+
loss_tensor = torch.tensor(mean_loss, device=self.device)
|
|
2596
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
|
2597
|
+
mean_loss = loss_tensor.item()
|
|
2598
|
+
|
|
2599
|
+
return mean_loss
|
|
2600
|
+
|
|
2601
|
+
|
|
2602
|
+
def validate(self) -> float:
|
|
2603
|
+
"""Validates the UnCLIP prior model.
|
|
2604
|
+
|
|
2605
|
+
Computes the validation loss by encoding images and text, applying forward diffusion,
|
|
2606
|
+
predicting clean embeddings, and comparing with target embeddings.
|
|
2607
|
+
|
|
2608
|
+
Returns
|
|
2609
|
+
-------
|
|
2610
|
+
val_loss : float
|
|
2611
|
+
Mean validation loss, synchronized across processes if DDP is enabled.
|
|
2612
|
+
"""
|
|
2613
|
+
|
|
2614
|
+
self.prior_model.eval()
|
|
2615
|
+
|
|
2616
|
+
val_losses = []
|
|
2617
|
+
|
|
2618
|
+
with torch.no_grad():
|
|
2619
|
+
for images, texts in self.val_loader:
|
|
2620
|
+
images = images.to(self.device, non_blocking=True)
|
|
2621
|
+
|
|
2622
|
+
# get embeddings
|
|
2623
|
+
text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize_clip_embeddings)
|
|
2624
|
+
image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize_clip_embeddings)
|
|
2625
|
+
original_image_embeddings = image_embeddings.clone()
|
|
2626
|
+
|
|
2627
|
+
if self.reduce_clip_embedding_dim:
|
|
2628
|
+
text_embeddings = self.prior_model.clip_text_projection(text_embeddings)
|
|
2629
|
+
image_embeddings = self.prior_model.clip_image_projection(image_embeddings)
|
|
2630
|
+
|
|
2631
|
+
# forward diffusion
|
|
2632
|
+
batch_size = image_embeddings.shape[0]
|
|
2633
|
+
timesteps = torch.randint(0, self.prior_model.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
|
|
2634
|
+
noise = torch.randn_like(image_embeddings)
|
|
2635
|
+
noisy_image_embeddings = self.prior_model.forward_diffusion(image_embeddings, noise, timesteps)
|
|
2636
|
+
|
|
2637
|
+
# predict
|
|
2638
|
+
predicted_embeddings = self.prior_model(text_embeddings, noisy_image_embeddings, timesteps)
|
|
2639
|
+
|
|
2640
|
+
if self.reduce_clip_embedding_dim:
|
|
2641
|
+
predicted_embeddings = self.prior_model.clip_image_projection.inverse_transform(predicted_embeddings)
|
|
2642
|
+
|
|
2643
|
+
# compute loss
|
|
2644
|
+
loss = self.objective(predicted_embeddings, original_image_embeddings)
|
|
2645
|
+
val_losses.append(loss.item())
|
|
2646
|
+
|
|
2647
|
+
|
|
2648
|
+
# compute averages
|
|
2649
|
+
val_loss = self._compute_mean_loss(val_losses)
|
|
2650
|
+
|
|
2651
|
+
# return to training mode
|
|
2652
|
+
self.prior_model.train()
|
|
2653
|
+
|
|
2654
|
+
return val_loss
|
|
2655
|
+
|
|
2656
|
+
|
|
2657
|
+
def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "", is_best: bool = False) -> None:
|
|
2658
|
+
"""Saves a model checkpoint.
|
|
2659
|
+
|
|
2660
|
+
Saves the state of the prior model and optimizer to a checkpoint file, with options
|
|
2661
|
+
for best model or early stopping checkpoints.
|
|
2662
|
+
|
|
2663
|
+
Parameters
|
|
2664
|
+
----------
|
|
2665
|
+
`epoch` : int
|
|
2666
|
+
Current epoch number.
|
|
2667
|
+
`loss` : float
|
|
2668
|
+
Current loss value.
|
|
2669
|
+
`suffix` : str, optional
|
|
2670
|
+
Suffix to append to the checkpoint filename, default "".
|
|
2671
|
+
`is_best` : bool, optional
|
|
2672
|
+
Whether to save the checkpoint as the best model, default False.
|
|
2673
|
+
"""
|
|
2674
|
+
try:
|
|
2675
|
+
# Get state dicts
|
|
2676
|
+
prior_state = (
|
|
2677
|
+
self.prior_model.module.state_dict() if self.use_ddp
|
|
2678
|
+
else self.prior_model.state_dict()
|
|
2679
|
+
)
|
|
2680
|
+
|
|
2681
|
+
checkpoint = {
|
|
2682
|
+
'epoch': epoch,
|
|
2683
|
+
'prior_model_state_dict': prior_state,
|
|
2684
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
2685
|
+
'loss': loss,
|
|
2686
|
+
'max_epochs': self.max_epochs,
|
|
2687
|
+
}
|
|
2688
|
+
|
|
2689
|
+
# create the directory if it doesn't exist
|
|
2690
|
+
os.makedirs(self.store_path, exist_ok=True)
|
|
2691
|
+
|
|
2692
|
+
# define the checkpoint filename
|
|
2693
|
+
if is_best:
|
|
2694
|
+
filename = "best_model.pth"
|
|
2695
|
+
else:
|
|
2696
|
+
filename = f"checkpoint_epoch_{epoch}{suffix}.pth"
|
|
2697
|
+
|
|
2698
|
+
# construct the full save path
|
|
2699
|
+
save_path = os.path.join(self.store_path, filename)
|
|
2700
|
+
|
|
2701
|
+
# save checkpoint
|
|
2702
|
+
torch.save(checkpoint, save_path)
|
|
2703
|
+
if self.master_process: # only print from the master process in DDP
|
|
2704
|
+
print(f"Checkpoint saved: {save_path}")
|
|
2705
|
+
|
|
2706
|
+
except Exception as e:
|
|
2707
|
+
print(f"Failed to save checkpoint: {e}")
|
|
2708
|
+
|
|
2709
|
+
def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
|
|
2710
|
+
"""Loads a model checkpoint to resume training.
|
|
2711
|
+
|
|
2712
|
+
Restores the prior model and optimizer states from a saved checkpoint, handling
|
|
2713
|
+
DDP compatibility for state dictionaries.
|
|
2714
|
+
|
|
2715
|
+
Parameters
|
|
2716
|
+
----------
|
|
2717
|
+
`checkpoint_path` : str
|
|
2718
|
+
Path to the checkpoint file.
|
|
2719
|
+
|
|
2720
|
+
Returns
|
|
2721
|
+
-------
|
|
2722
|
+
epoch : int
|
|
2723
|
+
The epoch at which the checkpoint was saved.
|
|
2724
|
+
loss : float
|
|
2725
|
+
The loss value at the checkpoint.
|
|
2726
|
+
"""
|
|
2727
|
+
try:
|
|
2728
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
2729
|
+
except FileNotFoundError:
|
|
2730
|
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
2731
|
+
|
|
2732
|
+
# load prior model
|
|
2733
|
+
if 'prior_model_state_dict' in checkpoint:
|
|
2734
|
+
state_dict = checkpoint['prior_model_state_dict']
|
|
2735
|
+
|
|
2736
|
+
# handle DDP state dict compatibility
|
|
2737
|
+
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
|
|
2738
|
+
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
|
|
2739
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
2740
|
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
2741
|
+
|
|
2742
|
+
self.prior_model.load_state_dict(state_dict)
|
|
2743
|
+
|
|
2744
|
+
# load optimizer
|
|
2745
|
+
if 'optimizer_state_dict' in checkpoint:
|
|
2746
|
+
try:
|
|
2747
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
2748
|
+
except Exception as e:
|
|
2749
|
+
warnings.warn(f"Failed to load optimizer state: {e}")
|
|
2750
|
+
|
|
2751
|
+
epoch = checkpoint.get('epoch', 0)
|
|
2752
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
2753
|
+
|
|
2754
|
+
if self.master_process:
|
|
2755
|
+
print(f"Loaded checkpoint from {checkpoint_path} (epoch {epoch}, loss {loss:.4f})")
|
|
2756
|
+
|
|
2757
|
+
return epoch, loss
|
|
2758
|
+
|
|
2759
|
+
###==================================================================================================================###
|
|
2760
|
+
|
|
2761
|
+
class SampleUnCLIP(nn.Module):
|
|
2762
|
+
"""Generates images using the UnCLIP model pipeline.
|
|
2763
|
+
|
|
2764
|
+
Combines a prior model, decoder model, CLIP model, and upsampler models to generate
|
|
2765
|
+
images from text prompts or noise. Performs diffusion-based sampling with classifier-free
|
|
2766
|
+
guidance in both prior and decoder stages, followed by upsampling to higher resolutions.
|
|
2767
|
+
|
|
2768
|
+
Parameters
|
|
2769
|
+
----------
|
|
2770
|
+
`prior_model` : nn.Module
|
|
2771
|
+
The UnCLIP prior model for generating image embeddings from text.
|
|
2772
|
+
`decoder_model` : nn.Module
|
|
2773
|
+
The UnCLIP decoder model for generating low-resolution images from embeddings.
|
|
2774
|
+
`clip_model` : nn.Module
|
|
2775
|
+
CLIP model for encoding text prompts into embeddings.
|
|
2776
|
+
`low_res_upsampler` : nn.Module
|
|
2777
|
+
First upsampler model for scaling images from 64x64 to 256x256.
|
|
2778
|
+
`high_res_upsampler` : nn.Module, optional
|
|
2779
|
+
Second upsampler model for scaling images from 256x256 to 1024x1024, default None.
|
|
2780
|
+
`device` : Union[torch.device, str], optional
|
|
2781
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
2782
|
+
`clip_embedding_dim` : int, optional
|
|
2783
|
+
Dimensionality of CLIP embeddings (default: 512).
|
|
2784
|
+
`prior_guidance_scale` : float, optional
|
|
2785
|
+
Classifier-free guidance scale for the prior model (default: 4.0).
|
|
2786
|
+
`decoder_guidance_scale` : float, optional
|
|
2787
|
+
Classifier-free guidance scale for the decoder model (default: 8.0).
|
|
2788
|
+
`batch_size` : int, optional
|
|
2789
|
+
Number of images to generate per batch (default: 1).
|
|
2790
|
+
`normalize` : bool, optional
|
|
2791
|
+
Whether to normalize CLIP embeddings (default: True).
|
|
2792
|
+
`prior_dim_reduction` : bool, optional
|
|
2793
|
+
Whether to apply dimensionality reduction in the prior model (default: True).
|
|
2794
|
+
`image_size` : Tuple[int, int, int], optional
|
|
2795
|
+
Size of the initial generated images (default: (3, 64, 64) for RGB 64x64).
|
|
2796
|
+
`use_high_res_upsampler` : bool, optional
|
|
2797
|
+
Whether to use the second upsampler for 1024x1024 output (default: True).
|
|
2798
|
+
`image_output_range` : Tuple[float, float], optional
|
|
2799
|
+
Range for clamping output images (default: (-1.0, 1.0)).
|
|
2800
|
+
"""
|
|
2801
|
+
def __init__(
|
|
2802
|
+
self,
|
|
2803
|
+
prior_model: nn.Module,
|
|
2804
|
+
decoder_model: nn.Module,
|
|
2805
|
+
clip_model: nn.Module,
|
|
2806
|
+
low_res_upsampler: nn.Module,
|
|
2807
|
+
high_res_upsampler: Optional[nn.Module] = None,
|
|
2808
|
+
device: Optional[Union[torch.device, str]] = None,
|
|
2809
|
+
clip_embedding_dim: int = 512, # CLIP embedding dimension
|
|
2810
|
+
prior_guidance_scale: float = 4.0,
|
|
2811
|
+
decoder_guidance_scale: float = 8.0,
|
|
2812
|
+
batch_size: int = 1,
|
|
2813
|
+
normalize_clip_embeddings: bool = True,
|
|
2814
|
+
prior_dim_reduction: bool = True,
|
|
2815
|
+
initial_image_size: Tuple[int, int, int] = (3, 64, 64),
|
|
2816
|
+
use_high_res_upsampler: bool = True,
|
|
2817
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0)
|
|
2818
|
+
) -> None:
|
|
2819
|
+
super().__init__()
|
|
2820
|
+
|
|
2821
|
+
if device is None:
|
|
2822
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
2823
|
+
elif isinstance(device, str):
|
|
2824
|
+
self.device = torch.device(device)
|
|
2825
|
+
else:
|
|
2826
|
+
self.device = device
|
|
2827
|
+
|
|
2828
|
+
self.prior_model = prior_model.to(self.device).eval()
|
|
2829
|
+
self.decoder_model = decoder_model.to(self.device).eval()
|
|
2830
|
+
self.clip_model = clip_model.to(self.device).eval()
|
|
2831
|
+
self.low_res_upsampler = low_res_upsampler.to(self.device).eval()
|
|
2832
|
+
self.high_res_upsampler = high_res_upsampler.to(self.device).eval() if high_res_upsampler else None
|
|
2833
|
+
|
|
2834
|
+
self.prior_guidance_scale = prior_guidance_scale
|
|
2835
|
+
self.decoder_guidance_scale = decoder_guidance_scale
|
|
2836
|
+
self.batch_size = batch_size
|
|
2837
|
+
self.normalize_clip_embeddings = normalize_clip_embeddings
|
|
2838
|
+
self.prior_dim_reduction = prior_dim_reduction
|
|
2839
|
+
self.clip_embedding_dim = clip_embedding_dim
|
|
2840
|
+
self.initial_image_size = initial_image_size
|
|
2841
|
+
self.use_high_res_upsampler = use_high_res_upsampler
|
|
2842
|
+
self.image_output_range = image_output_range
|
|
2843
|
+
self.images_256 = None
|
|
2844
|
+
self.images_1024 = None
|
|
2845
|
+
|
|
2846
|
+
def forward(
|
|
2847
|
+
self,
|
|
2848
|
+
prompts: Optional[Union[str, List]] = None,
|
|
2849
|
+
normalize_output: bool = True,
|
|
2850
|
+
save_images: bool = True,
|
|
2851
|
+
save_path: str = "unclip_generated"
|
|
2852
|
+
):
|
|
2853
|
+
"""Generates images from text prompts or noise using the UnCLIP pipeline.
|
|
2854
|
+
|
|
2855
|
+
Executes the full UnCLIP generation process: prior model generates image embeddings,
|
|
2856
|
+
decoder model generates 64x64 images, first upsampler scales to 256x256, and optional
|
|
2857
|
+
second upsampler scales to 1024x1024. Supports classifier-free guidance and saves
|
|
2858
|
+
generated images if requested.
|
|
2859
|
+
|
|
2860
|
+
Parameters
|
|
2861
|
+
----------
|
|
2862
|
+
`prompts` : Union[str, List], optional
|
|
2863
|
+
Text prompt(s) for conditional generation, default None (unconditional).
|
|
2864
|
+
`normalize_output` : bool, optional
|
|
2865
|
+
Whether to normalize output images to [0, 1] range (default: True).
|
|
2866
|
+
`save_images` : bool, optional
|
|
2867
|
+
Whether to save generated images to disk (default: True).
|
|
2868
|
+
`save_path` : str, optional
|
|
2869
|
+
Directory to save generated images (default: "unclip_generated").
|
|
2870
|
+
|
|
2871
|
+
Returns
|
|
2872
|
+
-------
|
|
2873
|
+
final_images : torch.Tensor
|
|
2874
|
+
Generated images, shape (batch_size, channels, height, width), either 256x256
|
|
2875
|
+
or 1024x1024 depending on use_second_upsampler.
|
|
2876
|
+
"""
|
|
2877
|
+
# initialize noise for prior sampling (image embedding space)
|
|
2878
|
+
embedding_noise = torch.randn((self.batch_size, self.clip_embedding_dim), device=self.device)
|
|
2879
|
+
|
|
2880
|
+
with torch.no_grad():
|
|
2881
|
+
|
|
2882
|
+
# ====== PRIOR STAGE: generate image embeddings from text ======
|
|
2883
|
+
# encode text prompt using CLIP
|
|
2884
|
+
text_embeddings = self.clip_model(data=prompts, data_type="text", normalize=self.normalize_clip_embeddings)
|
|
2885
|
+
current_embeddings = embedding_noise.clone()
|
|
2886
|
+
|
|
2887
|
+
# optionally reduce dimensionality for prior model
|
|
2888
|
+
if self.prior_dim_reduction:
|
|
2889
|
+
text_embeddings_reduced = self.prior_model.clip_text_projection(text_embeddings)
|
|
2890
|
+
current_embeddings_reduced = self.prior_model.clip_image_projection(current_embeddings)
|
|
2891
|
+
else:
|
|
2892
|
+
text_embeddings_reduced = text_embeddings
|
|
2893
|
+
current_embeddings_reduced = current_embeddings
|
|
2894
|
+
|
|
2895
|
+
# prior diffusion sampling loop
|
|
2896
|
+
for t in reversed(range(self.prior_model.forward_diffusion.variance_scheduler.tau_num_steps)):
|
|
2897
|
+
timesteps = torch.full((self.batch_size,), t, device=self.device)
|
|
2898
|
+
prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
|
|
2899
|
+
|
|
2900
|
+
# predict embeddings
|
|
2901
|
+
predicted_embeddings = self.prior_model(text_embeddings_reduced, current_embeddings_reduced, timesteps)
|
|
2902
|
+
|
|
2903
|
+
# apply guidance
|
|
2904
|
+
guided_embeddings = self.compute_prior_guided_prediction(
|
|
2905
|
+
predicted_embeddings, text_embeddings_reduced, current_embeddings_reduced, timesteps
|
|
2906
|
+
)
|
|
2907
|
+
|
|
2908
|
+
# update embeddings using reverse diffusion
|
|
2909
|
+
current_embeddings_reduced, _ = self.prior_model.reverse_diffusion(
|
|
2910
|
+
current_embeddings_reduced, guided_embeddings, timesteps, prev_timesteps
|
|
2911
|
+
)
|
|
2912
|
+
|
|
2913
|
+
# convert back to full embedding dimension if needed
|
|
2914
|
+
if self.prior_dim_reduction:
|
|
2915
|
+
final_image_embeddings = self.prior_model.clip_image_projection.inverse_transform(current_embeddings_reduced)
|
|
2916
|
+
else:
|
|
2917
|
+
final_image_embeddings = current_embeddings_reduced
|
|
2918
|
+
|
|
2919
|
+
# ====== DECODER STAGE: generate 64x64 images from embeddings ======
|
|
2920
|
+
# initialize noise for decoder sampling
|
|
2921
|
+
decoder_noise = torch.randn((self.batch_size, self.initial_image_size[0], self.initial_image_size[1], self.initial_image_size[2]), device=self.device)
|
|
2922
|
+
|
|
2923
|
+
# project image embeddings to 4 tokens
|
|
2924
|
+
projected_embeddings = self.decoder_model.clip_decoder_projection(final_image_embeddings)
|
|
2925
|
+
|
|
2926
|
+
# encode text with GLIDE/decoder's text encoder
|
|
2927
|
+
glide_text_embeddings = self.decoder_model._encode_text_with_glide(prompts)
|
|
2928
|
+
|
|
2929
|
+
# concatenate embeddings for context
|
|
2930
|
+
context = self.decoder_model._concatenate_embeddings(glide_text_embeddings, projected_embeddings)
|
|
2931
|
+
|
|
2932
|
+
current_images = decoder_noise
|
|
2933
|
+
|
|
2934
|
+
for t in reversed(range(self.decoder_model.forward_diffusion.variance_scheduler.tau_num_steps)):
|
|
2935
|
+
|
|
2936
|
+
timesteps = torch.full((self.batch_size,), t, device=self.device)
|
|
2937
|
+
prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
|
|
2938
|
+
|
|
2939
|
+
# predict noise
|
|
2940
|
+
predicted_noise = self.decoder_model.noise_predictor(current_images, timesteps, context, None)
|
|
2941
|
+
|
|
2942
|
+
# apply guidance
|
|
2943
|
+
guided_noise = self.compute_decoder_guided_prediction(
|
|
2944
|
+
predicted_noise, current_images, timesteps, context
|
|
2945
|
+
)
|
|
2946
|
+
|
|
2947
|
+
# update images using reverse diffusion
|
|
2948
|
+
current_images, _ = self.decoder_model.reverse_diffusion(
|
|
2949
|
+
current_images, guided_noise, timesteps, prev_timesteps
|
|
2950
|
+
)
|
|
2951
|
+
|
|
2952
|
+
generated_64x64 = current_images
|
|
2953
|
+
|
|
2954
|
+
# ====== FIRST UPSAMPLER: 64x64 -> 256x256 ======
|
|
2955
|
+
upsampled_256_noise = torch.randn((self.batch_size, self.initial_image_size[0], 256, 256), device=self.device)
|
|
2956
|
+
current_256_images = upsampled_256_noise
|
|
2957
|
+
|
|
2958
|
+
for t in reversed(range(self.low_res_upsampler.forward_diffusion.variance_scheduler.tau_num_steps)):
|
|
2959
|
+
timesteps = torch.full((self.batch_size,), t, device=self.device)
|
|
2960
|
+
prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
|
|
2961
|
+
|
|
2962
|
+
# predict noise for upsampling (conditioned on low-res image)
|
|
2963
|
+
predicted_noise = self.low_res_upsampler(current_256_images, timesteps, generated_64x64)
|
|
2964
|
+
|
|
2965
|
+
# update using reverse diffusion
|
|
2966
|
+
current_256_images, _ = self.low_res_upsampler.reverse_diffusion(
|
|
2967
|
+
current_256_images, predicted_noise, timesteps, prev_timesteps
|
|
2968
|
+
)
|
|
2969
|
+
|
|
2970
|
+
self.images_256 = current_256_images
|
|
2971
|
+
|
|
2972
|
+
# ====== SECOND UPSAMPLER: 256x256 -> 1024x1024 (if enabled) ======
|
|
2973
|
+
if self.use_high_res_upsampler and self.high_res_upsampler:
|
|
2974
|
+
upsampled_1024_noise = torch.randn((self.batch_size, self.initial_image_size[0], 1024, 1024), device=self.device)
|
|
2975
|
+
current_1024_images = upsampled_1024_noise
|
|
2976
|
+
|
|
2977
|
+
for t in reversed(range(self.high_res_upsampler.forward_diffusion.variance_scheduler.tau_num_steps)):
|
|
2978
|
+
timesteps = torch.full((self.batch_size,), t, device=self.device)
|
|
2979
|
+
prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
|
|
2980
|
+
|
|
2981
|
+
# predict noise for upsampling (conditioned on 256x256 image)
|
|
2982
|
+
predicted_noise = self.high_res_upsampler(current_1024_images, timesteps, self.images_256)
|
|
2983
|
+
|
|
2984
|
+
# update using reverse diffusion
|
|
2985
|
+
current_1024_images, _ = self.high_res_upsampler.reverse_diffusion(
|
|
2986
|
+
current_1024_images, predicted_noise, timesteps, prev_timesteps
|
|
2987
|
+
)
|
|
2988
|
+
|
|
2989
|
+
self.images_1024 = current_1024_images
|
|
2990
|
+
|
|
2991
|
+
# ====== POST-PROCESSING ======
|
|
2992
|
+
# normalize output to [0, 1] range if requested
|
|
2993
|
+
if normalize_output:
|
|
2994
|
+
final_256 = (self.images_256 - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
2995
|
+
final_1024 = None
|
|
2996
|
+
if self.images_1024 is not None:
|
|
2997
|
+
final_1024 = (self.images_1024 - self.image_output_range[0]) / (
|
|
2998
|
+
self.image_output_range[1] - self.image_output_range[0])
|
|
2999
|
+
else:
|
|
3000
|
+
final_256 = self.images_256
|
|
3001
|
+
final_1024 = self.images_1024
|
|
3002
|
+
|
|
3003
|
+
# save images if requested
|
|
3004
|
+
if save_images:
|
|
3005
|
+
os.makedirs(save_path, exist_ok=True)
|
|
3006
|
+
os.makedirs(os.path.join(save_path, "images_256"), exist_ok=True)
|
|
3007
|
+
if final_1024 is not None:
|
|
3008
|
+
os.makedirs(os.path.join(save_path, "images_1024"), exist_ok=True)
|
|
3009
|
+
|
|
3010
|
+
for i in range(self.batch_size):
|
|
3011
|
+
img_path_256 = os.path.join(save_path, "images_256", f"image_{i+1}.png")
|
|
3012
|
+
torchvision.utils.save_image(final_256[i], img_path_256)
|
|
3013
|
+
|
|
3014
|
+
if final_1024 is not None:
|
|
3015
|
+
img_path_1024 = os.path.join(save_path, "images_1024", f"image_{i+1}.png")
|
|
3016
|
+
torchvision.utils.save_image(final_1024[i], img_path_1024)
|
|
3017
|
+
|
|
3018
|
+
# return final images
|
|
3019
|
+
if final_1024 is not None:
|
|
3020
|
+
return final_1024
|
|
3021
|
+
else:
|
|
3022
|
+
return final_256
|
|
3023
|
+
|
|
3024
|
+
def compute_prior_guided_prediction(
|
|
3025
|
+
self,
|
|
3026
|
+
predicted_embeddings: torch.Tensor,
|
|
3027
|
+
text_embeddings: torch.Tensor,
|
|
3028
|
+
current_embeddings: torch.Tensor,
|
|
3029
|
+
timesteps: torch.Tensor
|
|
3030
|
+
) -> torch.Tensor:
|
|
3031
|
+
"""Computes classifier-free guidance for the prior model.
|
|
3032
|
+
|
|
3033
|
+
Combines conditioned and unconditioned predictions using the classifier-free guidance
|
|
3034
|
+
formula to enhance the quality of generated image embeddings.
|
|
3035
|
+
|
|
3036
|
+
Parameters
|
|
3037
|
+
----------
|
|
3038
|
+
`predicted_embeddings` : torch.Tensor
|
|
3039
|
+
Conditioned predicted embeddings, shape (batch_size, embedding_dim).
|
|
3040
|
+
`text_embeddings` : torch.Tensor
|
|
3041
|
+
Text embeddings from CLIP, shape (batch_size, embedding_dim).
|
|
3042
|
+
`current_embeddings` : torch.Tensor
|
|
3043
|
+
Current noisy embeddings, shape (batch_size, embedding_dim).
|
|
3044
|
+
`timesteps` : torch.Tensor
|
|
3045
|
+
Timestep indices, shape (batch_size,).
|
|
3046
|
+
|
|
3047
|
+
Returns
|
|
3048
|
+
-------
|
|
3049
|
+
guided_embeddings : torch.Tensor
|
|
3050
|
+
Guided embeddings, shape (batch_size, embedding_dim).
|
|
3051
|
+
"""
|
|
3052
|
+
# use zero embeddings for unconditional generation
|
|
3053
|
+
zero_text_embeddings = torch.zeros_like(text_embeddings)
|
|
3054
|
+
unconditioned_pred = self.prior_model(zero_text_embeddings, current_embeddings, timesteps)
|
|
3055
|
+
|
|
3056
|
+
# CFG formula: (1 + guidance_scale) * conditioned - guidance_scale * unconditioned
|
|
3057
|
+
return (1.0 + self.prior_guidance_scale) * predicted_embeddings - self.prior_guidance_scale * unconditioned_pred
|
|
3058
|
+
|
|
3059
|
+
def compute_decoder_guided_prediction(
|
|
3060
|
+
self,
|
|
3061
|
+
predicted_noise: torch.Tensor,
|
|
3062
|
+
current_images: torch.Tensor,
|
|
3063
|
+
timesteps: torch.Tensor,
|
|
3064
|
+
context: torch.Tensor
|
|
3065
|
+
) -> torch.Tensor:
|
|
3066
|
+
"""Computes classifier-free guidance for the decoder model.
|
|
3067
|
+
|
|
3068
|
+
Combines conditioned and unconditioned noise predictions using the classifier-free
|
|
3069
|
+
guidance formula to enhance the quality of generated images.
|
|
3070
|
+
|
|
3071
|
+
Parameters
|
|
3072
|
+
----------
|
|
3073
|
+
`predicted_noise` : torch.Tensor
|
|
3074
|
+
Conditioned predicted noise, shape (batch_size, channels, height, width).
|
|
3075
|
+
`current_images` : torch.Tensor
|
|
3076
|
+
Current noisy images, shape (batch_size, channels, height, width).
|
|
3077
|
+
`timesteps` : torch.Tensor
|
|
3078
|
+
Timestep indices, shape (batch_size,).
|
|
3079
|
+
`context` : torch.Tensor
|
|
3080
|
+
Context embeddings (concatenated GLIDE text and projected image embeddings),
|
|
3081
|
+
shape (batch_size, seq_len, embedding_dim).
|
|
3082
|
+
|
|
3083
|
+
Returns
|
|
3084
|
+
-------
|
|
3085
|
+
guided_noise : torch.Tensor
|
|
3086
|
+
Guided noise prediction, shape (batch_size, channels, height, width).
|
|
3087
|
+
"""
|
|
3088
|
+
zero_context = torch.zeros_like(context)
|
|
3089
|
+
unconditioned_noise = self.decoder_model.noise_predictor(current_images, timesteps, zero_context, None)
|
|
3090
|
+
|
|
3091
|
+
# CFG formula: (1 + guidance_scale) * conditioned - guidance_scale * unconditioned
|
|
3092
|
+
return (1.0 + self.decoder_guidance_scale) * predicted_noise - self.decoder_guidance_scale * unconditioned_noise
|
|
3093
|
+
|
|
3094
|
+
def to(self, device: Union[torch.device, str]) -> Self:
|
|
3095
|
+
"""Moves the module and all its components to the specified device.
|
|
3096
|
+
|
|
3097
|
+
Updates the device attribute and moves all sub-models (prior, decoder, CLIP,
|
|
3098
|
+
and upsamplers) to the specified device.
|
|
3099
|
+
|
|
3100
|
+
Parameters
|
|
3101
|
+
----------
|
|
3102
|
+
device : Union[torch.device, str]
|
|
3103
|
+
Target device for the module and its components.
|
|
3104
|
+
|
|
3105
|
+
Returns
|
|
3106
|
+
-------
|
|
3107
|
+
SampleUnCLIP
|
|
3108
|
+
The module moved to the specified device.
|
|
3109
|
+
"""
|
|
3110
|
+
if isinstance(device, str):
|
|
3111
|
+
device = torch.device(device)
|
|
3112
|
+
|
|
3113
|
+
self.device = device
|
|
3114
|
+
|
|
3115
|
+
# move all sub-models to the specified device
|
|
3116
|
+
self.prior_model.to(device)
|
|
3117
|
+
self.decoder_model.to(device)
|
|
3118
|
+
self.clip_model.to(device)
|
|
3119
|
+
self.low_res_upsampler.to(device)
|
|
3120
|
+
|
|
3121
|
+
if self.second_upsampler_model is not None:
|
|
3122
|
+
self.second_upsampler_model.to(device)
|
|
3123
|
+
|
|
3124
|
+
return super().to(device)
|
|
3125
|
+
|
|
3126
|
+
###==================================================================================================================###
|
|
3127
|
+
|
|
3128
|
+
class UpsamplerUnCLIP(nn.Module):
|
|
3129
|
+
"""Diffusion-based upsampler for UnCLIP models.
|
|
3130
|
+
|
|
3131
|
+
A U-Net-like model that upsamples low-resolution images to high-resolution images,
|
|
3132
|
+
conditioned on noisy high-resolution images and timesteps, using residual blocks,
|
|
3133
|
+
downsampling, and upsampling layers.
|
|
3134
|
+
|
|
3135
|
+
Parameters
|
|
3136
|
+
----------
|
|
3137
|
+
`forward_diffusion` : nn.Module
|
|
3138
|
+
Forward diffusion module (e.g., ForwardUnCLIP) for adding noise during training.
|
|
3139
|
+
`in_channels` : int, optional
|
|
3140
|
+
Number of input channels (default: 3, for RGB images).
|
|
3141
|
+
`out_channels` : int, optional
|
|
3142
|
+
Number of output channels (default: 3, for RGB noise prediction).
|
|
3143
|
+
`model_channels` : int, optional
|
|
3144
|
+
Base number of channels in the model (default: 192).
|
|
3145
|
+
`num_res_blocks` : int, optional
|
|
3146
|
+
Number of residual blocks per resolution level (default: 2).
|
|
3147
|
+
`channel_mult` : Tuple[int, ...], optional
|
|
3148
|
+
Channel multiplier for each resolution level (default: (1, 2, 4, 8)).
|
|
3149
|
+
`dropout` : float, optional
|
|
3150
|
+
Dropout probability for regularization (default: 0.1).
|
|
3151
|
+
`time_embed_dim` : int, optional
|
|
3152
|
+
Dimensionality of time embeddings (default: 768).
|
|
3153
|
+
`low_res_size` : int, optional
|
|
3154
|
+
Spatial size of low-resolution input (default: 64).
|
|
3155
|
+
`high_res_size` : int, optional
|
|
3156
|
+
Spatial size of high-resolution output (default: 256).
|
|
3157
|
+
"""
|
|
3158
|
+
|
|
3159
|
+
def __init__(
|
|
3160
|
+
self,
|
|
3161
|
+
forward_diffusion: nn.Module,
|
|
3162
|
+
reverse_diffusion: nn.Module,
|
|
3163
|
+
in_channels: int = 3,
|
|
3164
|
+
out_channels: int = 3,
|
|
3165
|
+
model_channels: int = 192,
|
|
3166
|
+
num_res_blocks: int = 2,
|
|
3167
|
+
channel_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
|
3168
|
+
dropout_rate: float = 0.1,
|
|
3169
|
+
time_embed_dim: int = 768,
|
|
3170
|
+
low_res_size: int = 64,
|
|
3171
|
+
high_res_size: int = 256,
|
|
3172
|
+
) -> None:
|
|
3173
|
+
super().__init__()
|
|
3174
|
+
|
|
3175
|
+
self.forward_diffusion = forward_diffusion # this will be used on training time inside 'TrainUpsamplerUnCLIP'
|
|
3176
|
+
self.reverse_diffusion = reverse_diffusion # this module will be used in inference time
|
|
3177
|
+
self.in_channels = in_channels
|
|
3178
|
+
self.out_channels = out_channels
|
|
3179
|
+
self.model_channels = model_channels
|
|
3180
|
+
self.num_res_blocks = num_res_blocks
|
|
3181
|
+
self.low_res_size = low_res_size
|
|
3182
|
+
self.high_res_size = high_res_size
|
|
3183
|
+
|
|
3184
|
+
# time embedding
|
|
3185
|
+
self.time_embed = nn.Sequential(
|
|
3186
|
+
SinusoidalPositionalEmbedding(model_channels),
|
|
3187
|
+
nn.Linear(model_channels, time_embed_dim),
|
|
3188
|
+
nn.SiLU(),
|
|
3189
|
+
nn.Linear(time_embed_dim, time_embed_dim),
|
|
3190
|
+
)
|
|
3191
|
+
|
|
3192
|
+
# Input projection
|
|
3193
|
+
# concatenate noisy high-res and upsampled low-res
|
|
3194
|
+
self.input_proj = nn.Conv2d(in_channels * 2, model_channels, 3, padding=1)
|
|
3195
|
+
|
|
3196
|
+
# encoder (downsampling path)
|
|
3197
|
+
self.encoder_blocks = nn.ModuleList()
|
|
3198
|
+
self.downsample_blocks = nn.ModuleList()
|
|
3199
|
+
|
|
3200
|
+
ch = model_channels
|
|
3201
|
+
for level, mult in enumerate(channel_mult):
|
|
3202
|
+
for _ in range(num_res_blocks):
|
|
3203
|
+
self.encoder_blocks.append(
|
|
3204
|
+
ResBlock(ch, model_channels * mult, time_embed_dim, dropout_rate)
|
|
3205
|
+
)
|
|
3206
|
+
ch = model_channels * mult
|
|
3207
|
+
|
|
3208
|
+
if level != len(channel_mult) - 1:
|
|
3209
|
+
self.downsample_blocks.append(DownsampleBlock(ch, ch))
|
|
3210
|
+
|
|
3211
|
+
# middle blocks
|
|
3212
|
+
self.middle_blocks = nn.ModuleList([
|
|
3213
|
+
ResBlock(ch, ch, time_embed_dim, dropout_rate),
|
|
3214
|
+
ResBlock(ch, ch, time_embed_dim, dropout_rate),
|
|
3215
|
+
])
|
|
3216
|
+
|
|
3217
|
+
# decoder (upsampling path)
|
|
3218
|
+
self.decoder_blocks = nn.ModuleList()
|
|
3219
|
+
self.upsample_blocks = nn.ModuleList()
|
|
3220
|
+
|
|
3221
|
+
for level, mult in reversed(list(enumerate(channel_mult))):
|
|
3222
|
+
for i in range(num_res_blocks + 1):
|
|
3223
|
+
# skip connections double the input channels
|
|
3224
|
+
in_ch = ch + (model_channels * mult if i == 0 else 0)
|
|
3225
|
+
out_ch = model_channels * mult
|
|
3226
|
+
|
|
3227
|
+
self.decoder_blocks.append(
|
|
3228
|
+
ResBlock(in_ch, out_ch, time_embed_dim, dropout_rate)
|
|
3229
|
+
)
|
|
3230
|
+
ch = out_ch
|
|
3231
|
+
|
|
3232
|
+
if level != 0:
|
|
3233
|
+
self.upsample_blocks.append(UpsampleBlock(ch, ch))
|
|
3234
|
+
|
|
3235
|
+
# output projection
|
|
3236
|
+
self.output_proj = nn.Sequential(
|
|
3237
|
+
nn.GroupNorm(8, ch),
|
|
3238
|
+
nn.SiLU(),
|
|
3239
|
+
nn.Conv2d(ch, out_channels, 3, padding=1),
|
|
3240
|
+
)
|
|
3241
|
+
|
|
3242
|
+
def forward(self, x_high: torch.Tensor, t: torch.Tensor, x_low: torch.Tensor) -> torch.Tensor:
|
|
3243
|
+
"""Predicts noise for the upsampling process.
|
|
3244
|
+
|
|
3245
|
+
Processes a noisy high-resolution image and a low-resolution conditioning image,
|
|
3246
|
+
conditioned on timesteps, to predict the noise component for denoising.
|
|
3247
|
+
|
|
3248
|
+
Parameters
|
|
3249
|
+
----------
|
|
3250
|
+
`x_high` : torch.Tensor
|
|
3251
|
+
Noisy high-resolution image, shape (batch_size, in_channels, high_res_size, high_res_size).
|
|
3252
|
+
`t` : torch.Tensor
|
|
3253
|
+
Timestep indices, shape (batch_size,).
|
|
3254
|
+
`x_low` : torch.Tensor
|
|
3255
|
+
Low-resolution conditioning image, shape (batch_size, in_channels, low_res_size, low_res_size).
|
|
3256
|
+
|
|
3257
|
+
Returns
|
|
3258
|
+
-------
|
|
3259
|
+
out : torch.Tensor
|
|
3260
|
+
Predicted noise, shape (batch_size, out_channels, high_res_size, high_res_size).
|
|
3261
|
+
"""
|
|
3262
|
+
# upsample low-resolution image to match high-resolution
|
|
3263
|
+
x_low_upsampled = F.interpolate(
|
|
3264
|
+
x_low,
|
|
3265
|
+
size=(x_high.shape[-2], x_high.shape[-1]),
|
|
3266
|
+
mode='bicubic',
|
|
3267
|
+
align_corners=False
|
|
3268
|
+
)
|
|
3269
|
+
|
|
3270
|
+
# concatenate noisy high-res and upsampled low-res
|
|
3271
|
+
x = torch.cat([x_high, x_low_upsampled], dim=1)
|
|
3272
|
+
|
|
3273
|
+
# time embedding
|
|
3274
|
+
time_emb = self.time_embed(t.float()) # Ensure float for embedding
|
|
3275
|
+
|
|
3276
|
+
# input projection
|
|
3277
|
+
h = self.input_proj(x)
|
|
3278
|
+
|
|
3279
|
+
# store skip connections
|
|
3280
|
+
skip_connections = []
|
|
3281
|
+
|
|
3282
|
+
# encoder
|
|
3283
|
+
for i, block in enumerate(self.encoder_blocks):
|
|
3284
|
+
h = block(h, time_emb)
|
|
3285
|
+
if (i + 1) % self.num_res_blocks == 0:
|
|
3286
|
+
skip_connections.append(h)
|
|
3287
|
+
downsample_idx = (i + 1) // self.num_res_blocks - 1
|
|
3288
|
+
if downsample_idx < len(self.downsample_blocks):
|
|
3289
|
+
h = self.downsample_blocks[downsample_idx](h)
|
|
3290
|
+
|
|
3291
|
+
# middle
|
|
3292
|
+
for i, block in enumerate(self.middle_blocks):
|
|
3293
|
+
h = block(h, time_emb)
|
|
3294
|
+
|
|
3295
|
+
# decoder
|
|
3296
|
+
upsample_idx = 0
|
|
3297
|
+
for i, block in enumerate(self.decoder_blocks):
|
|
3298
|
+
# add skip connection
|
|
3299
|
+
if i % (self.num_res_blocks + 1) == 0 and skip_connections:
|
|
3300
|
+
skip = skip_connections.pop()
|
|
3301
|
+
h = torch.cat([h, skip], dim=1)
|
|
3302
|
+
|
|
3303
|
+
h = block(h, time_emb)
|
|
3304
|
+
|
|
3305
|
+
# upsample at the end of each resolution level
|
|
3306
|
+
if ((i + 1) % (self.num_res_blocks + 1) == 0 and
|
|
3307
|
+
upsample_idx < len(self.upsample_blocks)):
|
|
3308
|
+
h = self.upsample_blocks[upsample_idx](h)
|
|
3309
|
+
upsample_idx += 1
|
|
3310
|
+
|
|
3311
|
+
# output projection
|
|
3312
|
+
out = self.output_proj(h)
|
|
3313
|
+
|
|
3314
|
+
return out
|
|
3315
|
+
|
|
3316
|
+
|
|
3317
|
+
|
|
3318
|
+
class SinusoidalPositionalEmbedding(nn.Module):
|
|
3319
|
+
"""Sinusoidal positional embedding for timesteps.
|
|
3320
|
+
|
|
3321
|
+
Generates sinusoidal embeddings for timesteps to condition the upsampler on the
|
|
3322
|
+
diffusion process stage.
|
|
3323
|
+
|
|
3324
|
+
Parameters
|
|
3325
|
+
----------
|
|
3326
|
+
`dim` : int
|
|
3327
|
+
Dimensionality of the embedding.
|
|
3328
|
+
"""
|
|
3329
|
+
|
|
3330
|
+
def __init__(self, dim: int):
|
|
3331
|
+
super().__init__()
|
|
3332
|
+
self.dim = dim
|
|
3333
|
+
|
|
3334
|
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
|
3335
|
+
"""Generates sinusoidal embeddings for timesteps.
|
|
3336
|
+
|
|
3337
|
+
Parameters
|
|
3338
|
+
----------
|
|
3339
|
+
`timesteps` : torch.Tensor
|
|
3340
|
+
Timestep indices, shape (batch_size,).
|
|
3341
|
+
|
|
3342
|
+
Returns
|
|
3343
|
+
-------
|
|
3344
|
+
embeddings : torch.Tensor
|
|
3345
|
+
Sinusoidal embeddings, shape (batch_size, dim).
|
|
3346
|
+
"""
|
|
3347
|
+
device = timesteps.device
|
|
3348
|
+
half_dim = self.dim // 2
|
|
3349
|
+
embeddings = math.log(10000) / (half_dim - 1)
|
|
3350
|
+
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
|
3351
|
+
embeddings = timesteps[:, None] * embeddings[None, :]
|
|
3352
|
+
embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
|
|
3353
|
+
return embeddings
|
|
3354
|
+
|
|
3355
|
+
|
|
3356
|
+
class ResBlock(nn.Module):
|
|
3357
|
+
"""Residual block with time embedding and conditioning.
|
|
3358
|
+
|
|
3359
|
+
A convolutional residual block with group normalization, time embedding conditioning,
|
|
3360
|
+
and optional scale-shift normalization, used in the UnCLIP upsampler.
|
|
3361
|
+
|
|
3362
|
+
Parameters
|
|
3363
|
+
----------
|
|
3364
|
+
`in_channels` : int
|
|
3365
|
+
Number of input channels.
|
|
3366
|
+
`out_channels` : int
|
|
3367
|
+
Number of output channels.
|
|
3368
|
+
`time_embed_dim` : int
|
|
3369
|
+
Dimensionality of time embeddings.
|
|
3370
|
+
`dropout` : float, optional
|
|
3371
|
+
Dropout probability (default: 0.1).
|
|
3372
|
+
`use_scale_shift_norm` : bool, optional
|
|
3373
|
+
Whether to use scale-shift normalization for time embeddings (default: True).
|
|
3374
|
+
"""
|
|
3375
|
+
def __init__(self, in_channels: int, out_channels: int, time_embed_dim: int,
|
|
3376
|
+
dropout: float = 0.1, use_scale_shift_norm: bool = True):
|
|
3377
|
+
super().__init__()
|
|
3378
|
+
self.use_scale_shift_norm = use_scale_shift_norm
|
|
3379
|
+
|
|
3380
|
+
self.in_layers = nn.Sequential(
|
|
3381
|
+
nn.GroupNorm(8, in_channels),
|
|
3382
|
+
nn.SiLU(),
|
|
3383
|
+
nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
|
3384
|
+
)
|
|
3385
|
+
|
|
3386
|
+
self.time_emb_proj = nn.Sequential(
|
|
3387
|
+
nn.SiLU(),
|
|
3388
|
+
nn.Linear(time_embed_dim, out_channels * 2 if use_scale_shift_norm else out_channels)
|
|
3389
|
+
)
|
|
3390
|
+
|
|
3391
|
+
self.out_norm = nn.GroupNorm(8, out_channels)
|
|
3392
|
+
self.out_rest = nn.Sequential(
|
|
3393
|
+
nn.SiLU(),
|
|
3394
|
+
nn.Dropout(dropout),
|
|
3395
|
+
nn.Conv2d(out_channels, out_channels, 3, padding=1)
|
|
3396
|
+
)
|
|
3397
|
+
|
|
3398
|
+
if in_channels != out_channels:
|
|
3399
|
+
self.skip_connection = nn.Conv2d(in_channels, out_channels, 1)
|
|
3400
|
+
else:
|
|
3401
|
+
self.skip_connection = nn.Identity()
|
|
3402
|
+
|
|
3403
|
+
def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
|
|
3404
|
+
"""Processes input through the residual block with time conditioning.
|
|
3405
|
+
|
|
3406
|
+
Parameters
|
|
3407
|
+
----------
|
|
3408
|
+
`x` : torch.Tensor
|
|
3409
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
3410
|
+
`time_emb` : torch.Tensor
|
|
3411
|
+
Time embeddings, shape (batch_size, time_embed_dim).
|
|
3412
|
+
|
|
3413
|
+
Returns
|
|
3414
|
+
-------
|
|
3415
|
+
out : torch.Tensor
|
|
3416
|
+
Output tensor, shape (batch_size, out_channels, height, width).
|
|
3417
|
+
"""
|
|
3418
|
+
h = self.in_layers(x)
|
|
3419
|
+
|
|
3420
|
+
# apply time embedding
|
|
3421
|
+
emb_out = self.time_emb_proj(time_emb)[:, :, None, None]
|
|
3422
|
+
|
|
3423
|
+
if self.use_scale_shift_norm:
|
|
3424
|
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
|
3425
|
+
h = self.out_norm(h) * (1 + scale) + shift
|
|
3426
|
+
h = self.out_rest(h)
|
|
3427
|
+
else:
|
|
3428
|
+
h = h + emb_out
|
|
3429
|
+
h = self.out_norm(h)
|
|
3430
|
+
h = self.out_rest(h)
|
|
3431
|
+
|
|
3432
|
+
return h + self.skip_connection(x)
|
|
3433
|
+
|
|
3434
|
+
|
|
3435
|
+
class UpsampleBlock(nn.Module):
|
|
3436
|
+
"""Upsampling block using transposed convolution.
|
|
3437
|
+
|
|
3438
|
+
Increases the spatial resolution of the input tensor using a transposed convolution.
|
|
3439
|
+
|
|
3440
|
+
Parameters
|
|
3441
|
+
----------
|
|
3442
|
+
`in_channels` : int
|
|
3443
|
+
Number of input channels.
|
|
3444
|
+
`out_channels` : int
|
|
3445
|
+
Number of output channels.
|
|
3446
|
+
"""
|
|
3447
|
+
|
|
3448
|
+
def __init__(self, in_channels: int, out_channels: int):
|
|
3449
|
+
super().__init__()
|
|
3450
|
+
self.conv = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)
|
|
3451
|
+
|
|
3452
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
3453
|
+
"""Upsamples the input tensor.
|
|
3454
|
+
|
|
3455
|
+
Parameters
|
|
3456
|
+
----------
|
|
3457
|
+
`x` : torch.Tensor
|
|
3458
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
3459
|
+
|
|
3460
|
+
Returns
|
|
3461
|
+
-------
|
|
3462
|
+
out : torch.Tensor
|
|
3463
|
+
Upsampled tensor, shape (batch_size, out_channels, height*2, width*2).
|
|
3464
|
+
"""
|
|
3465
|
+
return self.conv(x)
|
|
3466
|
+
|
|
3467
|
+
|
|
3468
|
+
class DownsampleBlock(nn.Module):
|
|
3469
|
+
"""Downsampling block using strided convolution.
|
|
3470
|
+
|
|
3471
|
+
Reduces the spatial resolution of the input tensor using a strided convolution.
|
|
3472
|
+
|
|
3473
|
+
Parameters
|
|
3474
|
+
----------
|
|
3475
|
+
`in_channels` : int
|
|
3476
|
+
Number of input channels.
|
|
3477
|
+
`out_channels` : int
|
|
3478
|
+
Number of output channels.
|
|
3479
|
+
"""
|
|
3480
|
+
|
|
3481
|
+
def __init__(self, in_channels: int, out_channels: int):
|
|
3482
|
+
super().__init__()
|
|
3483
|
+
self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)
|
|
3484
|
+
|
|
3485
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
3486
|
+
"""Downsamples the input tensor.
|
|
3487
|
+
|
|
3488
|
+
Parameters
|
|
3489
|
+
----------
|
|
3490
|
+
`x` : torch.Tensor
|
|
3491
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
3492
|
+
|
|
3493
|
+
Returns
|
|
3494
|
+
-------
|
|
3495
|
+
out : torch.Tensor
|
|
3496
|
+
Downsampled tensor, shape (batch_size, out_channels, height//2, width//2).
|
|
3497
|
+
"""
|
|
3498
|
+
return self.conv(x)
|
|
3499
|
+
|
|
3500
|
+
###==================================================================================================================###
|
|
3501
|
+
|
|
3502
|
+
class TrainUpsamplerUnCLIP(nn.Module):
|
|
3503
|
+
"""Trainer for the UnCLIP upsampler model.
|
|
3504
|
+
|
|
3505
|
+
Orchestrates the training of the UnCLIP upsampler model, integrating forward diffusion,
|
|
3506
|
+
noise prediction, and low-resolution image conditioning with optional corruption (Gaussian
|
|
3507
|
+
blur or BSR degradation). Supports mixed precision, gradient accumulation, DDP, and
|
|
3508
|
+
comprehensive training utilities.
|
|
3509
|
+
|
|
3510
|
+
Parameters
|
|
3511
|
+
----------
|
|
3512
|
+
`upsampler_model` : nn.Module
|
|
3513
|
+
The UnCLIP upsampler model (e.g., UpsamplerUnCLIP) to be trained.
|
|
3514
|
+
`train_loader` : torch.utils.data.DataLoader
|
|
3515
|
+
DataLoader for training data, providing low- and high-resolution image pairs.
|
|
3516
|
+
`optimizer` : torch.optim.Optimizer
|
|
3517
|
+
Optimizer for training the upsampler model.
|
|
3518
|
+
`objective` : Callable
|
|
3519
|
+
Loss function to compute the difference between predicted and target noise.
|
|
3520
|
+
`val_loader` : torch.utils.data.DataLoader, optional
|
|
3521
|
+
DataLoader for validation data, default None.
|
|
3522
|
+
`max_epochs` : int, optional
|
|
3523
|
+
Maximum number of training epochs (default: 1000).
|
|
3524
|
+
`device` : Union[str, torch.device], optional
|
|
3525
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
3526
|
+
`store_path` : str, optional
|
|
3527
|
+
Directory to save model checkpoints (default: "unclip_upsampler").
|
|
3528
|
+
`patience` : int, optional
|
|
3529
|
+
Number of epochs to wait for improvement before early stopping (default: 100).
|
|
3530
|
+
`warmup_epochs` : int, optional
|
|
3531
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
3532
|
+
`val_frequency` : int, optional
|
|
3533
|
+
Frequency (in epochs) for validation (default: 10).
|
|
3534
|
+
`use_ddp` : bool, optional
|
|
3535
|
+
Whether to use Distributed Data Parallel training (default: False).
|
|
3536
|
+
`grad_accumulation_steps` : int, optional
|
|
3537
|
+
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
3538
|
+
`log_frequency` : int, optional
|
|
3539
|
+
Frequency (in epochs) for printing progress (default: 1).
|
|
3540
|
+
`use_compilation` : bool, optional
|
|
3541
|
+
Whether to compile the model using torch.compile (default: False).
|
|
3542
|
+
`image_output_range` : Tuple[float, float], optional
|
|
3543
|
+
Range for clamping output images (default: (-1.0, 1.0)).
|
|
3544
|
+
`normalize_image_outputs` : bool, optional
|
|
3545
|
+
Whether to normalize inputs/outputs (default: True).
|
|
3546
|
+
`use_autocast` : bool, optional
|
|
3547
|
+
Whether to use automatic mixed precision training (default: True).
|
|
3548
|
+
"""
|
|
3549
|
+
|
|
3550
|
+
def __init__(
|
|
3551
|
+
self,
|
|
3552
|
+
upsampler_model: nn.Module,
|
|
3553
|
+
train_loader: torch.utils.data.DataLoader,
|
|
3554
|
+
optimizer: torch.optim.Optimizer,
|
|
3555
|
+
objective: Callable,
|
|
3556
|
+
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
3557
|
+
max_epochs: int = 1000,
|
|
3558
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
3559
|
+
store_path: str = "unclip_upsampler",
|
|
3560
|
+
patience: int = 100,
|
|
3561
|
+
warmup_epochs: int = 100,
|
|
3562
|
+
val_frequency: int = 10,
|
|
3563
|
+
use_ddp: bool = False,
|
|
3564
|
+
grad_accumulation_steps: int = 1,
|
|
3565
|
+
log_frequency: int = 1,
|
|
3566
|
+
use_compilation: bool = False,
|
|
3567
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0),
|
|
3568
|
+
normalize_image_outputs: bool = True,
|
|
3569
|
+
use_autocast: bool = True
|
|
3570
|
+
) -> None:
|
|
3571
|
+
super().__init__()
|
|
3572
|
+
|
|
3573
|
+
# training configuration
|
|
3574
|
+
self.use_ddp = use_ddp
|
|
3575
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
3576
|
+
self.use_compilation = use_compilation
|
|
3577
|
+
self.use_autocast = use_autocast # Store autocast flag
|
|
3578
|
+
|
|
3579
|
+
# device initialization
|
|
3580
|
+
if device is None:
|
|
3581
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
3582
|
+
elif isinstance(device, str):
|
|
3583
|
+
self.device = torch.device(device)
|
|
3584
|
+
else:
|
|
3585
|
+
self.device = device
|
|
3586
|
+
|
|
3587
|
+
# setup distributed training
|
|
3588
|
+
if self.use_ddp:
|
|
3589
|
+
self._setup_ddp()
|
|
3590
|
+
else:
|
|
3591
|
+
self._setup_single_gpu()
|
|
3592
|
+
|
|
3593
|
+
# compile and wrap models
|
|
3594
|
+
self._compile_models()
|
|
3595
|
+
self._wrap_models_for_ddp()
|
|
3596
|
+
|
|
3597
|
+
# core model
|
|
3598
|
+
self.upsampler_model = upsampler_model.to(self.device)
|
|
3599
|
+
self.num_timesteps = self.upsampler_model.forward_diffusion.variance_scheduler.num_steps
|
|
3600
|
+
|
|
3601
|
+
# training components
|
|
3602
|
+
self.optimizer = optimizer
|
|
3603
|
+
self.objective = objective
|
|
3604
|
+
self.train_loader = train_loader
|
|
3605
|
+
self.val_loader = val_loader
|
|
3606
|
+
|
|
3607
|
+
# training parameters
|
|
3608
|
+
self.max_epochs = max_epochs
|
|
3609
|
+
self.patience = patience
|
|
3610
|
+
self.val_frequency = val_frequency
|
|
3611
|
+
self.log_frequency = log_frequency
|
|
3612
|
+
self.image_output_range = image_output_range
|
|
3613
|
+
self.normalize_image_outputs = normalize_image_outputs
|
|
3614
|
+
|
|
3615
|
+
# checkpoint management
|
|
3616
|
+
self.store_path = store_path
|
|
3617
|
+
|
|
3618
|
+
# learning rate scheduling
|
|
3619
|
+
self.scheduler = ReduceLROnPlateau(
|
|
3620
|
+
self.optimizer,
|
|
3621
|
+
patience=self.patience,
|
|
3622
|
+
factor=0.5
|
|
3623
|
+
)
|
|
3624
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
3625
|
+
|
|
3626
|
+
def forward(self) -> Tuple[List[float], float]:
|
|
3627
|
+
"""Trains the UnCLIP upsampler model to predict noise for denoising.
|
|
3628
|
+
|
|
3629
|
+
Executes the training loop, optimizing the upsampler model using low- and high-resolution
|
|
3630
|
+
image pairs, mixed precision, gradient clipping, and learning rate scheduling. Supports
|
|
3631
|
+
validation, early stopping, and checkpointing.
|
|
3632
|
+
|
|
3633
|
+
Returns
|
|
3634
|
+
-------
|
|
3635
|
+
train_losses : List[float]
|
|
3636
|
+
List of mean training losses per epoch.
|
|
3637
|
+
best_val_loss : float
|
|
3638
|
+
Best validation or training loss achieved.
|
|
3639
|
+
"""
|
|
3640
|
+
# set models to training mode
|
|
3641
|
+
self.upsampler_model.train()
|
|
3642
|
+
if self.upsampler_model.forward_diffusion.variance_scheduler.trainable_beta:
|
|
3643
|
+
self.upsampler_model.forward_diffusion.variance_scheduler.train()
|
|
3644
|
+
else:
|
|
3645
|
+
self.upsampler_model.forward_diffusion.variance_scheduler.eval()
|
|
3646
|
+
|
|
3647
|
+
# initialize training components
|
|
3648
|
+
scaler = torch.GradScaler() if self.use_autocast else None
|
|
3649
|
+
train_losses = []
|
|
3650
|
+
best_val_loss = float("inf")
|
|
3651
|
+
wait = 0
|
|
3652
|
+
|
|
3653
|
+
# main training loop
|
|
3654
|
+
for epoch in range(self.max_epochs):
|
|
3655
|
+
if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
|
|
3656
|
+
self.train_loader.sampler.set_epoch(epoch)
|
|
3657
|
+
|
|
3658
|
+
train_losses_epoch = []
|
|
3659
|
+
|
|
3660
|
+
# training step loop with gradient accumulation
|
|
3661
|
+
for step, (low_res_images, high_res_images) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
|
|
3662
|
+
low_res_images = low_res_images.to(self.device, non_blocking=True)
|
|
3663
|
+
high_res_images = high_res_images.to(self.device, non_blocking=True)
|
|
3664
|
+
|
|
3665
|
+
# forward pass with optional autocast
|
|
3666
|
+
if self.use_autocast:
|
|
3667
|
+
with torch.autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
|
|
3668
|
+
batch_size = high_res_images.shape[0]
|
|
3669
|
+
timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
|
|
3670
|
+
noise = torch.randn_like(high_res_images)
|
|
3671
|
+
# force FP32 for forward_diffusion to avoid NaN in variance scheduling
|
|
3672
|
+
with torch.autocast(device_type='cuda', enabled=False):
|
|
3673
|
+
high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
|
|
3674
|
+
corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
|
|
3675
|
+
low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
|
|
3676
|
+
predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
|
|
3677
|
+
loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
|
|
3678
|
+
else:
|
|
3679
|
+
batch_size = high_res_images.shape[0]
|
|
3680
|
+
timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
|
|
3681
|
+
noise = torch.randn_like(high_res_images)
|
|
3682
|
+
high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
|
|
3683
|
+
corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
|
|
3684
|
+
low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
|
|
3685
|
+
predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
|
|
3686
|
+
loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
|
|
3687
|
+
|
|
3688
|
+
# backward pass
|
|
3689
|
+
if self.use_autocast:
|
|
3690
|
+
scaler.scale(loss).backward()
|
|
3691
|
+
else:
|
|
3692
|
+
loss.backward()
|
|
3693
|
+
|
|
3694
|
+
if (step + 1) % self.grad_accumulation_steps == 0:
|
|
3695
|
+
# clip gradients
|
|
3696
|
+
if self.use_autocast:
|
|
3697
|
+
scaler.unscale_(self.optimizer)
|
|
3698
|
+
torch.nn.utils.clip_grad_norm_(self.upsampler_model.parameters(), max_norm=1.0)
|
|
3699
|
+
torch.nn.utils.clip_grad_norm_(self.upsampler_model.forward_diffusion.parameters(), max_norm=1.0)
|
|
3700
|
+
|
|
3701
|
+
# optimizer step
|
|
3702
|
+
if self.use_autocast:
|
|
3703
|
+
scaler.step(self.optimizer)
|
|
3704
|
+
scaler.update()
|
|
3705
|
+
else:
|
|
3706
|
+
self.optimizer.step()
|
|
3707
|
+
self.optimizer.zero_grad()
|
|
3708
|
+
torch.cuda.empty_cache() # clear memory after optimizer step
|
|
3709
|
+
|
|
3710
|
+
train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
|
|
3711
|
+
|
|
3712
|
+
self.warmup_lr_scheduler.step()
|
|
3713
|
+
|
|
3714
|
+
mean_train_loss = self._compute_mean_loss(train_losses_epoch)
|
|
3715
|
+
train_losses.append(mean_train_loss)
|
|
3716
|
+
|
|
3717
|
+
if self.master_process and (epoch + 1) % self.log_frequency == 0:
|
|
3718
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
3719
|
+
print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
|
|
3720
|
+
|
|
3721
|
+
current_loss = mean_train_loss
|
|
3722
|
+
|
|
3723
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
3724
|
+
val_loss = self.validate()
|
|
3725
|
+
if self.master_process:
|
|
3726
|
+
print(f" | Val Loss: {val_loss:.4f}")
|
|
3727
|
+
print()
|
|
3728
|
+
current_loss = val_loss
|
|
3729
|
+
|
|
3730
|
+
self.scheduler.step(current_loss)
|
|
3731
|
+
|
|
3732
|
+
if self.master_process:
|
|
3733
|
+
if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
|
|
3734
|
+
best_val_loss = current_loss
|
|
3735
|
+
wait = 0
|
|
3736
|
+
self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
|
|
3737
|
+
else:
|
|
3738
|
+
wait += 1
|
|
3739
|
+
if wait >= self.patience:
|
|
3740
|
+
print("Early stopping triggered")
|
|
3741
|
+
self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
|
|
3742
|
+
break
|
|
3743
|
+
|
|
3744
|
+
if self.use_ddp:
|
|
3745
|
+
destroy_process_group()
|
|
3746
|
+
|
|
3747
|
+
return train_losses, best_val_loss
|
|
3748
|
+
|
|
3749
|
+
def _compute_mean_loss(self, losses: List[float]) -> float:
|
|
3750
|
+
"""Computes mean loss with DDP synchronization if needed.
|
|
3751
|
+
|
|
3752
|
+
Calculates the mean of the provided losses and synchronizes the result across
|
|
3753
|
+
processes in DDP mode.
|
|
3754
|
+
|
|
3755
|
+
Parameters
|
|
3756
|
+
----------
|
|
3757
|
+
`losses` : List[float]
|
|
3758
|
+
List of loss values for the current epoch.
|
|
3759
|
+
|
|
3760
|
+
Returns
|
|
3761
|
+
-------
|
|
3762
|
+
mean_loss : float
|
|
3763
|
+
Mean loss value, synchronized if using DDP.
|
|
3764
|
+
"""
|
|
3765
|
+
if not losses:
|
|
3766
|
+
return 0.0
|
|
3767
|
+
mean_loss = sum(losses) / len(losses)
|
|
3768
|
+
if self.use_ddp:
|
|
3769
|
+
# synchronize loss across all processes
|
|
3770
|
+
loss_tensor = torch.tensor(mean_loss, device=self.device)
|
|
3771
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
|
|
3772
|
+
mean_loss = (loss_tensor / self.ddp_world_size).item()
|
|
3773
|
+
|
|
3774
|
+
return mean_loss
|
|
3775
|
+
|
|
3776
|
+
def _setup_ddp(self) -> None:
|
|
3777
|
+
"""Sets up Distributed Data Parallel training configuration.
|
|
3778
|
+
|
|
3779
|
+
Initializes the process group, sets up rank information, and configures the CUDA
|
|
3780
|
+
device for the current process in DDP mode.
|
|
3781
|
+
"""
|
|
3782
|
+
required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
|
|
3783
|
+
for var in required_env_vars:
|
|
3784
|
+
if var not in os.environ:
|
|
3785
|
+
raise ValueError(f"DDP enabled but {var} environment variable not set")
|
|
3786
|
+
|
|
3787
|
+
if not torch.cuda.is_available():
|
|
3788
|
+
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
3789
|
+
|
|
3790
|
+
if not torch.distributed.is_initialized():
|
|
3791
|
+
init_process_group(backend="nccl")
|
|
3792
|
+
|
|
3793
|
+
self.ddp_rank = int(os.environ["RANK"])
|
|
3794
|
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
|
3795
|
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"])
|
|
3796
|
+
|
|
3797
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
3798
|
+
torch.cuda.set_device(self.device)
|
|
3799
|
+
|
|
3800
|
+
self.master_process = self.ddp_rank == 0
|
|
3801
|
+
|
|
3802
|
+
if self.master_process:
|
|
3803
|
+
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
3804
|
+
|
|
3805
|
+
def _setup_single_gpu(self) -> None:
|
|
3806
|
+
"""Sets up single GPU or CPU training configuration.
|
|
3807
|
+
|
|
3808
|
+
Configures the training setup for single-device operation, setting rank and process
|
|
3809
|
+
information for non-DDP training.
|
|
3810
|
+
"""
|
|
3811
|
+
self.ddp_rank = 0
|
|
3812
|
+
self.ddp_local_rank = 0
|
|
3813
|
+
self.ddp_world_size = 1
|
|
3814
|
+
self.master_process = True
|
|
3815
|
+
|
|
3816
|
+
@staticmethod
|
|
3817
|
+
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
|
|
3818
|
+
"""Creates a learning rate scheduler for warmup.
|
|
3819
|
+
|
|
3820
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
3821
|
+
optimizer's initial value over the specified warmup epochs.
|
|
3822
|
+
|
|
3823
|
+
Parameters
|
|
3824
|
+
----------
|
|
3825
|
+
`optimizer` : torch.optim.Optimizer
|
|
3826
|
+
Optimizer to apply the scheduler to.
|
|
3827
|
+
`warmup_epochs` : int
|
|
3828
|
+
Number of epochs for the warmup phase.
|
|
3829
|
+
|
|
3830
|
+
Returns
|
|
3831
|
+
-------
|
|
3832
|
+
lr_scheduler : torch.optim.lr_scheduler.LambdaLR
|
|
3833
|
+
Learning rate scheduler for warmup.
|
|
3834
|
+
"""
|
|
3835
|
+
def lr_lambda(epoch):
|
|
3836
|
+
return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
|
|
3837
|
+
|
|
3838
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
3839
|
+
|
|
3840
|
+
def _wrap_models_for_ddp(self) -> None:
|
|
3841
|
+
"""Wraps models with DistributedDataParallel for multi-GPU training.
|
|
3842
|
+
|
|
3843
|
+
Configures the upsampler model for DDP training by wrapping it with DistributedDataParallel.
|
|
3844
|
+
"""
|
|
3845
|
+
if self.use_ddp:
|
|
3846
|
+
self.upsampler_model = self.upsampler_model.to(self.ddp_local_rank)
|
|
3847
|
+
self.upsampler_model = DDP(
|
|
3848
|
+
self.upsampler_model,
|
|
3849
|
+
device_ids=[self.ddp_local_rank],
|
|
3850
|
+
find_unused_parameters=True
|
|
3851
|
+
)
|
|
3852
|
+
|
|
3853
|
+
def _compile_models(self) -> None:
|
|
3854
|
+
"""Compiles models for optimization if supported.
|
|
3855
|
+
|
|
3856
|
+
Attempts to compile the upsampler model using torch.compile for optimization,
|
|
3857
|
+
falling back to uncompiled execution if compilation fails.
|
|
3858
|
+
"""
|
|
3859
|
+
if self.use_compilation:
|
|
3860
|
+
try:
|
|
3861
|
+
self.upsampler_model = self.upsampler_model.to(self.device)
|
|
3862
|
+
self.upsampler_model = torch.compile(self.upsampler_model, mode="reduce-overhead")
|
|
3863
|
+
|
|
3864
|
+
if self.master_process:
|
|
3865
|
+
print("Models compiled successfully")
|
|
3866
|
+
except Exception as e:
|
|
3867
|
+
if self.master_process:
|
|
3868
|
+
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
3869
|
+
|
|
3870
|
+
def corrupt_conditioning_image(self, x_low: torch.Tensor, corruption_type: str = "gaussian_blur") -> torch.Tensor:
|
|
3871
|
+
"""Corrupts the low-resolution conditioning image for robustness.
|
|
3872
|
+
|
|
3873
|
+
Applies Gaussian blur or BSR degradation to the low-resolution image to simulate
|
|
3874
|
+
real-world degradation, as specified in the UnCLIP paper.
|
|
3875
|
+
|
|
3876
|
+
Parameters
|
|
3877
|
+
----------
|
|
3878
|
+
`x_low` : torch.Tensor
|
|
3879
|
+
Low-resolution input image, shape (batch_size, channels, low_res_size, low_res_size).
|
|
3880
|
+
`corruption_type` : str, optional
|
|
3881
|
+
Type of corruption to apply: "gaussian_blur" or "bsr_degradation" (default: "gaussian_blur").
|
|
3882
|
+
|
|
3883
|
+
Returns
|
|
3884
|
+
-------
|
|
3885
|
+
x_degraded : torch.Tensor
|
|
3886
|
+
Corrupted low-resolution image, same shape as input.
|
|
3887
|
+
"""
|
|
3888
|
+
if corruption_type == "gaussian_blur":
|
|
3889
|
+
# apply Gaussian blur
|
|
3890
|
+
kernel_size = random.choice([3, 5, 7])
|
|
3891
|
+
sigma = random.uniform(0.5, 2.0)
|
|
3892
|
+
return self._gaussian_blur(x_low, kernel_size, sigma)
|
|
3893
|
+
elif corruption_type == "bsr_degradation":
|
|
3894
|
+
# more diverse BSR degradation for second upsampler
|
|
3895
|
+
return self._bsr_degradation(x_low)
|
|
3896
|
+
else:
|
|
3897
|
+
return x_low
|
|
3898
|
+
|
|
3899
|
+
def _gaussian_blur(self, x: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
|
|
3900
|
+
"""Applies Gaussian blur to the input image.
|
|
3901
|
+
|
|
3902
|
+
Parameters
|
|
3903
|
+
----------
|
|
3904
|
+
`x` : torch.Tensor
|
|
3905
|
+
Input image tensor, shape (batch_size, channels, height, width).
|
|
3906
|
+
`kernel_size` : int
|
|
3907
|
+
Size of the Gaussian kernel.
|
|
3908
|
+
`sigma` : float
|
|
3909
|
+
Standard deviation of the Gaussian distribution.
|
|
3910
|
+
|
|
3911
|
+
Returns
|
|
3912
|
+
-------
|
|
3913
|
+
x_blurred : torch.Tensor
|
|
3914
|
+
Blurred image tensor, same shape as input.
|
|
3915
|
+
"""
|
|
3916
|
+
# create Gaussian kernel
|
|
3917
|
+
kernel = self._get_gaussian_kernel(kernel_size, sigma).to(x.device)
|
|
3918
|
+
kernel = kernel.expand(x.shape[1], 1, kernel_size, kernel_size)
|
|
3919
|
+
padding = kernel_size // 2
|
|
3920
|
+
return F.conv2d(x, kernel, padding=padding, groups=x.shape[1])
|
|
3921
|
+
|
|
3922
|
+
def _get_gaussian_kernel(self, kernel_size: int, sigma: float) -> torch.Tensor:
|
|
3923
|
+
"""Generates a 2D Gaussian kernel.
|
|
3924
|
+
|
|
3925
|
+
Parameters
|
|
3926
|
+
----------
|
|
3927
|
+
`kernel_size` : int
|
|
3928
|
+
Size of the Gaussian kernel.
|
|
3929
|
+
`sigma` : float
|
|
3930
|
+
Standard deviation of the Gaussian distribution.
|
|
3931
|
+
|
|
3932
|
+
Returns
|
|
3933
|
+
-------
|
|
3934
|
+
kernel : torch.Tensor
|
|
3935
|
+
2D Gaussian kernel, shape (kernel_size, kernel_size).
|
|
3936
|
+
"""
|
|
3937
|
+
coords = torch.arange(kernel_size, dtype=torch.float32) - kernel_size // 2
|
|
3938
|
+
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
|
|
3939
|
+
g = g / g.sum()
|
|
3940
|
+
return g[:, None] * g[None, :]
|
|
3941
|
+
|
|
3942
|
+
def _bsr_degradation(self, x: torch.Tensor) -> torch.Tensor:
|
|
3943
|
+
"""Applies BSR degradation to the input image.
|
|
3944
|
+
|
|
3945
|
+
Simulates degradation with noise and Gaussian blur, as used in the UnCLIP paper
|
|
3946
|
+
for the second upsampler.
|
|
3947
|
+
|
|
3948
|
+
Parameters
|
|
3949
|
+
----------
|
|
3950
|
+
`x` : torch.Tensor
|
|
3951
|
+
Input image tensor, shape (batch_size, channels, height, width).
|
|
3952
|
+
|
|
3953
|
+
Returns
|
|
3954
|
+
-------
|
|
3955
|
+
x_degraded : torch.Tensor
|
|
3956
|
+
Degraded image tensor, same shape as input, clamped to [-1, 1].
|
|
3957
|
+
"""
|
|
3958
|
+
# add noise
|
|
3959
|
+
noise_level = random.uniform(0.0, 0.1)
|
|
3960
|
+
noise = torch.randn_like(x) * noise_level
|
|
3961
|
+
|
|
3962
|
+
# apply blur
|
|
3963
|
+
kernel_size = random.choice([3, 5, 7])
|
|
3964
|
+
sigma = random.uniform(0.5, 3.0)
|
|
3965
|
+
x_degraded = self._gaussian_blur(x + noise, kernel_size, sigma)
|
|
3966
|
+
|
|
3967
|
+
return torch.clamp(x_degraded, -1.0, 1.0)
|
|
3968
|
+
|
|
3969
|
+
def validate(self) -> float:
|
|
3970
|
+
"""Validates the UnCLIP upsampler model.
|
|
3971
|
+
|
|
3972
|
+
Computes the validation loss by applying forward diffusion to high-resolution images,
|
|
3973
|
+
predicting noise with the upsampler model conditioned on corrupted low-resolution images,
|
|
3974
|
+
and comparing predicted noise to ground truth.
|
|
3975
|
+
|
|
3976
|
+
Returns
|
|
3977
|
+
-------
|
|
3978
|
+
val_loss : float
|
|
3979
|
+
Mean validation loss.
|
|
3980
|
+
"""
|
|
3981
|
+
# set models to eval mode for evaluation
|
|
3982
|
+
self.upsampler_model.eval()
|
|
3983
|
+
self.upsampler_model.forward_diffusion.eval()
|
|
3984
|
+
|
|
3985
|
+
val_losses = []
|
|
3986
|
+
|
|
3987
|
+
with torch.no_grad():
|
|
3988
|
+
for low_res_images, high_res_images in self.val_loader:
|
|
3989
|
+
low_res_images = low_res_images.to(self.device, non_blocking=True)
|
|
3990
|
+
high_res_images = high_res_images.to(self.device, non_blocking=True)
|
|
3991
|
+
batch_size = high_res_images.shape[0]
|
|
3992
|
+
timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
|
|
3993
|
+
noise = torch.randn_like(high_res_images)
|
|
3994
|
+
high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
|
|
3995
|
+
corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
|
|
3996
|
+
low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
|
|
3997
|
+
predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
|
|
3998
|
+
# compute loss
|
|
3999
|
+
loss = self.objective(predicted_noise, noise)
|
|
4000
|
+
val_losses.append(loss.item())
|
|
4001
|
+
|
|
4002
|
+
# compute average loss
|
|
4003
|
+
val_loss = torch.tensor(val_losses).mean().item()
|
|
4004
|
+
|
|
4005
|
+
if self.use_ddp:
|
|
4006
|
+
val_loss_tensor = torch.tensor(val_loss, device=self.device)
|
|
4007
|
+
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
|
|
4008
|
+
val_loss = val_loss_tensor.item()
|
|
4009
|
+
|
|
4010
|
+
# return to training mode
|
|
4011
|
+
self.upsampler_model.train()
|
|
4012
|
+
if not self.upsampler_model.forward_diffusion.variance_scheduler.trainable_beta:
|
|
4013
|
+
self.upsampler_model.forward_diffusion.variance_scheduler.eval()
|
|
4014
|
+
|
|
4015
|
+
return val_loss
|
|
4016
|
+
|
|
4017
|
+
def _save_checkpoint(self, epoch: int, loss: float, is_best: bool = False, suffix: str = ""):
|
|
4018
|
+
"""Saves model checkpoint.
|
|
4019
|
+
|
|
4020
|
+
Saves the state of the upsampler model, its variance scheduler, optimizer, and
|
|
4021
|
+
schedulers, with options for best model and epoch-specific checkpoints.
|
|
4022
|
+
|
|
4023
|
+
Parameters
|
|
4024
|
+
----------
|
|
4025
|
+
`epoch` : int
|
|
4026
|
+
Current epoch number.
|
|
4027
|
+
`loss` : float
|
|
4028
|
+
Current loss value.
|
|
4029
|
+
`is_best` : bool, optional
|
|
4030
|
+
Whether to save as the best model checkpoint (default: False).
|
|
4031
|
+
`suffix` : str, optional
|
|
4032
|
+
Suffix to add to checkpoint filename, default "".
|
|
4033
|
+
"""
|
|
4034
|
+
if not self.master_process:
|
|
4035
|
+
return
|
|
4036
|
+
checkpoint = {
|
|
4037
|
+
'epoch': epoch,
|
|
4038
|
+
'loss': loss,
|
|
4039
|
+
# core model
|
|
4040
|
+
'upsampler_model_state_dict': self.upsampler_model.module.state_dict() if self.use_ddp else self.upsampler_model.state_dict(),
|
|
4041
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
4042
|
+
# training configuration
|
|
4043
|
+
'model_channels': self.upsampler_model.model_channels,
|
|
4044
|
+
'num_res_blocks': self.upsampler_model.num_res_blocks,
|
|
4045
|
+
'normalize': self.normalize_image_outputs,
|
|
4046
|
+
'output_range': self.image_output_range
|
|
4047
|
+
}
|
|
4048
|
+
|
|
4049
|
+
# save variance scheduler (submodule of forward_diffusion)
|
|
4050
|
+
checkpoint['variance_scheduler_state_dict'] = (
|
|
4051
|
+
self.upsampler_model.module.forward_diffusion.variance_scheduler.state_dict() if self.use_ddp
|
|
4052
|
+
else self.upsampler_model.forward_diffusion.variance_scheduler.state_dict()
|
|
4053
|
+
)
|
|
4054
|
+
|
|
4055
|
+
# save schedulers state
|
|
4056
|
+
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
|
4057
|
+
checkpoint['warmup_scheduler_state_dict'] = self.warmup_lr_scheduler.state_dict()
|
|
4058
|
+
|
|
4059
|
+
filename = f"unclip_upsampler_epoch_{epoch}{suffix}.pth"
|
|
4060
|
+
if is_best:
|
|
4061
|
+
filename = f"unclip_upsampler_best{suffix}.pth"
|
|
4062
|
+
|
|
4063
|
+
filepath = os.path.join(self.store_path, filename)
|
|
4064
|
+
os.makedirs(self.store_path, exist_ok=True)
|
|
4065
|
+
torch.save(checkpoint, filepath)
|
|
4066
|
+
|
|
4067
|
+
if is_best:
|
|
4068
|
+
print(f"Best model saved: {filepath}")
|
|
4069
|
+
|
|
4070
|
+
def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
|
|
4071
|
+
"""Loads model checkpoint.
|
|
4072
|
+
|
|
4073
|
+
Restores the state of the upsampler model, its variance scheduler, optimizer, and
|
|
4074
|
+
schedulers from a saved checkpoint, handling DDP compatibility.
|
|
4075
|
+
|
|
4076
|
+
Parameters
|
|
4077
|
+
----------
|
|
4078
|
+
`checkpoint_path` : str
|
|
4079
|
+
Path to the checkpoint file.
|
|
4080
|
+
|
|
4081
|
+
Returns
|
|
4082
|
+
-------
|
|
4083
|
+
epoch : int
|
|
4084
|
+
The epoch at which the checkpoint was saved.
|
|
4085
|
+
loss : float
|
|
4086
|
+
The loss at the checkpoint.
|
|
4087
|
+
"""
|
|
4088
|
+
try:
|
|
4089
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
4090
|
+
except FileNotFoundError:
|
|
4091
|
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
4092
|
+
|
|
4093
|
+
def _load_model_state_dict(model: nn.Module, state_dict: dict, model_name: str) -> None:
|
|
4094
|
+
"""Helper function to load state dict with DDP compatibility."""
|
|
4095
|
+
try:
|
|
4096
|
+
# handle DDP state dict compatibility
|
|
4097
|
+
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
|
|
4098
|
+
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
|
|
4099
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
4100
|
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
4101
|
+
|
|
4102
|
+
model.load_state_dict(state_dict)
|
|
4103
|
+
if self.master_process:
|
|
4104
|
+
print(f"✓ Loaded {model_name}")
|
|
4105
|
+
except Exception as e:
|
|
4106
|
+
warnings.warn(f"Failed to load {model_name}: {e}")
|
|
4107
|
+
|
|
4108
|
+
# load core upsampler model
|
|
4109
|
+
if 'upsampler_model_state_dict' in checkpoint:
|
|
4110
|
+
_load_model_state_dict(self.upsampler_model, checkpoint['upsampler_model_state_dict'],
|
|
4111
|
+
'upsampler_model')
|
|
4112
|
+
|
|
4113
|
+
# load variance scheduler (submodule of forward_diffusion)
|
|
4114
|
+
if 'variance_scheduler_state_dict' in checkpoint or 'hyper_params_state_dict' in checkpoint:
|
|
4115
|
+
state_dict = checkpoint.get('variance_scheduler_state_dict', checkpoint.get('hyper_params_state_dict'))
|
|
4116
|
+
try:
|
|
4117
|
+
_load_model_state_dict(self.upsampler_model.forward_diffusion.variance_scheduler, state_dict, 'variance_scheduler')
|
|
4118
|
+
except Exception as e:
|
|
4119
|
+
warnings.warn(f"Failed to load variance scheduler: {e}")
|
|
4120
|
+
|
|
4121
|
+
# load optimizer
|
|
4122
|
+
if 'optimizer_state_dict' in checkpoint:
|
|
4123
|
+
try:
|
|
4124
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
4125
|
+
if self.master_process:
|
|
4126
|
+
print("✓ Loaded optimizer")
|
|
4127
|
+
except Exception as e:
|
|
4128
|
+
warnings.warn(f"Failed to load optimizer state: {e}")
|
|
4129
|
+
|
|
4130
|
+
# load schedulers
|
|
4131
|
+
if 'scheduler_state_dict' in checkpoint:
|
|
4132
|
+
try:
|
|
4133
|
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
4134
|
+
if self.master_process:
|
|
4135
|
+
print("✓ Loaded main scheduler")
|
|
4136
|
+
except Exception as e:
|
|
4137
|
+
warnings.warn(f"Failed to load scheduler state: {e}")
|
|
4138
|
+
|
|
4139
|
+
if 'warmup_scheduler_state_dict' in checkpoint:
|
|
4140
|
+
try:
|
|
4141
|
+
self.warmup_lr_scheduler.load_state_dict(checkpoint['warmup_scheduler_state_dict'])
|
|
4142
|
+
if self.master_process:
|
|
4143
|
+
print("✓ Loaded warmup scheduler")
|
|
4144
|
+
except Exception as e:
|
|
4145
|
+
warnings.warn(f"Failed to load warmup scheduler state: {e}")
|
|
4146
|
+
|
|
4147
|
+
# verify configuration compatibility
|
|
4148
|
+
if 'model_channels' in checkpoint:
|
|
4149
|
+
if checkpoint['model_channels'] != self.upsampler_model.model_channels:
|
|
4150
|
+
warnings.warn(
|
|
4151
|
+
f"Model channels mismatch: checkpoint={checkpoint['model_channels']}, current={self.upsampler_model.model_channels}")
|
|
4152
|
+
|
|
4153
|
+
if 'num_res_blocks' in checkpoint:
|
|
4154
|
+
if checkpoint['num_res_blocks'] != self.upsampler_model.num_res_blocks:
|
|
4155
|
+
warnings.warn(
|
|
4156
|
+
f"Num res blocks mismatch: checkpoint={checkpoint['num_res_blocks']}, current={self.upsampler_model.num_res_blocks}")
|
|
4157
|
+
|
|
4158
|
+
if 'normalize' in checkpoint:
|
|
4159
|
+
if checkpoint['normalize'] != self.normalize_image_outputs:
|
|
4160
|
+
warnings.warn(
|
|
4161
|
+
f"Normalize setting mismatch: checkpoint={checkpoint['normalize']}, current={self.normalize_image_outputs}")
|
|
4162
|
+
|
|
4163
|
+
epoch = checkpoint.get('epoch', 0)
|
|
4164
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
4165
|
+
|
|
4166
|
+
if self.master_process:
|
|
4167
|
+
print(f"Successfully loaded checkpoint from {checkpoint_path}")
|
|
4168
|
+
print(f"Epoch: {epoch}, Loss: {loss:.4f}")
|
|
4169
|
+
|
|
4170
|
+
return epoch, loss
|