TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
torchdiff/unclip.py ADDED
@@ -0,0 +1,4170 @@
1
+ """
2
+ **UnCLIP Diffusion Model**
3
+
4
+ This module provides a comprehensive implementation of the UnCLIP diffusion model,
5
+ as described in Ramesh et al. (2022, "Hierarchical Text-Conditional Image Generation with CLIP Latents").
6
+ It integrates CLIP embeddings with diffusion processes for high-quality image generation conditioned on text prompts or image embeddings.
7
+ The module supports training, sampling, and upsampling processes, leveraging components from CLIP, GLIDE, and DDIM,
8
+ with classifier-free guidance and text dropout for robust generation.
9
+
10
+ **Components**
11
+
12
+ - **VarianceSchedulerUnCLIP**: Manages noise schedules with support for linear, sigmoid, quadratic, constant, inverse_time,
13
+ and cosine beta schedules, including subsampled (tau) schedules for efficient sampling.
14
+ - **ForwardUnCLIP**: Forward diffusion process to add noise to image or latent embeddings.
15
+ - **ReverseUnCLIP**: Reverse diffusion process for denoising, supporting noise or clean image predictions with subsampled steps.
16
+ - **CLIPEncoder**: Encodes images or text into embeddings using a pre-trained CLIP model.
17
+ - **UnClipDecoder**: Generates low-resolution images (64x64) from CLIP embeddings, incorporating GLIDE text encoding and classifier-free guidance.
18
+ - **UnCLIPTransformerPrior**: Transformer-based prior to predict clean image embeddings from noisy embeddings and text conditions.
19
+ - **CLIPContextProjection**: Projects CLIP image embeddings into context tokens for the decoder.
20
+ - **CLIPEmbeddingProjection**: Reduces and reconstructs embedding dimensionality for efficient processing.
21
+ - **TrainUnClipDecoder**: Orchestrates training of the decoder with mixed precision, gradient accumulation, and DDP support.
22
+ - **SampleUnCLIP**: Generates images from text prompts or noise, scaling from 64x64 to 256x256 or 1024x1024 with upsamplers.
23
+ - **UpsamplerUnCLIP**: U-Net-based upsampler for scaling images (64x64 to 256x256 or 256x256 to 1024x1024), conditioned on low-resolution inputs.
24
+ - **TrainUpsamplerUnCLIP**: Trains the upsampler with noise prediction, low-resolution conditioning, and optional image corruption (Gaussian blur or BSR degradation).
25
+
26
+ **Notes**
27
+
28
+ - The model uses a subsampled time step schedule (tau) for faster sampling, controlled by the `tau_num_steps` parameter in VarianceSchedulerUnCLIP.
29
+ - Classifier-free guidance and text dropout enhance generation quality, with tunable parameters `classifier_free_prop` and `drop_caption`.
30
+ - Upsampling stages use corrupted low-resolution inputs (Gaussian blur for 64x64→256x256, BSR degradation for 256x256→1024x1024) to improve robustness.
31
+ - Supports distributed training with DDP, mixed precision via autocast, and learning rate scheduling with warmup and plateau reduction.
32
+
33
+ **References**
34
+
35
+ - Ramesh, Aditya, et al. "Hierarchical Text-Conditional Image Generation with CLIP Latents." arXiv preprint arXiv:2204.06125 (2022).
36
+ - Radford, Alec, et al. "Learning Transferable Visual Models From Natural Language Supervision." arXiv preprint arXiv:2103.00020 (2021).
37
+ - Nichol, Alexander, et al. "GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models." arXiv preprint arXiv:2112.10741 (2021).
38
+ - Song, Jiaming, et al. "Denoising Diffusion Implicit Models." arXiv preprint arXiv:2010.02502 (2020).
39
+
40
+ -------------------------------------------------------------------------------
41
+ """
42
+
43
+ import torch
44
+ import torch.nn as nn
45
+ import torch.nn.functional as F
46
+ from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
47
+ import torch.distributed as dist
48
+ from torch.nn.parallel import DistributedDataParallel as DDP
49
+ from torch.distributed import init_process_group, destroy_process_group
50
+ import torchvision
51
+ from PIL import Image
52
+ from transformers import BertTokenizer, CLIPProcessor, CLIPModel
53
+ from typing import Optional, List, Tuple, Union, Callable, Any, Self
54
+ from tqdm import tqdm
55
+ import os
56
+ import warnings
57
+ import random
58
+ import math
59
+
60
+
61
+ ###==================================================================================================================###
62
+
63
+
64
+ class VarianceSchedulerUnCLIP(nn.Module):
65
+ """Manages noise schedule parameters for UnCLIP diffusion models.
66
+
67
+ Handles beta values, derived noise schedule quantities, and a subsampled time step schedule
68
+ (tau schedule) for UnCLIP diffusion processes. Supports trainable or fixed beta schedules
69
+ and multiple scheduling methods, including linear, sigmoid, quadratic, constant, inverse_time,
70
+ and cosine schedules.
71
+
72
+ Parameters
73
+ ----------
74
+ `eta` : float, optional
75
+ Noise scaling factor for the reverse process (default: 0, deterministic).
76
+ `num_steps` : int, optional
77
+ Total number of diffusion steps (default: 1000).
78
+ `tau_num_steps` : int, optional
79
+ Number of subsampled time steps for sampling (default: 100).
80
+ `beta_start` : float, optional
81
+ Starting value for beta (default: 1e-4).
82
+ `beta_end` : float, optional
83
+ Ending value for beta (default: 0.02).
84
+ `trainable_beta` : bool, optional
85
+ Whether the beta schedule is trainable (default: False).
86
+ `beta_method` : str, optional
87
+ Method for computing the beta schedule (default: "linear").
88
+ Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time", "cosine".
89
+ """
90
+ def __init__(
91
+ self,
92
+ eta: Optional[float] = None,
93
+ num_steps: int = 1000,
94
+ tau_num_steps: int = 100,
95
+ beta_start: float = 1e-4,
96
+ beta_end: float = 0.02,
97
+ trainable_beta: bool = False,
98
+ beta_method: str = "linear"
99
+ ) -> None:
100
+ super().__init__()
101
+ self.eta = eta or 0
102
+ self.num_steps = num_steps
103
+ self.tau_num_steps = tau_num_steps
104
+ self.beta_start = beta_start
105
+ self.beta_end = beta_end
106
+ self.trainable_beta = trainable_beta
107
+ self.beta_method = beta_method
108
+
109
+ if not (0 < beta_start < beta_end < 1):
110
+ raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
111
+ if num_steps <= 0:
112
+ raise ValueError(f"num_steps ({num_steps}) must be positive")
113
+
114
+ beta_range = (beta_start, beta_end)
115
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
116
+
117
+ if trainable_beta:
118
+ self.beta_raw = nn.Parameter(torch.logit((betas_init - beta_start) / (beta_end - beta_start)))
119
+ else:
120
+ self.register_buffer('betas_buffer', betas_init)
121
+ self.register_buffer('alphas', 1 - self.betas)
122
+ self.register_buffer('alpha_cumprod', torch.cumprod(self.alphas, dim=0))
123
+ self.register_buffer('sqrt_alpha_cumprod', torch.sqrt(self.alpha_cumprod))
124
+ self.register_buffer('sqrt_one_minus_alpha_cumprod', torch.sqrt(1 - self.alpha_cumprod))
125
+
126
+ self.register_buffer('tau_indices', torch.linspace(0, num_steps - 1, tau_num_steps, dtype=torch.long))
127
+
128
+ @property
129
+ def betas(self) -> torch.Tensor:
130
+ """Returns the beta values, applying reparameterization if trainable.
131
+
132
+ Returns the beta values, using sigmoid reparameterization for trainable betas
133
+ or directly accessing the stored buffer for fixed betas.
134
+
135
+ Returns
136
+ -------
137
+ betas : torch.Tensor
138
+ Beta values, shape (num_steps,).
139
+ """
140
+ if self.trainable_beta:
141
+ return self.beta_start + (self.beta_end - self.beta_start) * torch.sigmoid(self.beta_raw)
142
+ return self._buffers['betas_buffer']
143
+
144
+ def compute_beta_schedule(self, beta_range: Tuple[float, float], num_steps: int, method: str) -> torch.Tensor:
145
+ """Computes the beta schedule based on the specified method.
146
+
147
+ Generates a sequence of beta values for the noise schedule using the chosen method,
148
+ ensuring values are clamped within the specified range. Supports linear, sigmoid,
149
+ quadratic, constant, inverse_time, and cosine schedules.
150
+
151
+ Parameters
152
+ ----------
153
+ `beta_range` : tuple
154
+ Tuple of (min_beta, max_beta) specifying the valid range for beta values.
155
+ `num_steps` : int
156
+ Number of diffusion steps.
157
+ `method` : str
158
+ Method for computing the beta schedule. Supported methods:
159
+ "linear", "sigmoid", "quadratic", "constant", "inverse_time", "cosine".
160
+
161
+ Returns
162
+ -------
163
+ beta : torch.Tensor
164
+ Tensor of beta values, shape (num_steps,).
165
+ """
166
+ beta_min, beta_max = beta_range
167
+ if method == "sigmoid":
168
+ x = torch.linspace(-6, 6, num_steps)
169
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
170
+ elif method == "quadratic":
171
+ x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
172
+ beta = x ** 2
173
+ elif method == "constant":
174
+ beta = torch.full((num_steps,), beta_max)
175
+ elif method == "inverse_time":
176
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
177
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
178
+ elif method == "linear":
179
+ beta = torch.linspace(beta_min, beta_max, num_steps)
180
+ elif method == "cosine":
181
+ s = 0.008
182
+ steps = num_steps + 1
183
+ x = torch.linspace(0, num_steps, steps)
184
+ alphas_cumprod = torch.cos(((x / num_steps) + s) / (1 + s) * math.pi * 0.5) ** 2
185
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
186
+ beta = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
187
+ else:
188
+ raise ValueError(f"Unknown beta_method: {method}")
189
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
190
+ return beta
191
+
192
+ def get_tau_schedule(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
193
+ """Computes the subsampled (tau) noise schedule for UnCLIP.
194
+
195
+ Returns the noise schedule parameters for the subsampled time steps used in
196
+ UnCLIP sampling, based on the `tau_indices`.
197
+
198
+ Returns
199
+ -------
200
+ tau_betas : torch.Tensor
201
+ Beta values for subsampled steps, shape (tau_num_steps,).
202
+ tau_alphas : torch.Tensor
203
+ Alpha values for subsampled steps, shape (tau_num_steps,).
204
+ tau_alpha_cumprod : torch.Tensor
205
+ Cumulative product of alphas for subsampled steps, shape (tau_num_steps,).
206
+ tau_sqrt_alpha_cumprod : torch.Tensor
207
+ Square root of alpha_cumprod for subsampled steps, shape (tau_num_steps,).
208
+ tau_sqrt_one_minus_alpha_cumprod : torch.Tensor
209
+ Square root of (1 - alpha_cumprod) for subsampled steps, shape (tau_num_steps,).
210
+ """
211
+ if self.trainable_beta:
212
+ betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod = self.compute_schedule()
213
+ else:
214
+ betas = self.betas
215
+ alphas = self.alphas
216
+ alpha_cumprod = self.alpha_cumprod
217
+ sqrt_alpha_cumprod = self.sqrt_alpha_cumprod
218
+ sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod
219
+
220
+ tau_betas = betas[self.tau_indices]
221
+ tau_alphas = alphas[self.tau_indices]
222
+ tau_alpha_cumprod = alpha_cumprod[self.tau_indices]
223
+ tau_sqrt_alpha_cumprod = sqrt_alpha_cumprod[self.tau_indices]
224
+ tau_sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[self.tau_indices]
225
+
226
+ return tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod
227
+
228
+ def compute_schedule(self, time_steps: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
229
+ """Computes noise schedule parameters dynamically from betas.
230
+
231
+ Calculates the derived noise schedule parameters (alphas, alpha_cumprod, etc.)
232
+ from the provided beta values for the UnCLIP diffusion process.
233
+
234
+ Parameters
235
+ ----------
236
+ `time_steps` : torch.Tensor, optional
237
+ If provided, returns parameters only for specified time steps.
238
+ If None, returns parameters for all time steps.
239
+
240
+ Returns
241
+ -------
242
+ betas : torch.Tensor
243
+ Beta values, shape (num_steps,) or (len(time_steps),).
244
+ alphas : torch.Tensor
245
+ 1 - betas, shape (num_steps,) or (len(time_steps),).
246
+ alpha_cumprod : torch.Tensor
247
+ Cumulative product of alphas, shape (num_steps,) or (len(time_steps),).
248
+ sqrt_alpha_cumprod : torch.Tensor
249
+ Square root of alpha_cumprod, shape (num_steps,) or (len(time_steps),).
250
+ sqrt_one_minus_alpha_cumprod : torch.Tensor
251
+ Square root of (1 - alpha_cumprod), shape (num_steps,) or (len(time_steps),).
252
+ """
253
+ betas = self.betas
254
+ alphas = 1 - betas
255
+ alpha_cumprod = torch.cumprod(alphas, dim=0)
256
+ sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod)
257
+ sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alpha_cumprod)
258
+ if time_steps is not None:
259
+ return (betas[time_steps], alphas[time_steps], alpha_cumprod[time_steps],
260
+ sqrt_alpha_cumprod[time_steps], sqrt_one_minus_alpha_cumprod[time_steps])
261
+ return betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod
262
+
263
+ ###==================================================================================================================###
264
+
265
+ class ForwardUnCLIP(nn.Module):
266
+ """Forward diffusion process for UnCLIP diffusion models.
267
+
268
+ Applies Gaussian noise to input data (2D or 4D tensors) according to the UnCLIP
269
+ forward diffusion process at specified time steps, using cumulative noise schedule
270
+ parameters from the variance scheduler.
271
+
272
+ Parameters
273
+ ----------
274
+ `variance_scheduler` : torch.nn.Module
275
+ Variance scheduler module (e.g., VarianceSchedulerUnCLIP) containing the noise
276
+ schedule parameters.
277
+ """
278
+ def __init__(self, variance_scheduler: torch.nn.Module) -> None:
279
+ super().__init__()
280
+ self.variance_scheduler = variance_scheduler
281
+
282
+ def forward(self, x0: torch.Tensor, noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
283
+ """Applies the forward diffusion process to the input data.
284
+
285
+ Perturbs the input data `x0` by adding Gaussian noise at specified time steps,
286
+ supporting both 2D (e.g., latent embeddings) and 4D (e.g., image) inputs.
287
+
288
+ Parameters
289
+ ----------
290
+ `x0` : torch.Tensor
291
+ Input data tensor, shape (batch_size, embedding_dim) for 2D or
292
+ (batch_size, channels, height, width) for 4D.
293
+ `noise` : torch.Tensor
294
+ Gaussian noise tensor, same shape as `x0`.
295
+ `time_steps` : torch.Tensor
296
+ Tensor of time step indices (long), shape (batch_size,),
297
+ where each value is in the range [0, variance_scheduler.num_steps - 1].
298
+
299
+ Returns
300
+ -------
301
+ xt : torch.Tensor
302
+ Noisy data tensor at the specified time steps, same shape as `x0`.
303
+ """
304
+ if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.num_steps)):
305
+ raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.num_steps - 1}")
306
+
307
+ if self.variance_scheduler.trainable_beta:
308
+ _, _, _, sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.compute_schedule(
309
+ time_steps
310
+ )
311
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.to(x0.device)
312
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.to(x0.device)
313
+ else:
314
+ sqrt_alpha_cumprod_t = self.variance_scheduler.sqrt_alpha_cumprod[time_steps].to(x0.device)
315
+ sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.sqrt_one_minus_alpha_cumprod[time_steps].to(x0.device)
316
+
317
+ # check input dimensions and adjust reshaping for 2D or 4D tensors
318
+ is_2d = x0.dim() == 2 # check if input is 2D (batch_size, embedding_dim)
319
+ if is_2d:
320
+ # for 2D inputs, reshape to [batch_size, 1]
321
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1)
322
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1)
323
+ else:
324
+ # for 4D inputs, reshape to [batch_size, 1, 1, 1]
325
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1, 1, 1)
326
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1, 1, 1)
327
+
328
+ xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
329
+ return xt
330
+
331
+ ###==================================================================================================================###
332
+
333
+ class ReverseUnCLIP(nn.Module):
334
+ """Reverse diffusion process for UnCLIP diffusion models.
335
+
336
+ Denoises a noisy input `xt` using either a predicted noise component or predicted clean image
337
+ and a subsampled time step schedule, supporting both 2D (e.g., latent embeddings) and 4D (e.g., image) inputs.
338
+
339
+ Parameters
340
+ ----------
341
+ `variance_scheduler` : torch.nn.Module
342
+ Variance scheduler module (e.g., VarianceSchedulerUnCLIP) containing the noise
343
+ schedule parameters.
344
+ `prediction_type` : str, default "noise"
345
+ Type of prediction the model makes. Either "noise" (predicts noise like DDIM) or
346
+ "x0" (predicts clean image like UnCLIP prior).
347
+ """
348
+
349
+ def __init__(self, variance_scheduler: torch.nn.Module, prediction_type: str = "noise"):
350
+ super().__init__()
351
+ self.variance_scheduler = variance_scheduler
352
+ if prediction_type not in ["noise", "x0"]:
353
+ raise ValueError(f"prediction_type must be either 'noise' or 'x0', got {prediction_type}")
354
+ self.prediction_type = prediction_type
355
+
356
+ def forward(
357
+ self,
358
+ xt: torch.Tensor,
359
+ model_prediction: torch.Tensor,
360
+ time_steps: torch.Tensor,
361
+ prev_time_steps: torch.Tensor
362
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
363
+ """Applies the reverse diffusion process to the noisy input.
364
+
365
+ Denoises the input `xt` at time step `t` to produce the previous step `xt_prev`
366
+ at `prev_time_steps` using either the predicted noise or predicted clean image
367
+ and the UnCLIP reverse process. Supports both 2D and 4D inputs.
368
+
369
+ Parameters
370
+ ----------
371
+ `xt` : torch.Tensor
372
+ Noisy input tensor at time step `t`, shape (batch_size, embedding_dim) for 2D
373
+ or (batch_size, channels, height, width) for 4D.
374
+ `model_prediction` : torch.Tensor
375
+ Model prediction tensor, same shape as `xt`. Can be either predicted noise
376
+ or predicted clean image depending on `prediction_type`.
377
+ `time_steps` : torch.Tensor
378
+ Tensor of time step indices (long), shape (batch_size,), where each value
379
+ is in the range [0, variance_scheduler.tau_num_steps - 1].
380
+ `prev_time_steps` : torch.Tensor
381
+ Tensor of previous time step indices (long), shape (batch_size,), where each
382
+ value is in the range [0, variance_scheduler.tau_num_steps - 1].
383
+
384
+ Returns
385
+ -------
386
+ xt_prev : torch.Tensor
387
+ Denoised tensor at `prev_time_steps`, same shape as `xt`.
388
+ x0 : torch.Tensor
389
+ Estimated original data (t=0), same shape as `xt`.
390
+ """
391
+ if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.tau_num_steps)):
392
+ raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
393
+ if not torch.all((prev_time_steps >= 0) & (prev_time_steps < self.variance_scheduler.tau_num_steps)):
394
+ raise ValueError(f"prev_time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
395
+
396
+ _, _, _, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = self.variance_scheduler.get_tau_schedule()
397
+
398
+ # check input dimensions and adjust reshaping for 2D or 4D tensors
399
+ is_2d = xt.dim() == 2 # check if input is 2D (batch_size, embedding_dim)
400
+ if is_2d:
401
+ # for 2D inputs, reshape to [batch_size, 1]
402
+ tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1)
403
+ tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1)
404
+ prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1)
405
+ prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1)
406
+ else:
407
+ # for 4D inputs, reshape to [batch_size, 1, 1, 1]
408
+ tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
409
+ tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
410
+ prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
411
+ prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
412
+
413
+ eta = self.variance_scheduler.eta
414
+
415
+ predicted_noise = None
416
+ x0 = None
417
+ # handle different prediction types
418
+ if self.prediction_type == "noise":
419
+ # model predicts noise
420
+ predicted_noise = model_prediction
421
+ x0 = (xt - tau_sqrt_one_minus_alpha_cumprod_t * predicted_noise) / tau_sqrt_alpha_cumprod_t
422
+ elif self.prediction_type == "x0":
423
+ # model predicts clean image
424
+ x0 = model_prediction
425
+ # calculate implied noise from the predicted clean image
426
+ predicted_noise = (xt - tau_sqrt_alpha_cumprod_t * x0) / tau_sqrt_one_minus_alpha_cumprod_t
427
+
428
+ # DDIM sampling step (same for both prediction types)
429
+ noise_coeff = eta * ((tau_sqrt_one_minus_alpha_cumprod_t / prev_tau_sqrt_alpha_cumprod_t) *
430
+ prev_tau_sqrt_one_minus_alpha_cumprod_t / torch.clamp(tau_sqrt_one_minus_alpha_cumprod_t, min=1e-8))
431
+ direction_coeff = torch.clamp(prev_tau_sqrt_one_minus_alpha_cumprod_t ** 2 - noise_coeff ** 2, min=1e-8).sqrt()
432
+ xt_prev = prev_tau_sqrt_alpha_cumprod_t * x0 + noise_coeff * torch.randn_like(xt) + direction_coeff * predicted_noise
433
+
434
+ return xt_prev, x0
435
+
436
+ def set_prediction_type(self, prediction_type: str):
437
+ """Change the prediction type after initialization.
438
+
439
+ Parameters
440
+ ----------
441
+ prediction_type : str
442
+ Type of prediction the model makes. Either "noise" or "x0".
443
+ """
444
+ if prediction_type not in ["noise", "x0"]:
445
+ raise ValueError(f"prediction_type must be either 'noise' or 'x0', got {prediction_type}")
446
+ self.prediction_type = prediction_type
447
+
448
+ ###==================================================================================================================###
449
+
450
+ class CLIPEncoder(nn.Module):
451
+ """Encodes images or text using a pre-trained CLIP model.
452
+
453
+ Loads a CLIP model and processor from the transformers library, providing methods to
454
+ encode images or text into embeddings and compute similarity scores between them.
455
+
456
+ Parameters
457
+ ----------
458
+ `model_name` : str, optional
459
+ Name of the CLIP model to load (default: 'openai/clip-vit-base-patch32').
460
+ `device` : str, optional
461
+ Device to run the model on (default: 'cuda' if available, else 'cpu').
462
+ `use_fast` : bool, optional
463
+ Whether to use the fast image processor (torchvision-based) (default: False).
464
+ """
465
+ def __init__(
466
+ self,
467
+ model_name: str = "openai/clip-vit-base-patch32",
468
+ device: Optional[str] = None,
469
+ use_fast: bool = False,
470
+ ) -> None:
471
+ super().__init__()
472
+
473
+ # set model name and device
474
+ self.model_name = model_name
475
+ self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
476
+
477
+ try:
478
+ # load CLIP model and processor
479
+ self.model = CLIPModel.from_pretrained(self.model_name)
480
+ self.processor = CLIPProcessor.from_pretrained(self.model_name, use_fast=use_fast)
481
+ self.model = self.model.to(self.device)
482
+ except Exception as e:
483
+ raise RuntimeError(f"Failed to load CLIP model or processor for {self.model_name}: {e}")
484
+
485
+ # set model to evaluation mode by default
486
+ self.model.eval()
487
+
488
+ def forward(
489
+ self,
490
+ data: Union[torch.Tensor, List[str], str, Image.Image, List[Image.Image]],
491
+ data_type: str,
492
+ normalize: bool = True
493
+ ) -> torch.Tensor:
494
+ """Encodes input data (image or text) using the CLIP model.
495
+
496
+ Processes input data (images or text) to produce embeddings, with optional L2
497
+ normalization.
498
+
499
+ Parameters
500
+ ----------
501
+ `data` : Union[torch.Tensor, List[str], str, Image.Image, List[Image.Image]]
502
+ Input data to encode:
503
+ - torch.Tensor: Preprocessed image tensor (batch_size, channels, height, width).
504
+ - List[str] or str: Text or list of texts.
505
+ - PIL.Image.Image or List[PIL.Image.Image]: Single or list of PIL images.
506
+ `data_type` : str
507
+ Type of input data ('img' or 'text').
508
+ `normalize` : bool, optional
509
+ Whether to L2-normalize the output embeddings (default: True).
510
+
511
+ Returns
512
+ -------
513
+ outputs : torch.Tensor
514
+ Encoded embeddings, shape (batch_size, embedding_dim).
515
+ """
516
+ if data_type not in ["img", "text"]:
517
+ raise ValueError(f"Invalid data_type: {data_type}. Must be 'img' or 'text'.")
518
+
519
+ with torch.no_grad():
520
+ if data_type == "img":
521
+ outputs = self._encode_images(data)
522
+ else:
523
+ outputs = self._encode_texts(data)
524
+
525
+ # normalize embeddings if requested
526
+ if normalize:
527
+ outputs = F.normalize(outputs, p=2, dim=-1)
528
+
529
+ return outputs
530
+
531
+ def _encode_images(self, data: Union[torch.Tensor, Image.Image, List[Image.Image]]) -> torch.Tensor:
532
+ """Encodes images into embeddings using the CLIP model.
533
+
534
+ Processes image inputs (tensors or PIL images) to produce image embeddings.
535
+
536
+ Parameters
537
+ ----------
538
+ `data` : Union[torch.Tensor, Image.Image, List[Image.Image]]
539
+ Input images as a tensor or PIL image(s).
540
+
541
+ Returns
542
+ -------
543
+ image_features : torch.Tensor
544
+ Image embeddings, shape (batch_size, embedding_dim).
545
+ """
546
+ if isinstance(data, torch.Tensor):
547
+ if data.dim() == 3:
548
+ data = data.unsqueeze(0)
549
+ inputs = {"pixel_values": data.to(self.device)}
550
+ elif isinstance(data, (Image.Image, list)):
551
+ if isinstance(data, Image.Image):
552
+ data = [data]
553
+ inputs = self.processor(images=data, return_tensors="pt", padding=True)
554
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
555
+ else:
556
+ raise ValueError(f"Invalid image data type: {type(data)}. Expected torch.Tensor, PIL.Image.Image, or List[PIL.Image.Image].")
557
+ return self.model.get_image_features(**inputs)
558
+
559
+ def _encode_texts(self, data: Union[str, List[str], torch.Tensor]) -> torch.Tensor:
560
+ """Encodes texts into embeddings using the CLIP model.
561
+
562
+ Processes text inputs (strings or tokenized tensors) to produce text embeddings.
563
+
564
+ Parameters
565
+ ----------
566
+ `data` : Union[str, List[str], torch.Tensor]
567
+ Input texts as strings or tokenized tensor.
568
+
569
+ Returns
570
+ -------
571
+ text_features : torch.Tensor
572
+ Text embeddings, shape (batch_size, embedding_dim).
573
+ """
574
+ if isinstance(data, torch.Tensor):
575
+ data = data.to(self.device)
576
+ if data.dim() == 2:
577
+ return data
578
+ if data.dim() == 1:
579
+ data = data.unsqueeze(0)
580
+ attention_mask = torch.ones_like(data)
581
+ return self.model.get_text_features(input_ids=data, attention_mask=attention_mask)
582
+
583
+ if isinstance(data, str):
584
+ data = [data]
585
+ elif isinstance(data, list) and all(isinstance(t, str) for t in data):
586
+ pass
587
+ else:
588
+ raise ValueError(
589
+ f"Invalid text data type: {type(data)}. Expected str, List[str], or torch.Tensor."
590
+ )
591
+
592
+ inputs = self.processor(text=data, return_tensors="pt", padding=True, truncation=True)
593
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
594
+ return self.model.get_text_features(**inputs)
595
+
596
+ def compute_similarity(self, image_features: torch.Tensor, text_features: torch.Tensor) -> torch.Tensor:
597
+ """Computes cosine similarity between image and text embeddings.
598
+
599
+ Calculates the cosine similarity matrix between batches of image and text embeddings.
600
+
601
+ Parameters
602
+ ----------
603
+ `image_features` : torch.Tensor
604
+ Image embeddings, shape (batch_size, embedding_dim).
605
+ `text_features` : torch.Tensor
606
+ Text embeddings, shape (batch_size, embedding_dim).
607
+
608
+ Returns
609
+ -------
610
+ similarity : torch.Tensor
611
+ Cosine similarity scores, shape (batch_size, batch_size).
612
+ """
613
+ image_features = F.normalize(image_features, p=2, dim=-1)
614
+ text_features = F.normalize(text_features, p=2, dim=-1)
615
+ return torch.matmul(image_features, text_features.T)
616
+
617
+ ###==================================================================================================================###
618
+
619
+ class UnClipDecoder(nn.Module):
620
+ """Decoder for UnCLIP diffusion models.
621
+
622
+ Combines CLIP image embeddings and text embeddings to guide the denoising process,
623
+ using a noise predictor and diffusion processes. Incorporates classifier-free guidance,
624
+ text caption dropout, and projection of CLIP embeddings into context tokens.
625
+
626
+ Parameters
627
+ ----------
628
+ `clip_embedding_dim` : int
629
+ Dimensionality of the input embeddings.
630
+ `noise_predictor` : nn.Module
631
+ Model to predict noise during the denoising process.
632
+ `forward_diffusion` : nn.Module
633
+ Forward diffusion module (e.g., ForwardUnCLIP) for adding noise.
634
+ `reverse_diffusion` : nn.Module
635
+ Reverse diffusion module (e.g., ReverseUnCLIP) for denoising.
636
+ `glide_text_encoder` : nn.Module, optional
637
+ GLIDE text encoder for processing text prompts, default None.
638
+ `bert_tokenizer` : BertTokenizer, optional
639
+ Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
640
+ `device` : Union[str, torch.device], optional
641
+ Device for computation (default: CUDA if available, else CPU).
642
+ `image_output_range` : Tuple[float, float], optional
643
+ Range for clamping output images (default: (-1.0, 1.0)).
644
+ `normalize_clip_embeddings` : bool, optional
645
+ Whether to normalize outputs (default: True).
646
+ `classifier_free_prop` : float, optional
647
+ Probability for classifier-free guidance (default: 0.1, per paper).
648
+ `drop_caption` : float, optional
649
+ Probability for text caption dropout (default: 0.5, per paper).
650
+ `max_token_length` : int, optional
651
+ Maximum length for tokenized prompts (default: 77).
652
+ """
653
+ def __init__(
654
+ self,
655
+ clip_embedding_dim: int,
656
+ noise_predictor: nn.Module,
657
+ forward_diffusion: nn.Module,
658
+ reverse_diffusion: nn.Module,
659
+ glide_text_encoder: torch.nn.Module = None, # GLIDE text encoder
660
+ bert_tokenizer: Optional[BertTokenizer] = None,
661
+ device: Optional[Union[str, torch.device]] = None,
662
+ image_output_range: Tuple[float, float] = (-1.0, 1.0),
663
+ normalize_clip_embeddings: bool = True,
664
+ classifier_free_prop: float = 0.1, # paper specifies 10%
665
+ drop_caption: float = 0.5, # paper specifies 50%
666
+ max_token_length: int = 77 # max_token_length for tokenization
667
+ ) -> None:
668
+ super().__init__()
669
+
670
+ if device is None:
671
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
672
+ elif isinstance(device, str):
673
+ self.device = torch.device(device)
674
+ else:
675
+ self.device = device
676
+ self.clip_embedding_dim = clip_embedding_dim
677
+
678
+ # core models
679
+ self.noise_predictor = noise_predictor.to(self.device)
680
+ self.forward_diffusion = forward_diffusion.to(self.device)
681
+ self.reverse_diffusion = reverse_diffusion.to(self.device)
682
+ self.glide_text_encoder = glide_text_encoder.to(self.device) if glide_text_encoder else None
683
+
684
+ # paper: "projecting CLIP embeddings into four extra tokens of context"
685
+ self.clip_decoder_projection = CLIPContextProjection(
686
+ clip_embedding_dim=self.clip_embedding_dim,
687
+ num_tokens=4
688
+ ).to(self.device)
689
+ self.clip_time_projection = nn.Linear(self.clip_embedding_dim, self.clip_embedding_dim).to(self.device)
690
+
691
+ # training parameters
692
+ self.image_output_range = image_output_range
693
+ self.normalize_clip_embeddings = normalize_clip_embeddings
694
+ self.classifier_free_prop = classifier_free_prop
695
+ self.drop_caption = drop_caption
696
+ self.max_token_length = max_token_length
697
+
698
+ # initialize tokenizer
699
+ if bert_tokenizer is None:
700
+ try:
701
+ self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
702
+ except Exception as e:
703
+ raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
704
+
705
+
706
+ def forward(
707
+ self,
708
+ image_embeddings: torch.Tensor,
709
+ text_embeddings: torch.Tensor,
710
+ images: torch.Tensor,
711
+ texts: torch.Tensor,
712
+ p_classifier_free: float,
713
+ p_text_drop: float) -> Tuple[torch.Tensor, torch.Tensor]:
714
+ """Processes embeddings and images to predict noise for training.
715
+
716
+ Applies classifier-free guidance and text dropout, projects CLIP image embeddings
717
+ into context tokens, encodes text with GLIDE, and predicts noise for the diffusion process.
718
+
719
+ Parameters
720
+ ----------
721
+ `image_embeddings` : torch.Tensor
722
+ CLIP image embeddings, shape (batch_size, embedding_dim).
723
+ `text_embeddings` : torch.Tensor
724
+ CLIP text embeddings, shape (batch_size, embedding_dim).
725
+ `images` : torch.Tensor
726
+ Input images, shape (batch_size, channels, height, width).
727
+ `texts` : torch.Tensor
728
+ Text prompts for conditional generation.
729
+ `p_classifier_free` : float
730
+ Probability for applying classifier-free guidance.
731
+ `p_text_drop` : float
732
+ Probability for applying text caption dropout.
733
+
734
+ Returns
735
+ -------
736
+ predicted_noise : torch.Tensor
737
+ Predicted noise tensor, shape (batch_size, channels, height, width).
738
+ noise : torch.Tensor
739
+ Ground truth noise tensor, shape (batch_size, channels, height, width).
740
+ """
741
+
742
+ image_embeddings = self._apply_classifier_free_guidance(image_embeddings, p_classifier_free)
743
+ text_embeddings = self._apply_text_dropout(text_embeddings, p_text_drop)
744
+ # project z_i to 4 tokens
745
+ c = self.clip_decoder_projection(image_embeddings)
746
+ # encode text with GLIDE
747
+ y_encoded = self._encode_text_with_glide(texts if text_embeddings is not None else None)
748
+ # concatenate embeddings
749
+ context = self._concatenate_embeddings(y_encoded, c)
750
+ # sample timestep and noise
751
+ t, noise = self._sample_timestep_and_noise(images.shape[0], images.shape)
752
+ # compute noisy image
753
+ noisy_images = self.forward_diffusion(images, noise, t)
754
+ clip_image_embedding = self.clip_time_projection(image_embeddings)
755
+ predicted_noise = self.noise_predictor(noisy_images, t, context, clip_image_embedding)
756
+ return predicted_noise, noise
757
+
758
+ def inference_forward(self, image_embeddings, prompt_embeddings):
759
+ pass
760
+
761
+ def _apply_classifier_free_guidance(self, image_embeddings: torch.Tensor, p_value: float) -> torch.Tensor:
762
+ """Applies classifier-free guidance to image embeddings.
763
+
764
+ Sets image embeddings to zero with a specified probability to implement
765
+ classifier-free guidance, as described in the UnCLIP paper.
766
+
767
+ Parameters
768
+ ----------
769
+ `image_embeddings` : torch.Tensor
770
+ CLIP image embeddings, shape (batch_size, embedding_dim).
771
+ `p_value` : float
772
+ Probability for applying classifier-free guidance.
773
+
774
+ Returns
775
+ -------
776
+ image_embeddings : torch.Tensor
777
+ Modified image embeddings, shape (batch_size, embedding_dim).
778
+ """
779
+ if p_value < self.classifier_free_prop:
780
+ # set z_i ← 0 {classifier-free guidance}
781
+ image_embeddings = torch.zeros_like(image_embeddings)
782
+
783
+ return image_embeddings
784
+
785
+ def _apply_text_dropout(self, text_embeddings: torch.Tensor, p_value: float) -> Optional[torch.Tensor]:
786
+ """Applies text caption dropout to text embeddings.
787
+
788
+ Drops text embeddings with a specified probability to implement text dropout,
789
+ as described in the UnCLIP paper.
790
+
791
+ Parameters
792
+ ----------
793
+ `text_embeddings` : torch.Tensor
794
+ CLIP text embeddings, shape (batch_size, embedding_dim).
795
+ `p_value` : float
796
+ Probability for applying text caption dropout.
797
+
798
+ Returns
799
+ -------
800
+ text_embeddings : torch.Tensor or None
801
+ Modified text embeddings or None if dropped, shape (batch_size, embedding_dim).
802
+ """
803
+ if p_value < self.drop_caption:
804
+ # set y ← ∅ {drop text caption}
805
+ return None
806
+
807
+ return text_embeddings
808
+
809
+
810
+ def _encode_text_with_glide(self, texts: Union[List, torch.Tensor]) -> Optional[torch.Tensor]:
811
+ """Encodes text prompts using the GLIDE text encoder.
812
+
813
+ Tokenizes and encodes text prompts into embeddings using the GLIDE text encoder,
814
+ returning None if no text or conditional model is provided.
815
+
816
+ Parameters
817
+ ----------
818
+ `texts` : Union[List, torch.Tensor]
819
+ Text prompts or tensor of text data.
820
+
821
+ Returns
822
+ -------
823
+ y_encoded : torch.Tensor or None
824
+ Encoded text embeddings, shape (batch_size, seq_len, embedding_dim), or None.
825
+ """
826
+ if texts is None:
827
+ return None
828
+
829
+ if self.glide_text_encoder is None:
830
+ return None
831
+
832
+ # convert to string list if needed
833
+ if isinstance(texts, torch.Tensor):
834
+ texts = texts.cpu().numpy().tolist()
835
+ texts = [str(item) for item in texts]
836
+
837
+ # tokenize
838
+ tokenized = self.bert_tokenizer(
839
+ texts,
840
+ padding="max_length",
841
+ truncation=True,
842
+ max_length=self.max_token_length,
843
+ return_tensors="pt"
844
+ ).to(self.device)
845
+
846
+ # get embeddings from GLIDE text encoder
847
+ input_ids = tokenized["input_ids"]
848
+ attention_mask = tokenized["attention_mask"]
849
+ y_encoded = self.glide_text_encoder(input_ids, attention_mask)
850
+ # print("y shape: ", y_encoded.size())
851
+
852
+ return y_encoded
853
+
854
+ def _concatenate_embeddings(self, y_encoded: Optional[torch.Tensor], c: torch.Tensor) -> torch.Tensor:
855
+ """Concatenates GLIDE text embeddings and context tokens.
856
+
857
+ Combines encoded text embeddings (if available) with projected context tokens
858
+ along the sequence dimension, as specified in the UnCLIP paper.
859
+
860
+ Parameters
861
+ ----------
862
+ `y_encoded` : torch.Tensor or None
863
+ Encoded text embeddings from GLIDE, shape (batch_size, seq_len, embedding_dim).
864
+ `c` : torch.Tensor
865
+ Projected context tokens, shape (batch_size, num_tokens, embedding_dim).
866
+
867
+ Returns
868
+ -------
869
+ s : torch.Tensor
870
+ Concatenated embeddings, shape (batch_size, seq_len + num_tokens, embedding_dim).
871
+ """
872
+ if y_encoded is not None:
873
+ # ensure y_encoded has sequence dimension
874
+ if len(y_encoded.shape) == 2: # [batch_size, embed_dim]
875
+ y_encoded = y_encoded.unsqueeze(1) # [batch_size, 1, embed_dim]
876
+
877
+ # concatenate along the sequence dimension
878
+ s = torch.cat([y_encoded, c], dim=1) # [batch_size, seq_len + 4, embed_dim]
879
+ else:
880
+ s = c # [batch_size, 4, embed_dim]
881
+
882
+ return s
883
+
884
+ def _sample_timestep_and_noise(self, batch_size: int, image_shape: torch.Size) -> Tuple[torch.Tensor, torch.Tensor]:
885
+ """Samples timesteps and noise for the diffusion process.
886
+
887
+ Generates random timesteps and Gaussian noise for use in the forward diffusion process.
888
+
889
+ Parameters
890
+ ----------
891
+ `batch_size` : int
892
+ Number of samples in the batch.
893
+ `image_shape` : torch.Size
894
+ Shape of the images, typically (batch_size, channels, height, width).
895
+
896
+ Returns
897
+ -------
898
+ t : torch.Tensor
899
+ Sampled timestep indices, shape (batch_size,).
900
+ noise : torch.Tensor
901
+ Sampled Gaussian noise, shape (batch_size, channels, height, width).
902
+ """
903
+ # sample timestep t ~ Uniform(1, T)
904
+ t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
905
+ # sample noise ε ~ N(0, I)
906
+ noise = torch.randn(image_shape, device=self.device)
907
+ return t, noise
908
+
909
+ ###==================================================================================================================###
910
+
911
+ class UnCLIPTransformerPrior(nn.Module):
912
+ """Transformer-based prior model for UnCLIP diffusion.
913
+
914
+ Predicts clean image embeddings from noisy image embeddings and text embeddings using
915
+ a Transformer architecture, incorporating time embeddings and optional projection
916
+ layers for text and image inputs.
917
+
918
+ Parameters
919
+ ----------
920
+ `forward_diffusion` : nn.Module
921
+ Forward diffusion module (e.g., ForwardUnCLIP) for adding noise during training.
922
+ `reverse_diffusion` : nn.Module
923
+ Reverse diffusion module (e.g., ReverseUnCLIP) for denoising during training.
924
+ `clip_text_projection` : nn.Module, optional
925
+ Projection module for text embeddings, default None.
926
+ `clip_image_projection` : nn.Module, optional
927
+ Projection module for image embeddings, default None.
928
+ `transformer_embedding_dim` : int, optional
929
+ Dimensionality of embeddings (default: 320).
930
+ `num_layers` : int, optional
931
+ Number of Transformer layers (default: 12).
932
+ `num_attention_heads` : int, optional
933
+ Number of attention heads in each Transformer layer (default: 8).
934
+ `feedforward_dim` : int, optional
935
+ Dimensionality of the feedforward network in Transformer layers (default: 768).
936
+ `max_sequence_length` : int, optional
937
+ Maximum sequence length for input embeddings (default: 2).
938
+ `dropout_rate` : float, optional
939
+ Dropout probability for regularization (default: 0.2).
940
+ """
941
+ def __init__(
942
+ self,
943
+ forward_diffusion: nn.Module, # will be used during training
944
+ reverse_diffusion: nn.Module, # will be used during training
945
+ clip_text_projection: Optional[nn.Module] = None, # used during training instead of PCA in the main paper
946
+ clip_image_projection: Optional[nn.Module] = None, # used during training instead of PCA in the main paper
947
+ transformer_embedding_dim: int = 320,
948
+ num_layers: int = 12,
949
+ num_attention_heads: int = 8,
950
+ feedforward_dim: int = 768,
951
+ max_sequence_length: int = 2,
952
+ dropout_rate: float = 0.2
953
+ ) -> None:
954
+ super().__init__()
955
+
956
+ self.forward_diffusion = forward_diffusion
957
+ self.reverse_diffusion = reverse_diffusion
958
+ self.clip_text_projection = clip_text_projection
959
+ self.clip_image_projection = clip_image_projection
960
+
961
+ self.transformer_embedding_dim = transformer_embedding_dim
962
+ self.max_sequence_length = max_sequence_length
963
+
964
+ # time embedding network
965
+ self.time_embedding_net = nn.Sequential(
966
+ nn.Linear(transformer_embedding_dim, transformer_embedding_dim),
967
+ nn.GELU(),
968
+ nn.Linear(transformer_embedding_dim, transformer_embedding_dim)
969
+ )
970
+
971
+ # positional embeddings
972
+ self.positional_embeddings = nn.Parameter(torch.randn(max_sequence_length, transformer_embedding_dim))
973
+
974
+ # transformer layers
975
+ self.transformer_blocks = nn.ModuleList([
976
+ TransformerBlock(transformer_embedding_dim, num_attention_heads, feedforward_dim, dropout_rate)
977
+ for _ in range(num_layers)
978
+ ])
979
+
980
+ # final output projection
981
+ self.output_projection = nn.Linear(transformer_embedding_dim, transformer_embedding_dim)
982
+
983
+ def forward(
984
+ self,
985
+ text_embeddings: torch.Tensor,
986
+ noisy_image_embeddings: torch.Tensor,
987
+ timesteps: torch.Tensor
988
+ ) -> torch.Tensor:
989
+ """Predicts clean image embeddings from noisy inputs and text embeddings.
990
+
991
+ Processes text and noisy image embeddings through a Transformer architecture,
992
+ conditioned on time embeddings, to predict the clean image embeddings.
993
+
994
+ Parameters
995
+ ----------
996
+ `text_embeddings` : torch.Tensor
997
+ Text embeddings, shape (batch_size, embedding_dim).
998
+ `noisy_image_embeddings` : torch.Tensor
999
+ Noisy image embeddings, shape (batch_size, embedding_dim).
1000
+ `timesteps` : torch.Tensor
1001
+ Tensor of time step indices (long), shape (batch_size,).
1002
+
1003
+ Returns
1004
+ -------
1005
+ predicted_clean_embeddings : torch.Tensor
1006
+ Predicted clean image embeddings, shape (batch_size, embedding_dim).
1007
+ """
1008
+ device = text_embeddings.device
1009
+ # create sinusoidal time embeddings
1010
+ time_embeddings = self._get_sinusoidal_embeddings(timesteps, self.transformer_embedding_dim, device)
1011
+ time_embeddings = self.time_embedding_net(time_embeddings)
1012
+ # add time information to image embeddings
1013
+ conditioned_image_embeddings = noisy_image_embeddings + time_embeddings
1014
+ # create sequence: [text_embeddings, conditioned_image_embeddings]
1015
+ sequence = torch.stack([text_embeddings, conditioned_image_embeddings], dim=1) # [B, 2, D]
1016
+ # add positional embeddings
1017
+ sequence = sequence + self.positional_embeddings.unsqueeze(0)
1018
+ # pass through transformer blocks
1019
+ for transformer_block in self.transformer_blocks:
1020
+ sequence = transformer_block(sequence)
1021
+ # extract predicted clean image embedding (second position in sequence)
1022
+ predicted_clean_embeddings = sequence[:, 1, :] # [B, D]
1023
+ # apply final projection
1024
+ predicted_clean_embeddings = self.output_projection(predicted_clean_embeddings)
1025
+
1026
+ return predicted_clean_embeddings
1027
+
1028
+ def _get_sinusoidal_embeddings(
1029
+ self,
1030
+ timesteps: torch.Tensor,
1031
+ embedding_dim: int,
1032
+ device: Union[torch.device, str]
1033
+ ) -> torch.Tensor:
1034
+ """Generates sinusoidal positional embeddings for timesteps.
1035
+
1036
+ Creates sinusoidal embeddings for the given timesteps to condition the Transformer
1037
+ on the diffusion process time steps.
1038
+
1039
+ Parameters
1040
+ ----------
1041
+ `timesteps` : torch.Tensor
1042
+ Tensor of time step indices (long), shape (batch_size,).
1043
+ `embedding_dim` : int
1044
+ Dimensionality of the embeddings.
1045
+ `device` : Union[torch.device, str]
1046
+ Device to place the embeddings on.
1047
+
1048
+ Returns
1049
+ -------
1050
+ embeddings : torch.Tensor
1051
+ Sinusoidal time embeddings, shape (batch_size, embedding_dim).
1052
+ """
1053
+ half_dim = embedding_dim // 2
1054
+ emb = math.log(10000) / (half_dim - 1)
1055
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
1056
+ emb = timesteps[:, None].float() * emb[None, :]
1057
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
1058
+
1059
+ # handle odd embedding dimensions
1060
+ if embedding_dim % 2 == 1:
1061
+ emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
1062
+
1063
+ return emb
1064
+
1065
+
1066
+ class TransformerBlock(nn.Module):
1067
+ """Single Transformer block with multi-head attention and feedforward layers.
1068
+
1069
+ Implements a Transformer block with multi-head self-attention, layer normalization,
1070
+ and a feedforward network with residual connections for processing sequences in
1071
+ the UnCLIPTransformerPrior model.
1072
+
1073
+ Parameters
1074
+ ----------
1075
+ `embedding_dim` : int
1076
+ Dimensionality of input and output embeddings.
1077
+ `num_heads` : int
1078
+ Number of attention heads in the multi-head attention layer.
1079
+ `feedforward_dim` : int
1080
+ Dimensionality of the feedforward network.
1081
+ `dropout` : float
1082
+ Dropout probability for regularization.
1083
+ """
1084
+
1085
+ def __init__(
1086
+ self,
1087
+ embedding_dim: int,
1088
+ num_heads: int,
1089
+ feedforward_dim: int,
1090
+ dropout: float
1091
+ ) -> None:
1092
+ super().__init__()
1093
+
1094
+ self.self_attention = nn.MultiheadAttention(
1095
+ embedding_dim,
1096
+ num_heads,
1097
+ dropout=dropout,
1098
+ batch_first=True
1099
+ )
1100
+ self.attention_norm = nn.LayerNorm(embedding_dim)
1101
+ self.feedforward_norm = nn.LayerNorm(embedding_dim)
1102
+
1103
+ self.feedforward = nn.Sequential(
1104
+ nn.Linear(embedding_dim, feedforward_dim),
1105
+ nn.GELU(),
1106
+ nn.Dropout(dropout),
1107
+ nn.Linear(feedforward_dim, embedding_dim),
1108
+ nn.Dropout(dropout)
1109
+ )
1110
+
1111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1112
+ """Processes input sequence through the Transformer block.
1113
+
1114
+ Applies multi-head self-attention followed by a feedforward network, with residual
1115
+ connections and layer normalization.
1116
+
1117
+ Parameters
1118
+ ----------
1119
+ `x` : torch.Tensor
1120
+ Input sequence tensor, shape (batch_size, sequence_length, embedding_dim).
1121
+
1122
+ Returns
1123
+ -------
1124
+ output : torch.Tensor
1125
+ Processed sequence tensor, shape (batch_size, sequence_length, embedding_dim).
1126
+ """
1127
+ # self-attention with residual connection
1128
+ attn_output, _ = self.self_attention(x, x, x)
1129
+ x = self.attention_norm(x + attn_output)
1130
+
1131
+ # feedforward with residual connection
1132
+ ff_output = self.feedforward(x)
1133
+ x = self.feedforward_norm(x + ff_output)
1134
+
1135
+ return x
1136
+
1137
+ ###==================================================================================================================###
1138
+
1139
+ class CLIPContextProjection(nn.Module):
1140
+ """Projects CLIP image embeddings into multiple context tokens.
1141
+
1142
+ Transforms a single CLIP image embedding into a specified number of context tokens
1143
+ using a linear projection followed by layer normalization.
1144
+
1145
+ Parameters
1146
+ ----------
1147
+ `clip_embedding_dim` : int
1148
+ Dimensionality of the input CLIP embedding (e.g., 319 or 512).
1149
+ `num_tokens` : int, optional
1150
+ Number of context tokens to generate (default: 4).
1151
+ """
1152
+ def __init__(self, clip_embedding_dim, num_tokens=4):
1153
+ super().__init__()
1154
+ self.clip_embedding_dim = clip_embedding_dim
1155
+ self.num_tokens = num_tokens
1156
+ self.clip_projection = nn.Linear(clip_embedding_dim, clip_embedding_dim * num_tokens)
1157
+ self.clip_embedding_norm = nn.LayerNorm(clip_embedding_dim)
1158
+
1159
+ def forward(self, z_i):
1160
+ """Projects CLIP image embedding into context tokens.
1161
+
1162
+ Applies a linear projection to transform the input embedding into multiple tokens,
1163
+ reshapes the output, and applies layer normalization.
1164
+
1165
+ Parameters
1166
+ ----------
1167
+ `z_i` : torch.Tensor
1168
+ Input CLIP image embedding, shape (batch_size, input_dim).
1169
+
1170
+ Returns
1171
+ -------
1172
+ c : torch.Tensor
1173
+ Context tokens, shape (batch_size, num_tokens, input_dim).
1174
+ """
1175
+ batch_size = z_i.shape[0]
1176
+ projected = self.clip_projection(z_i)
1177
+ c = projected.view(batch_size, self.num_tokens, self.clip_embedding_dim)
1178
+ c = self.clip_embedding_norm(c)
1179
+ return c
1180
+
1181
+ ###==================================================================================================================###
1182
+
1183
+ class CLIPEmbeddingProjection(nn.Module):
1184
+ """Projection module for dimensionality reduction and reconstruction.
1185
+
1186
+ Implements a neural network with forward and inverse projections to reduce and
1187
+ restore input dimensionality, supporting customizable hidden layers, dropout, and
1188
+ layer normalization.
1189
+
1190
+ Parameters
1191
+ ----------
1192
+ `clip_embedding_dim` : int, optional
1193
+ Input dimensionality (default: 1024).
1194
+ `transformer_embedding_dim` : int, optional
1195
+ Output dimensionality for forward projection (default: 320).
1196
+ `hidden_dim` : int, optional
1197
+ Hidden layer dimensionality (default: 512).
1198
+ `num_layers` : int, optional
1199
+ Number of layers in the projection network (default: 2).
1200
+ `dropout_rate` : float, optional
1201
+ Dropout probability for regularization (default: 0.2).
1202
+ `use_layer_norm` : bool, optional
1203
+ Whether to apply layer normalization after hidden layers (default: True).
1204
+ """
1205
+ def __init__(
1206
+ self,
1207
+ clip_embedding_dim: int = 1024,
1208
+ transformer_embedding_dim: int = 320,
1209
+ hidden_dim: int = 512,
1210
+ num_layers: int = 2,
1211
+ dropout_rate: float = 0.2,
1212
+ use_layer_norm: bool = True
1213
+ ) -> None:
1214
+ super().__init__()
1215
+
1216
+ self.clip_embedding_dim = clip_embedding_dim
1217
+ self.transformer_embedding_dim = transformer_embedding_dim
1218
+
1219
+ # Forward projection: input_dim -> output_dim
1220
+ self.forward_projection = self._build_projection_network(
1221
+ clip_embedding_dim, transformer_embedding_dim, hidden_dim, num_layers, dropout_rate, use_layer_norm
1222
+ )
1223
+
1224
+ # Inverse projection: output_dim -> input_dim
1225
+ self.inverse_projection = self._build_projection_network(
1226
+ transformer_embedding_dim, clip_embedding_dim, hidden_dim, num_layers, dropout_rate, use_layer_norm
1227
+ )
1228
+ def _build_projection_network(
1229
+ self,
1230
+ input_dim: int,
1231
+ output_dim: int,
1232
+ hidden_dim: int,
1233
+ num_layers: int,
1234
+ dropout: float,
1235
+ use_layer_norm: bool
1236
+ ) -> nn.Sequential:
1237
+ """Builds a projection network with customizable layers.
1238
+
1239
+ Constructs a neural network with linear layers, optional layer normalization,
1240
+ GELU activation, and dropout for either forward or inverse projection.
1241
+
1242
+ Parameters
1243
+ ----------
1244
+ `input_dim` : int
1245
+ Input dimensionality for the network.
1246
+ `output_dim` : int
1247
+ Output dimensionality for the network.
1248
+ `hidden_dim` : int
1249
+ Hidden layer dimensionality.
1250
+ `num_layers` : int
1251
+ Number of layers in the network.
1252
+ `dropout` : float
1253
+ Dropout probability for regularization.
1254
+ `use_layer_norm` : bool
1255
+ Whether to apply layer normalization after hidden layers.
1256
+
1257
+ Returns
1258
+ -------
1259
+ network : nn.Sequential
1260
+ Sequential container of the projection network layers.
1261
+ """
1262
+ layers = []
1263
+ current_dim = input_dim
1264
+
1265
+ # Hidden layers
1266
+ for i in range(num_layers - 1):
1267
+ layers.append(nn.Linear(current_dim, hidden_dim))
1268
+ if use_layer_norm:
1269
+ layers.append(nn.LayerNorm(hidden_dim))
1270
+ layers.append(nn.GELU())
1271
+ layers.append(nn.Dropout(dropout))
1272
+ current_dim = hidden_dim
1273
+
1274
+ # Output layer
1275
+ layers.append(nn.Linear(current_dim, output_dim))
1276
+
1277
+ return nn.Sequential(*layers)
1278
+
1279
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1280
+ """Projects input to a lower-dimensional space.
1281
+
1282
+ Applies the forward projection network to reduce the dimensionality of the input tensor.
1283
+
1284
+ Parameters
1285
+ ----------
1286
+ `x` : torch.Tensor
1287
+ Input tensor to be projected, shape (batch_size, input_dim).
1288
+
1289
+ Returns
1290
+ -------
1291
+ x_reduced : torch.Tensor
1292
+ Projected tensor, shape (batch_size, output_dim).
1293
+ """
1294
+ return self.forward_projection(x)
1295
+
1296
+ def inverse_transform(self, x_reduced: torch.Tensor) -> torch.Tensor:
1297
+ """Reconstructs input from lower-dimensional space.
1298
+
1299
+ Applies the inverse projection network to restore the original dimensionality
1300
+ of the input tensor.
1301
+
1302
+ Parameters
1303
+ ----------
1304
+ `x_reduced` : torch.Tensor
1305
+ Reduced-dimensionality tensor, shape (batch_size, output_dim).
1306
+
1307
+ Returns
1308
+ -------
1309
+ x_reconstructed : torch.Tensor
1310
+ Reconstructed tensor, shape (batch_size, input_dim).
1311
+ """
1312
+ return self.inverse_projection(x_reduced)
1313
+
1314
+ def reconstruction_loss(self, x: torch.Tensor) -> torch.Tensor:
1315
+ """Computes the reconstruction loss for the projection.
1316
+
1317
+ Calculates the mean squared error between the original input and its reconstruction
1318
+ after forward and inverse projections.
1319
+
1320
+ Parameters
1321
+ ----------
1322
+ `x` : torch.Tensor
1323
+ Original input tensor, shape (batch_size, input_dim).
1324
+
1325
+ Returns
1326
+ -------
1327
+ loss : torch.Tensor
1328
+ Mean squared error loss between the original and reconstructed tensors.
1329
+ """
1330
+ x_reduced = self.forward(x)
1331
+ x_reconstructed = self.inverse_transform(x_reduced)
1332
+ return F.mse_loss(x_reconstructed, x)
1333
+
1334
+ ###==================================================================================================================###
1335
+
1336
+ class TrainUnClipDecoder(nn.Module):
1337
+ """Trainer for the UnCLIP decoder model.
1338
+
1339
+ Orchestrates the training of the UnCLIP decoder model, integrating CLIP embeddings, forward
1340
+ and reverse diffusion processes, and optional dimensionality reduction. Supports mixed
1341
+ precision, gradient accumulation, DDP, and comprehensive evaluation metrics.
1342
+
1343
+ Parameters
1344
+ ----------
1345
+ `clip_embedding_dim` : int
1346
+ Dimensionality of the input embeddings.
1347
+ `decoder_model` : nn.Module
1348
+ The UnCLIP decoder model (e.g., UnClipDecoder) to be trained.
1349
+ `clip_model` : nn.Module
1350
+ CLIP model for generating text and image embeddings.
1351
+ `train_loader` : torch.utils.data.DataLoader
1352
+ DataLoader for training data.
1353
+ `optimizer` : torch.optim.Optimizer
1354
+ Optimizer for training the decoder model.
1355
+ `objective` : Callable
1356
+ Loss function to compute the difference between predicted and target noise.
1357
+ `clip_text_projection` : nn.Module, optional
1358
+ Projection module for text embeddings, default None.
1359
+ `clip_image_projection` : nn.Module, optional
1360
+ Projection module for image embeddings, default None.
1361
+ `val_loader` : torch.utils.data.DataLoader, optional
1362
+ DataLoader for validation data, default None.
1363
+ `metrics_` : Any, optional
1364
+ Object providing evaluation metrics (e.g., FID, MSE, PSNR, SSIM, LPIPS), default None.
1365
+ `max_epochs` : int, optional
1366
+ Maximum number of training epochs (default: 1000).
1367
+ `device` : Union[str, torch.device], optional
1368
+ Device for computation (default: CUDA if available, else CPU).
1369
+ `store_path` : str, optional
1370
+ Directory to save model checkpoints (default: "unclip_decoder").
1371
+ `patience` : int, optional
1372
+ Number of epochs to wait for improvement before early stopping (default: 100).
1373
+ `warmup_epochs` : int, optional
1374
+ Number of epochs for learning rate warmup (default: 100).
1375
+ `val_frequency` : int, optional
1376
+ Frequency (in epochs) for validation (default: 10).
1377
+ `use_ddp` : bool, optional
1378
+ Whether to use Distributed Data Parallel training (default: False).
1379
+ `grad_accumulation_steps` : int, optional
1380
+ Number of gradient accumulation steps before optimizer update (default: 1).
1381
+ `log_frequency` : int, optional
1382
+ Frequency (in epochs) for printing progress (default: 1).
1383
+ `use_compilation` : bool, optional
1384
+ Whether to compile the model using torch.compile (default: False).
1385
+ `image_output_range` : Tuple[float, float], optional
1386
+ Range for clamping output images (default: (-1.0, 1.0)).
1387
+ `reduce_clip_embedding_dim` : bool, optional
1388
+ Whether to apply dimensionality reduction to embeddings (default: True).
1389
+ `transformer_embedding_dim` : int, optional
1390
+ Output dimensionality for reduced embeddings (default: 312).
1391
+ `normalize_clip_embeddings` : bool, optional
1392
+ Whether to normalize CLIP embeddings (default: True).
1393
+ `finetune_clip_projections` : bool, optional
1394
+ Whether to fine-tune projection layers (default: False).
1395
+ """
1396
+ def __init__(
1397
+ self,
1398
+ clip_embedding_dim: int,
1399
+ decoder_model: nn.Module,
1400
+ clip_model: nn.Module,
1401
+ train_loader: torch.utils.data.DataLoader,
1402
+ optimizer: torch.optim.Optimizer,
1403
+ objective: Callable,
1404
+ clip_text_projection: Optional[nn.Module] = None,
1405
+ clip_image_projection: Optional[nn.Module] = None,
1406
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
1407
+ metrics_: Optional[Any] = None,
1408
+ max_epochs: int = 1000,
1409
+ device: Optional[Union[str, torch.device]] = None,
1410
+ store_path: str = "unclip_decoder",
1411
+ patience: int = 100,
1412
+ warmup_epochs: int = 100,
1413
+ val_frequency: int = 10,
1414
+ use_ddp: bool = False,
1415
+ grad_accumulation_steps: int = 1,
1416
+ log_frequency: int = 1,
1417
+ use_compilation: bool = False,
1418
+ image_output_range: Tuple[float, float] = (-1.0, 1.0),
1419
+ reduce_clip_embedding_dim: bool = True,
1420
+ transformer_embedding_dim: int = 312,
1421
+ normalize_clip_embeddings: bool = True,
1422
+ finetune_clip_projections: bool = False # if text_projection and image_projection model should be finetune
1423
+ ):
1424
+ super().__init__()
1425
+ # training configuration
1426
+ self.use_ddp = use_ddp
1427
+ self.grad_accumulation_steps = grad_accumulation_steps
1428
+ self.use_compilation = use_compilation
1429
+ if device is None:
1430
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1431
+ elif isinstance(device, str):
1432
+ self.device = torch.device(device)
1433
+ else:
1434
+ self.device = device
1435
+
1436
+ # core models
1437
+ self.decoder_model = decoder_model.to(self.device)
1438
+ self.clip_model = clip_model.to(self.device)
1439
+
1440
+ self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
1441
+
1442
+ # setup distributed training
1443
+ if self.use_ddp:
1444
+ self._setup_ddp()
1445
+ else:
1446
+ self._setup_single_gpu()
1447
+
1448
+ # compile and wrap models
1449
+ self._compile_models()
1450
+ self._wrap_models_for_ddp()
1451
+
1452
+ # projection models (PCA equivalent in the paper)
1453
+ if self.reduce_clip_embedding_dim and clip_text_projection is not None and clip_image_projection is not None:
1454
+ self.clip_text_projection = clip_text_projection.to(self.device)
1455
+ self.clip_image_projection = clip_image_projection.to(self.device)
1456
+ else:
1457
+ self.clip_text_projection = None
1458
+ self.clip_image_projection = None
1459
+
1460
+ # training components
1461
+ self.clip_embedding_dim = transformer_embedding_dim if self.reduce_clip_embedding_dim else clip_embedding_dim
1462
+ self.metrics_ = metrics_
1463
+ self.optimizer = optimizer
1464
+ self.objective = objective
1465
+ self.train_loader = train_loader
1466
+ self.val_loader = val_loader
1467
+
1468
+ # training parameters
1469
+ self.max_epochs = max_epochs
1470
+ self.patience = patience
1471
+ self.val_frequency = val_frequency
1472
+ self.log_frequency = log_frequency
1473
+ self.image_output_range = image_output_range
1474
+ self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
1475
+ self.normalize_clip_embeddings = normalize_clip_embeddings
1476
+ self.transformer_embedding_dim = transformer_embedding_dim
1477
+ self.finetune_clip_projections = finetune_clip_projections
1478
+
1479
+
1480
+ # checkpoint management
1481
+ self.store_path = store_path
1482
+
1483
+ # learning rate scheduling
1484
+ self.scheduler = ReduceLROnPlateau(
1485
+ self.optimizer,
1486
+ patience=self.patience,
1487
+ factor=0.5
1488
+ )
1489
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
1490
+
1491
+ def forward(self) -> Tuple[List[float], float]:
1492
+ """Trains the UnCLIP decoder model to predict noise for denoising.
1493
+
1494
+ Executes the training loop, optimizing the decoder model using CLIP embeddings, mixed
1495
+ precision, gradient clipping, and learning rate scheduling. Supports validation, early
1496
+ stopping, and checkpointing.
1497
+
1498
+ Returns
1499
+ -------
1500
+ train_losses : List[float]
1501
+ List of mean training losses per epoch.
1502
+ best_val_loss : float
1503
+ Best validation or training loss achieved.
1504
+ """
1505
+ # set models to training mode
1506
+ self.decoder_model.train() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj to train mode
1507
+ if not self.decoder_model.forward_diffusion.variance_scheduler.trainable_beta: # ff beta is not trainable
1508
+ self.decoder_model.forward_diffusion.variance_scheduler.eval()
1509
+
1510
+ # set text_projection and image_projection to train mode if fine-tuning
1511
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
1512
+ if self.finetune_clip_projections:
1513
+ self.clip_text_projection.train()
1514
+ self.clip_image_projection.train()
1515
+ else:
1516
+ self.clip_text_projection.eval()
1517
+ self.clip_image_projection.eval()
1518
+
1519
+ # set CLIP model to eval mode (frozen)
1520
+ if self.clip_model is not None:
1521
+ self.clip_model.eval()
1522
+
1523
+ # initialize training components
1524
+ scaler = torch.GradScaler()
1525
+ train_losses = []
1526
+ best_val_loss = float("inf")
1527
+ wait = 0
1528
+
1529
+ # main training loop
1530
+ for epoch in range(self.max_epochs):
1531
+ # set epoch for distributed sampler if using DDP
1532
+ if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
1533
+ self.train_loader.sampler.set_epoch(epoch)
1534
+
1535
+ train_losses_epoch = []
1536
+
1537
+ # training step loop with gradient accumulation
1538
+ for step, (images, texts) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
1539
+ images = images.to(self.device, non_blocking=True)
1540
+
1541
+ # forward pass with mixed precision
1542
+ with torch.autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
1543
+ # encode text and image with CLIP
1544
+ text_embeddings, image_embeddings = self._get_clip_embeddings(images, texts)
1545
+
1546
+ # reduce dimensionality (PCA equivalent)
1547
+ text_embeddings, image_embeddings = self._apply_dimensionality_reduction(
1548
+ text_embeddings, image_embeddings
1549
+ )
1550
+
1551
+ # use decoder model to predict noise
1552
+ p_classifier_free = torch.rand(1).item()
1553
+ p_text_drop = torch.rand(1).item()
1554
+ predicted_noise, noise = self.decoder_model(
1555
+ image_embeddings,
1556
+ text_embeddings,
1557
+ images,
1558
+ texts,
1559
+ p_classifier_free,
1560
+ p_text_drop
1561
+ )
1562
+
1563
+ # compute loss
1564
+ loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
1565
+
1566
+ scaler.scale(loss).backward()
1567
+
1568
+ if (step + 1) % self.grad_accumulation_steps == 0:
1569
+ # clip gradients
1570
+ scaler.unscale_(self.optimizer)
1571
+ torch.nn.utils.clip_grad_norm_(self.decoder_model.parameters(), max_norm=1.0) # covers all submodules
1572
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
1573
+ torch.nn.utils.clip_grad_norm_(self.clip_text_projection.parameters(), max_norm=1.0)
1574
+ torch.nn.utils.clip_grad_norm_(self.clip_image_projection.parameters(), max_norm=1.0)
1575
+
1576
+ scaler.step(self.optimizer)
1577
+ scaler.update()
1578
+ self.optimizer.zero_grad()
1579
+ self.warmup_lr_scheduler.step()
1580
+ torch.cuda.empty_cache() # clear memory after optimizer step
1581
+
1582
+ train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
1583
+
1584
+ mean_train_loss = self._compute_mean_loss(train_losses_epoch)
1585
+ train_losses.append(mean_train_loss)
1586
+
1587
+ if self.master_process and (epoch + 1) % self.log_frequency == 0:
1588
+ current_lr = self.optimizer.param_groups[0]['lr']
1589
+ print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
1590
+
1591
+ current_loss = mean_train_loss
1592
+
1593
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
1594
+ val_metrics = self.validate()
1595
+ val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
1596
+
1597
+ if self.master_process:
1598
+ print(f" | Val Loss: {val_loss:.4f}", end="")
1599
+ if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
1600
+ print(f" | FID: {fid:.4f}", end="")
1601
+ if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
1602
+ print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
1603
+ if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
1604
+ print(f" | LPIPS: {lpips_score:.4f}", end="")
1605
+ print()
1606
+
1607
+ self.scheduler.step(current_loss)
1608
+
1609
+ if self.master_process:
1610
+ if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
1611
+ best_val_loss = current_loss
1612
+ wait = 0
1613
+ self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
1614
+ else:
1615
+ wait += 1
1616
+ if wait >= self.patience:
1617
+ print("Early stopping triggered")
1618
+ self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
1619
+ break
1620
+
1621
+ if self.use_ddp:
1622
+ destroy_process_group()
1623
+
1624
+ return train_losses, best_val_loss
1625
+
1626
+ def _setup_ddp(self) -> None:
1627
+ """Sets up Distributed Data Parallel training configuration.
1628
+
1629
+ Initializes the process group, sets up rank information, and configures the CUDA
1630
+ device for the current process in DDP mode.
1631
+ """
1632
+ required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
1633
+ for var in required_env_vars:
1634
+ if var not in os.environ:
1635
+ raise ValueError(f"DDP enabled but {var} environment variable not set")
1636
+
1637
+ if not torch.cuda.is_available():
1638
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
1639
+
1640
+ if not torch.distributed.is_initialized():
1641
+ init_process_group(backend="nccl")
1642
+
1643
+ self.ddp_rank = int(os.environ["RANK"])
1644
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
1645
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"])
1646
+
1647
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
1648
+ torch.cuda.set_device(self.device)
1649
+
1650
+ self.master_process = self.ddp_rank == 0
1651
+
1652
+ if self.master_process:
1653
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
1654
+
1655
+ def _setup_single_gpu(self) -> None:
1656
+ """Sets up single GPU or CPU training configuration.
1657
+
1658
+ Configures the training setup for single-device operation, setting rank and process
1659
+ information for non-DDP training.
1660
+ """
1661
+ self.ddp_rank = 0
1662
+ self.ddp_local_rank = 0
1663
+ self.ddp_world_size = 1
1664
+ self.master_process = True
1665
+
1666
+ @staticmethod
1667
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
1668
+ """Creates a learning rate scheduler for warmup.
1669
+
1670
+ Generates a scheduler that linearly increases the learning rate from 0 to the
1671
+ optimizer's initial value over the specified warmup epochs.
1672
+
1673
+ Parameters
1674
+ ----------
1675
+ `optimizer` : torch.optim.Optimizer
1676
+ Optimizer to apply the scheduler to.
1677
+ `warmup_epochs` : int
1678
+ Number of epochs for the warmup phase.
1679
+
1680
+ Returns
1681
+ -------
1682
+ lr_scheduler : torch.optim.lr_scheduler.LambdaLR
1683
+ Learning rate scheduler for warmup.
1684
+ """
1685
+ def lr_lambda(epoch):
1686
+ return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
1687
+
1688
+ return LambdaLR(optimizer, lr_lambda)
1689
+
1690
+ def _wrap_models_for_ddp(self) -> None:
1691
+ """Wraps models with DistributedDataParallel for multi-GPU training.
1692
+
1693
+ Configures the decoder model and, if fine-tuning, the projection models for DDP training.
1694
+ """
1695
+ if self.use_ddp:
1696
+ self.decoder_model = self.decoder_model.to(self.ddp_local_rank)
1697
+ self.decoder_model = DDP(
1698
+ self.decoder_model,
1699
+ device_ids=[self.ddp_local_rank],
1700
+ find_unused_parameters=True
1701
+ )
1702
+ # only wrap text_projection and image_projection if they are trainable
1703
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
1704
+ self.clip_text_projection = self.clip_text_projection.to(self.ddp_local_rank)
1705
+ self.clip_image_projection = self.clip_image_projection.to(self.ddp_local_rank)
1706
+ self.clip_text_projection = DDP(self.clip_text_projection, device_ids=[self.ddp_local_rank])
1707
+ self.clip_image_projection = DDP(self.clip_image_projection, device_ids=[self.ddp_local_rank])
1708
+
1709
+ def _compile_models(self) -> None:
1710
+ """Compiles models for optimization if supported.
1711
+
1712
+ Attempts to compile the decoder model and, if fine-tuning, the projection models using
1713
+ torch.compile for optimization, falling back to uncompiled execution if compilation fails.
1714
+ """
1715
+ if self.use_compilation:
1716
+ try:
1717
+ self.decoder_model = self.decoder_model.to(self.device)
1718
+ self.decoder_model = torch.compile(self.decoder_model, mode="reduce-overhead")
1719
+ # only compile text_projection and image_projection if they are trainable
1720
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
1721
+ self.clip_text_projection = self.clip_text_projection.to(self.device)
1722
+ self.clip_image_projection = self.clip_image_projection.to(self.device)
1723
+ self.clip_text_projection = torch.compile(self.clip_text_projection, mode="reduce-overhead")
1724
+ self.clip_image_projection = torch.compile(self.clip_image_projection, mode="reduce-overhead")
1725
+ if self.master_process:
1726
+ print("Models compiled successfully")
1727
+ except Exception as e:
1728
+ if self.master_process:
1729
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
1730
+
1731
+ def _get_clip_embeddings(
1732
+ self,
1733
+ images: torch.Tensor,
1734
+ texts: Union[List, torch.Tensor]
1735
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1736
+ """Encodes images and texts using the CLIP model.
1737
+
1738
+ Generates text and image embeddings using the CLIP model, with optional normalization.
1739
+
1740
+ Parameters
1741
+ ----------
1742
+ `images` : torch.Tensor
1743
+ Input images, shape (batch_size, channels, height, width).
1744
+ `texts` : Union[List, torch.Tensor]
1745
+ Text prompts for conditional generation.
1746
+
1747
+ Returns
1748
+ -------
1749
+ text_embeddings : torch.Tensor
1750
+ CLIP text embeddings, shape (batch_size, embedding_dim).
1751
+ image_embeddings : torch.Tensor
1752
+ CLIP image embeddings, shape (batch_size, embedding_dim).
1753
+ """
1754
+ with torch.no_grad():
1755
+ # encode text y with CLIP text encoder: z_t ← CLIP_text(y)
1756
+ text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize_clip_embeddings)
1757
+ # encode image x with CLIP image encoder: z_i ← CLIP_image(x)
1758
+ image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize_clip_embeddings)
1759
+ return text_embeddings, image_embeddings
1760
+
1761
+ def _apply_dimensionality_reduction(
1762
+ self,
1763
+ text_embeddings: torch.Tensor,
1764
+ image_embeddings: torch.Tensor
1765
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1766
+ """Applies dimensionality reduction to embeddings if enabled.
1767
+
1768
+ Projects text and image embeddings to a lower-dimensional space using learned
1769
+ projection layers, mimicking PCA as used in the UnCLIP paper.
1770
+
1771
+ Parameters
1772
+ ----------
1773
+ `text_embeddings` : torch.Tensor
1774
+ CLIP text embeddings, shape (batch_size, embedding_dim).
1775
+ `image_embeddings` : torch.Tensor
1776
+ CLIP image embeddings, shape (batch_size, embedding_dim).
1777
+
1778
+ Returns
1779
+ -------
1780
+ text_embeddings : torch.Tensor
1781
+ Projected text embeddings, shape (batch_size, output_dim) if reduced, else unchanged.
1782
+ image_embeddings : torch.Tensor
1783
+ Projected image embeddings, shape (batch_size, output_dim) if reduced, else unchanged.
1784
+ """
1785
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
1786
+ if not self.finetune_clip_projections:
1787
+ with torch.no_grad():
1788
+ text_embeddings = self.clip_text_projection(text_embeddings.to(self.device))
1789
+ image_embeddings = self.clip_image_projection(image_embeddings.to(self.device))
1790
+ else:
1791
+ text_embeddings = self.clip_text_projection(text_embeddings.to(self.device))
1792
+ image_embeddings = self.clip_image_projection(image_embeddings.to(self.device))
1793
+ return text_embeddings.to(self.device), image_embeddings.to(self.device)
1794
+
1795
+ def _compute_mean_loss(self, losses: List[float]) -> float:
1796
+ """Computes mean loss with DDP synchronization if needed.
1797
+
1798
+ Calculates the mean of the provided losses and synchronizes the result across
1799
+ processes in DDP mode.
1800
+
1801
+ Parameters
1802
+ ----------
1803
+ `losses` : List[float]
1804
+ List of loss values for the current epoch.
1805
+
1806
+ Returns
1807
+ -------
1808
+ mean_loss : float
1809
+ Mean loss value, synchronized if using DDP.
1810
+ """
1811
+ if not losses:
1812
+ return 0.0
1813
+ mean_loss = sum(losses) / len(losses)
1814
+ if self.use_ddp:
1815
+ # synchronize loss across all processes
1816
+ loss_tensor = torch.tensor(mean_loss, device=self.device)
1817
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
1818
+ mean_loss = (loss_tensor / self.ddp_world_size).item()
1819
+
1820
+ return mean_loss
1821
+
1822
+ def _save_checkpoint(self, epoch: int, loss: float, is_best: bool = False, suffix: str = ""):
1823
+ """Saves model checkpoint.
1824
+
1825
+ Saves the state of the decoder model, its submodules, optimizer, and schedulers,
1826
+ with options for best model and epoch-specific checkpoints.
1827
+
1828
+ Parameters
1829
+ ----------
1830
+ `epoch` : int
1831
+ Current epoch number.
1832
+ `loss` : float
1833
+ Current loss value.
1834
+ `is_best` : bool, optional
1835
+ Whether to save as the best model checkpoint (default: False).
1836
+ `suffix` : str, optional
1837
+ Suffix to add to checkpoint filename, default "".
1838
+ """
1839
+ if not self.master_process:
1840
+ return
1841
+ checkpoint = {
1842
+ 'epoch': epoch,
1843
+ 'loss': loss,
1844
+ # core models (submodules of decoder_model)
1845
+ 'noise_predictor_state_dict': self.decoder_model.module.noise_predictor.state_dict() if self.use_ddp else self.decoder_model.noise_predictor.state_dict(),
1846
+ 'optimizer_state_dict': self.optimizer.state_dict(),
1847
+ # training configuration
1848
+ 'embedding_dim': self.clip_embedding_dim,
1849
+ 'output_dim': self.transformer_embedding_dim,
1850
+ 'reduce_dim': self.reduce_clip_embedding_dim,
1851
+ 'normalize': self.normalize_clip_embeddings
1852
+ }
1853
+
1854
+ # save conditional model (submodule of decoder_model)
1855
+ if self.decoder_model.glide_text_encoder is not None:
1856
+ checkpoint['conditional_model_state_dict'] = (
1857
+ self.decoder_model.module.glide_text_encoder.state_dict() if self.use_ddp
1858
+ else self.decoder_model.glide_text_encoder.state_dict()
1859
+ )
1860
+
1861
+ # save variance scheduler (submodule of decoder_model, always saved)
1862
+ checkpoint['variance_scheduler_state_dict'] = (
1863
+ self.decoder_model.forward_diffusion.module.variance_scheduler.state_dict() if self.use_ddp
1864
+ else self.decoder_model.forward_diffusion.variance_scheduler.state_dict()
1865
+ )
1866
+
1867
+ # save CLIP time projection layer (submodule of decoder_model)
1868
+ checkpoint['clip_time_proj_state_dict'] = (
1869
+ self.decoder_model.module.clip_time_projection.state_dict() if self.use_ddp
1870
+ else self.decoder_model.clip_time_projection.state_dict()
1871
+ )
1872
+
1873
+ # save decoder projection layer (submodule of decoder_model)
1874
+ checkpoint['decoder_projection_state_dict'] = (
1875
+ self.decoder_model.module.clip_decoder_projection.state_dict() if self.use_ddp
1876
+ else self.decoder_model.clip_decoder_projection.state_dict()
1877
+ )
1878
+ # a nn.Linear projection layer
1879
+ checkpoint['clip_time_projection_state_dict'] = (
1880
+ self.decoder_model.module.clip_time_projection.state_dict() if self.use_ddp
1881
+ else self.decoder_model.clip_time_projection.state_dict()
1882
+ )
1883
+
1884
+ # save projection models (PCA equivalent)
1885
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
1886
+ checkpoint['text_projection_state_dict'] = (
1887
+ self.clip_text_projection.module.state_dict() if self.use_ddp
1888
+ else self.clip_text_projection.state_dict()
1889
+ )
1890
+ checkpoint['image_projection_state_dict'] = (
1891
+ self.clip_image_projection.module.state_dict() if self.use_ddp
1892
+ else self.clip_image_projection.state_dict()
1893
+ )
1894
+
1895
+ # save schedulers state
1896
+ checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
1897
+ checkpoint['warmup_scheduler_state_dict'] = self.warmup_lr_scheduler.state_dict()
1898
+
1899
+ filename = f"unclip_decoder_epoch_{epoch}{suffix}.pth"
1900
+ if is_best:
1901
+ filename = f"unclip_decoder_best{suffix}.pth"
1902
+
1903
+ filepath = os.path.join(self.store_path, filename)
1904
+ os.makedirs(self.store_path, exist_ok=True)
1905
+ torch.save(checkpoint, filepath)
1906
+
1907
+ if is_best:
1908
+ print(f"Best model saved: {filepath}")
1909
+
1910
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
1911
+ """Loads model checkpoint.
1912
+
1913
+ Restores the state of the decoder model, its submodules, optimizer, and schedulers
1914
+ from a saved checkpoint, handling DDP compatibility.
1915
+
1916
+ Parameters
1917
+ ----------
1918
+ `checkpoint_path` : str
1919
+ Path to the checkpoint file.
1920
+
1921
+ Returns
1922
+ -------
1923
+ epoch : int
1924
+ The epoch at which the checkpoint was saved.
1925
+ loss : float
1926
+ The loss at the checkpoint.
1927
+ """
1928
+ try:
1929
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
1930
+ except FileNotFoundError:
1931
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
1932
+
1933
+ def _load_model_state_dict(model: nn.Module, state_dict: dict, model_name: str) -> None:
1934
+ """Helper function to load state dict with DDP compatibility."""
1935
+ try:
1936
+ # handle DDP state dict compatibility
1937
+ if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
1938
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
1939
+ elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
1940
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
1941
+
1942
+ model.load_state_dict(state_dict)
1943
+ if self.master_process:
1944
+ print(f"✓ Loaded {model_name}")
1945
+ except Exception as e:
1946
+ warnings.warn(f"Failed to load {model_name}: {e}")
1947
+
1948
+ # load core noise predictor model (submodule of decoder_model)
1949
+ if 'noise_predictor_state_dict' in checkpoint:
1950
+ _load_model_state_dict(self.decoder_model.noise_predictor, checkpoint['noise_predictor_state_dict'],
1951
+ 'noise_predictor')
1952
+
1953
+ # load conditional model (submodule of decoder_model) - matches your save logic
1954
+ if self.decoder_model.glide_text_encoder is not None and 'conditional_model_state_dict' in checkpoint:
1955
+ _load_model_state_dict(self.decoder_model.glide_text_encoder, checkpoint['conditional_model_state_dict'],
1956
+ 'glide_text_encoder')
1957
+
1958
+ # load variance scheduler (submodule of decoder_model)
1959
+ if 'variance_scheduler_state_dict' in checkpoint:
1960
+ try:
1961
+ _load_model_state_dict(self.decoder_model.forward_diffusion.variance_scheduler,
1962
+ checkpoint['variance_scheduler_state_dict'], 'variance_scheduler')
1963
+ except Exception as e:
1964
+ warnings.warn(f"Failed to load variance scheduler: {e}")
1965
+
1966
+ # load CLIP time projection layer (submodule of decoder_model)
1967
+ if 'clip_time_proj_state_dict' in checkpoint:
1968
+ try:
1969
+ _load_model_state_dict(self.decoder_model.clip_time_projection,
1970
+ checkpoint['clip_time_proj_state_dict'], 'clip_time_projection')
1971
+ except Exception as e:
1972
+ warnings.warn(f"Failed to load CLIP time projection: {e}")
1973
+
1974
+ # load decoder projection layer (submodule of decoder_model)
1975
+ if 'decoder_projection_state_dict' in checkpoint:
1976
+ try:
1977
+ _load_model_state_dict(self.decoder_model.clip_decoder_projection,
1978
+ checkpoint['decoder_projection_state_dict'], 'clip_decoder_projection')
1979
+ except Exception as e:
1980
+ warnings.warn(f"Failed to load decoder projection: {e}")
1981
+
1982
+ # handle the duplicate clip_time_projection_state_dict (from your save function)
1983
+ # This loads the same thing as clip_time_proj_state_dict above, so we'll skip it
1984
+ # to avoid overwriting, but add a warning if it exists
1985
+ if 'clip_time_projection_state_dict' in checkpoint and self.master_process:
1986
+ warnings.warn(
1987
+ "Found duplicate 'clip_time_projection_state_dict' in checkpoint - skipping to avoid conflict")
1988
+
1989
+ # load projection models (PCA equivalent)
1990
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
1991
+ if 'text_projection_state_dict' in checkpoint:
1992
+ _load_model_state_dict(self.clip_text_projection, checkpoint['text_projection_state_dict'],
1993
+ 'text_projection')
1994
+ if 'image_projection_state_dict' in checkpoint:
1995
+ _load_model_state_dict(self.clip_image_projection, checkpoint['image_projection_state_dict'],
1996
+ 'image_projection')
1997
+
1998
+ # load optimizer
1999
+ if 'optimizer_state_dict' in checkpoint:
2000
+ try:
2001
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
2002
+ if self.master_process:
2003
+ print("✓ Loaded optimizer")
2004
+ except Exception as e:
2005
+ warnings.warn(f"Failed to load optimizer state: {e}")
2006
+
2007
+ # load schedulers
2008
+ if 'scheduler_state_dict' in checkpoint:
2009
+ try:
2010
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
2011
+ if self.master_process:
2012
+ print("✓ Loaded main scheduler")
2013
+ except Exception as e:
2014
+ warnings.warn(f"Failed to load scheduler state: {e}")
2015
+
2016
+ if 'warmup_scheduler_state_dict' in checkpoint:
2017
+ try:
2018
+ self.warmup_lr_scheduler.load_state_dict(checkpoint['warmup_scheduler_state_dict'])
2019
+ if self.master_process:
2020
+ print("✓ Loaded warmup scheduler")
2021
+ except Exception as e:
2022
+ warnings.warn(f"Failed to load warmup scheduler state: {e}")
2023
+
2024
+ # verify configuration compatibility
2025
+ if 'embedding_dim' in checkpoint:
2026
+ if checkpoint['embedding_dim'] != self.clip_embedding_dim:
2027
+ warnings.warn(
2028
+ f"Embedding dimension mismatch: checkpoint={checkpoint['embedding_dim']}, current={self.clip_embedding_dim}")
2029
+
2030
+ if 'reduce_dim' in checkpoint:
2031
+ if checkpoint['reduce_dim'] != self.reduce_clip_embedding_dim:
2032
+ warnings.warn(
2033
+ f"Reduce dimension setting mismatch: checkpoint={checkpoint['reduce_dim']}, current={self.reduce_clip_embedding_dim}")
2034
+
2035
+ epoch = checkpoint.get('epoch', 0)
2036
+ loss = checkpoint.get('loss', float('inf'))
2037
+
2038
+ if self.master_process:
2039
+ print(f"Successfully loaded checkpoint from {checkpoint_path}")
2040
+ print(f"Epoch: {epoch}, Loss: {loss:.4f}")
2041
+
2042
+ return epoch, loss
2043
+
2044
+
2045
+ def validate(self) -> Tuple[float, Optional[float], Optional[float], Optional[float], Optional[float], Optional[float]]:
2046
+ """Validates the UnCLIP decoder model.
2047
+
2048
+ Computes validation loss and optional metrics (FID, MSE, PSNR, SSIM, LPIPS) by
2049
+ encoding images and texts, applying forward diffusion, predicting noise, and
2050
+ reconstructing images through reverse diffusion.
2051
+
2052
+ Returns
2053
+ -------
2054
+ val_loss : float
2055
+ Mean validation loss.
2056
+ fid_avg : float or None
2057
+ Average FID score, if computed.
2058
+ mse_avg : float or None
2059
+ Average MSE score, if computed.
2060
+ psnr_avg : float or None
2061
+ Average PSNR score, if computed.
2062
+ ssim_avg : float or None
2063
+ Average SSIM score, if computed.
2064
+ lpips_avg : float or None
2065
+ Average LPIPS score, if computed.
2066
+ """
2067
+
2068
+ # set models to eval mode for evaluation
2069
+ self.decoder_model.eval() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj, decoder_projection to eval mode
2070
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
2071
+ self.clip_text_projection.eval()
2072
+ self.clip_image_projection.eval()
2073
+ if self.clip_model is not None:
2074
+ self.clip_model.eval()
2075
+
2076
+ val_losses = []
2077
+ fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
2078
+
2079
+ with torch.no_grad():
2080
+ for images, texts in self.val_loader:
2081
+ images = images.to(self.device, non_blocking=True)
2082
+ images_orig = images.clone()
2083
+ text_embeddings, image_embeddings = self._get_clip_embeddings(images, texts)
2084
+ text_embeddings, image_embeddings = self._apply_dimensionality_reduction(
2085
+ text_embeddings, image_embeddings
2086
+ )
2087
+ p_classifier_free = torch.rand(1).item()
2088
+ p_text_drop = torch.rand(1).item()
2089
+ predicted_noise, noise = self.decoder_model(
2090
+ image_embeddings,
2091
+ text_embeddings,
2092
+ images,
2093
+ texts,
2094
+ p_classifier_free,
2095
+ p_text_drop
2096
+ )
2097
+ loss = self.objective(predicted_noise, noise)
2098
+ val_losses.append(loss.item())
2099
+
2100
+ if self.metrics_ is not None and self.decoder_model.reverse_diffusion is not None:
2101
+ xt = torch.randn_like(images).to(self.device)
2102
+ for t in reversed(range(self.decoder_model.forward_diffusion.variance_scheduler.tau_num_steps)):
2103
+ time_steps = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
2104
+ prev_time_steps = torch.full((xt.shape[0],), max(t - 1, 0), device=self.device, dtype=torch.long)
2105
+ image_embeddings = self.decoder_model._apply_classifier_free_guidance(image_embeddings, p_classifier_free)
2106
+ text_embeddings = self.decoder_model._apply_text_dropout(text_embeddings, p_text_drop)
2107
+ c = self.decoder_model.clip_decoder_projection(image_embeddings)
2108
+ y_encoded = self.decoder_model._encode_text_with_glide(texts if text_embeddings is not None else None)
2109
+ context = self.decoder_model._concatenate_embeddings(y_encoded, c)
2110
+ clip_image_embedding = self.decoder_model.clip_time_projection(image_embeddings)
2111
+ predicted_noise = self.decoder_model.noise_predictor(xt, time_steps, context, clip_image_embedding)
2112
+ xt, _ = self.decoder_model.reverse_diffusion(xt, predicted_noise, time_steps, prev_time_steps)
2113
+
2114
+ x_hat = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
2115
+
2116
+ if self.normalize_clip_embeddings:
2117
+ x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
2118
+ x_orig = (images_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
2119
+
2120
+ metrics_result = self.metrics_.forward(x_orig, x_hat)
2121
+ fid = metrics_result[0] if getattr(self.metrics_, 'fid', False) else float('inf')
2122
+ mse = metrics_result[1] if getattr(self.metrics_, 'metrics', False) else None
2123
+ psnr = metrics_result[2] if getattr(self.metrics_, 'metrics', False) else None
2124
+ ssim = metrics_result[3] if getattr(self.metrics_, 'metrics', False) else None
2125
+ lpips_score = metrics_result[4] if getattr(self.metrics_, 'lpips', False) else None
2126
+
2127
+ if fid != float('inf'):
2128
+ fid_scores.append(fid)
2129
+ if mse is not None:
2130
+ mse_scores.append(mse)
2131
+ if psnr is not None:
2132
+ psnr_scores.append(psnr)
2133
+ if ssim is not None:
2134
+ ssim_scores.append(ssim)
2135
+ if lpips_score is not None:
2136
+ lpips_scores.append(lpips_score)
2137
+
2138
+ # compute averages
2139
+ val_loss = torch.tensor(val_losses).mean().item()
2140
+ fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
2141
+ mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
2142
+ psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
2143
+ ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
2144
+ lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
2145
+
2146
+ # synchronize metrics across GPUs in DDP mode
2147
+ if self.use_ddp:
2148
+ metrics = [val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg]
2149
+ metrics_tensors = [torch.tensor(m, device=self.device) if m is not None else torch.tensor(float('inf'), device=self.device) for m in metrics]
2150
+ for tensor in metrics_tensors:
2151
+ dist.all_reduce(tensor, op=dist.ReduceOp.AVG)
2152
+ val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg = [t.item() if t.item() != float('inf') else (None if i > 1 else float('inf')) for i, t in enumerate(metrics_tensors)]
2153
+
2154
+ # return to training mode
2155
+ self.decoder_model.train() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj, decoder_projection to train mode
2156
+ if not self.decoder_model.forward_diffusion.variance_scheduler.trainable_beta:
2157
+ self.decoder_model.forward_diffusion.variance_scheduler.eval()
2158
+ self.decoder_model.reverse_diffusion.variance_scheduler.eval()
2159
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
2160
+ if self.finetune_clip_projections:
2161
+ self.clip_text_projection.train()
2162
+ self.clip_image_projection.train()
2163
+ else:
2164
+ self.clip_text_projection.eval()
2165
+ self.clip_image_projection.eval()
2166
+ if self.clip_model is not None:
2167
+ self.clip_model.eval()
2168
+
2169
+ return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
2170
+
2171
+ ###==================================================================================================================###
2172
+
2173
+ class TrainUnCLIPPrior(nn.Module):
2174
+ """Trainer for the UnCLIPTransformerPrior model.
2175
+
2176
+ Handles the training of the UnCLIP prior model to predict clean image embeddings from
2177
+ noisy image embeddings and text embeddings, with support for dimension reduction,
2178
+ mixed precision training, and distributed training.
2179
+
2180
+ Parameters
2181
+ ----------
2182
+ `prior_model` : nn.Module
2183
+ The UnCLIP prior model to be trained (e.g., UnCLIPTransformerPrior).
2184
+ `clip_model` : nn.Module
2185
+ CLIP model for encoding text and images.
2186
+ `train_loader` : torch.utils.data.DataLoader
2187
+ DataLoader for training data.
2188
+ `optimizer` : torch.optim.Optimizer
2189
+ Optimizer for training the prior model.
2190
+ `objective` : Callable
2191
+ Loss function to compute the difference between predicted and target embeddings.
2192
+ `val_loader` : torch.utils.data.DataLoader, optional
2193
+ DataLoader for validation data, default None.
2194
+ `max_epochs` : int, optional
2195
+ Maximum number of training epochs (default: 1000).
2196
+ `device` : Union[str, torch.device], optional
2197
+ Device for computation (default: CUDA if available, else CPU).
2198
+ `store_path` : str, optional
2199
+ Directory path to save model checkpoints, default None.
2200
+ `patience` : int, optional
2201
+ Number of epochs to wait for improvement before early stopping (default: 100).
2202
+ `warmup_epochs` : int, optional
2203
+ Number of epochs for learning rate warmup (default: 100).
2204
+ `val_frequency` : int, optional
2205
+ Frequency (in epochs) for validation (default: 10).
2206
+ `use_ddp` : bool, optional
2207
+ Whether to use Distributed Data Parallel training (default: False).
2208
+ `num_grad_accumulation` : int, optional
2209
+ Number of gradient accumulation steps before optimizer update (default: 1).
2210
+ `log_frequency` : int, optional
2211
+ Frequency (in epochs) for printing training progress (default: 1).
2212
+ `use_compilation` : bool, optional
2213
+ Whether to compile models for optimization (default: False).
2214
+ `embedding_output_range` : Tuple[float, float], optional
2215
+ Range for clamping output embeddings (default: (-1.0, 1.0)).
2216
+ `reduce_clip_embedding_dim` : bool, optional
2217
+ Whether to apply dimension reduction to embeddings (default: True).
2218
+ `transformer_embedding_dim` : int, optional
2219
+ Target dimensionality for reduced embeddings (default: 319).
2220
+ `normalize` : bool, optional
2221
+ Whether to normalize CLIP embeddings (default: True).
2222
+ """
2223
+
2224
+ def __init__(
2225
+ self,
2226
+ prior_model: nn.Module,
2227
+ clip_model: nn.Module,
2228
+ train_loader: torch.utils.data.DataLoader,
2229
+ optimizer: torch.optim.Optimizer,
2230
+ objective: Callable,
2231
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
2232
+ max_epochs: int = 1000,
2233
+ device: Optional[Union[str, torch.device]] = None,
2234
+ store_path: Optional[str] = None,
2235
+ patience: int = 100,
2236
+ warmup_epochs: int = 100,
2237
+ val_frequency: int = 10,
2238
+ use_ddp: bool = False,
2239
+ grad_accumulation_steps: int = 1,
2240
+ log_frequency: int = 1,
2241
+ use_compilation: bool = False,
2242
+ embedding_output_range: Tuple[float, float] = (-1.0, 1.0),
2243
+ reduce_clip_embedding_dim: bool = True,
2244
+ transformer_embedding_dim: int = 319,
2245
+ normalize_clip_embeddings: bool = True
2246
+ ) -> None:
2247
+ super().__init__()
2248
+
2249
+ # training configuration
2250
+ self.use_ddp = use_ddp
2251
+ self.grad_accumulation_steps = grad_accumulation_steps
2252
+ if device is None:
2253
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2254
+ elif isinstance(device, str):
2255
+ self.device = torch.device(device)
2256
+ else:
2257
+ self.device = device
2258
+
2259
+ # setup distributed training
2260
+ if self.use_ddp:
2261
+ self._setup_ddp()
2262
+ else:
2263
+ self._setup_single_gpu()
2264
+
2265
+ # core models
2266
+ self.prior_model = prior_model.to(self.device)
2267
+ self.clip_model = clip_model.to(self.device)
2268
+
2269
+ # training components
2270
+ self.optimizer = optimizer
2271
+ self.objective = objective
2272
+ self.train_loader = train_loader
2273
+ self.val_loader = val_loader
2274
+
2275
+ # training parameters
2276
+ self.max_epochs = max_epochs
2277
+ self.patience = patience
2278
+ self.val_frequency = val_frequency
2279
+ self.log_frequency = log_frequency
2280
+ self.use_compilation = use_compilation
2281
+ self.embedding_output_range = embedding_output_range
2282
+ self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
2283
+ self.normalize_clip_embeddings = normalize_clip_embeddings
2284
+ self.transformer_embedding_dim = transformer_embedding_dim
2285
+
2286
+ # checkpoint management
2287
+ self.store_path = store_path
2288
+
2289
+ # learning rate scheduling
2290
+ self.scheduler = ReduceLROnPlateau(
2291
+ self.optimizer,
2292
+ patience=self.patience,
2293
+ factor=0.5
2294
+ )
2295
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
2296
+
2297
+
2298
+ def _setup_ddp(self) -> None:
2299
+ """Sets up Distributed Data Parallel training configuration.
2300
+
2301
+ Initializes the process group, sets up rank information, and configures the CUDA
2302
+ device for the current process.
2303
+
2304
+ Raises
2305
+ ------
2306
+ ValueError
2307
+ If required DDP environment variables (RANK, LOCAL_RANK, WORLD_SIZE) are not set.
2308
+ RuntimeError
2309
+ If CUDA is not available when DDP is enabled.
2310
+ """
2311
+
2312
+ required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
2313
+ for var in required_env_vars:
2314
+ if var not in os.environ:
2315
+ raise ValueError(f"DDP enabled but {var} environment variable not set")
2316
+
2317
+ # ensure CUDA is available for DDP
2318
+ if not torch.cuda.is_available():
2319
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
2320
+
2321
+ # initialize process group only if not already initialized
2322
+ if not torch.distributed.is_initialized():
2323
+ init_process_group(backend="nccl")
2324
+
2325
+ # get rank information
2326
+ self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
2327
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
2328
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
2329
+
2330
+ # set device and make it current
2331
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
2332
+ torch.cuda.set_device(self.device)
2333
+
2334
+ # master process handles logging, checkpointing, etc.
2335
+ self.master_process = self.ddp_rank == 0
2336
+
2337
+ if self.master_process:
2338
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
2339
+
2340
+
2341
+ def _setup_single_gpu(self) -> None:
2342
+ """Sets up single GPU or CPU training configuration.
2343
+
2344
+ Configures the training setup for single-device operation, setting rank and process
2345
+ information for non-DDP training.
2346
+ """
2347
+ self.ddp_rank = 0
2348
+ self.ddp_local_rank = 0
2349
+ self.ddp_world_size = 1
2350
+ self.master_process = True
2351
+
2352
+ @staticmethod
2353
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
2354
+ """Creates a learning rate scheduler for warmup.
2355
+
2356
+ Generates a scheduler that linearly increases the learning rate from 0 to the
2357
+ optimizer's initial value over the specified warmup epochs.
2358
+
2359
+ Parameters
2360
+ ----------
2361
+ `optimizer` : torch.optim.Optimizer
2362
+ Optimizer to apply the scheduler to.
2363
+ `warmup_epochs` : int
2364
+ Number of epochs for the warmup phase.
2365
+
2366
+ Returns
2367
+ -------
2368
+ lr_scheduler : torch.optim.lr_scheduler.LambdaLR
2369
+ Learning rate scheduler for warmup.
2370
+ """
2371
+ def lr_lambda(epoch):
2372
+ return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
2373
+ return LambdaLR(optimizer, lr_lambda)
2374
+
2375
+ def _wrap_models_for_ddp(self) -> None:
2376
+ """Wraps the prior model with DistributedDataParallel for multi-GPU training.
2377
+
2378
+ Configures the prior model for DDP, setting device IDs and handling unused parameters.
2379
+ """
2380
+ if self.use_ddp:
2381
+ # wrap prior with DDP
2382
+ self.prior_model = DDP(
2383
+ self.prior_model,
2384
+ device_ids=[self.ddp_local_rank],
2385
+ find_unused_parameters=True
2386
+ )
2387
+
2388
+ def _compile_models(self) -> None:
2389
+ """Compiles models for optimization if supported.
2390
+
2391
+ Attempts to compile the prior model using torch.compile for performance optimization,
2392
+ with fallback to uncompiled models if compilation fails.
2393
+ """
2394
+ if self.use_compilation:
2395
+ try:
2396
+ self.prior_model = torch.compile(self.prior_model)
2397
+
2398
+ if self.master_process:
2399
+ print("Models compiled successfully")
2400
+ except Exception as e:
2401
+ if self.master_process:
2402
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
2403
+
2404
+ def forward(self) -> Tuple[List[float], float]:
2405
+ """Trains the UnCLIP prior model.
2406
+
2407
+ Executes the training loop, optimizing the prior model to predict clean image embeddings
2408
+ from noisy embeddings and text conditions, with support for validation, early stopping,
2409
+ and checkpointing.
2410
+
2411
+ Returns
2412
+ -------
2413
+ train_losses : List[float]
2414
+ List of mean training losses per epoch.
2415
+ best_val_loss : float
2416
+ Best validation or training loss achieved.
2417
+ """
2418
+ # set models to training mode
2419
+ self.prior_model.train()
2420
+
2421
+ # compile and wrap models
2422
+ self._compile_models()
2423
+ self._wrap_models_for_ddp()
2424
+
2425
+ # initialize training components
2426
+ scaler = torch.GradScaler()
2427
+ train_losses = []
2428
+ best_val_loss = float("inf")
2429
+ wait = 0
2430
+
2431
+ # main training loop
2432
+ for epoch in range(self.max_epochs):
2433
+ # set epoch for distributed sampler if using DDP
2434
+ if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
2435
+ self.train_loader.sampler.set_epoch(epoch)
2436
+
2437
+ train_losses_epoch = []
2438
+
2439
+ # training step loop with gradient accumulation
2440
+ for step, (x, y) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
2441
+ x = x.to(self.device, non_blocking=True)
2442
+
2443
+ # forward pass with mixed precision
2444
+ with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
2445
+ loss = self._compute_training_loss(x, y)
2446
+ loss = loss / self.grad_accumulation_steps
2447
+
2448
+ # backward pass
2449
+ scaler.scale(loss).backward()
2450
+
2451
+ # optimizer step with gradient accumulation
2452
+ if (step + 1) % self.grad_accumulation_steps == 0:
2453
+ self._optimizer_step(scaler)
2454
+ # update learning rate (warmup scheduler)
2455
+ self.warmup_lr_scheduler.step()
2456
+
2457
+ # record loss (unscaled)
2458
+ train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
2459
+
2460
+ # compute and sync training loss
2461
+ mean_train_loss = self._compute_mean_loss(train_losses_epoch)
2462
+ train_losses.append(mean_train_loss)
2463
+
2464
+ # print training progress (only master process)
2465
+ if self.master_process and (epoch + 1) % self.log_frequency == 0:
2466
+ current_lr = self.optimizer.param_groups[0]['lr']
2467
+ print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}", end="")
2468
+
2469
+ # validation and checkpointing
2470
+ current_loss = mean_train_loss
2471
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
2472
+ val_loss = self.validate()
2473
+ current_loss = val_loss
2474
+
2475
+ if self.master_process:
2476
+ print(f" | Val Loss: {val_loss:.4f}")
2477
+ elif self.master_process:
2478
+ print()
2479
+
2480
+ # learning rate scheduling
2481
+ self.scheduler.step(current_loss)
2482
+
2483
+ # save checkpoint and early stopping
2484
+ if self.master_process:
2485
+ if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
2486
+ best_val_loss = current_loss
2487
+ wait = 0
2488
+ self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
2489
+ else:
2490
+ wait += 1
2491
+ if wait >= self.patience:
2492
+ print("Early stopping triggered")
2493
+ self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
2494
+ break
2495
+
2496
+ # cleanup
2497
+ if self.use_ddp:
2498
+ destroy_process_group()
2499
+
2500
+ return train_losses, best_val_loss
2501
+
2502
+
2503
+ def _compute_training_loss(self, images: torch.Tensor, texts: List[str]) -> torch.Tensor:
2504
+ """Computes the training loss for the UnCLIP prior model.
2505
+
2506
+ Calculates the loss by encoding images and text with CLIP, applying forward diffusion,
2507
+ predicting clean embeddings, and comparing with target embeddings.
2508
+
2509
+ Parameters
2510
+ ----------
2511
+ `images` : torch.Tensor
2512
+ Input images, shape (batch_size, channels, height, width).
2513
+ `texts` : List[str]
2514
+ List of text prompts for conditioning.
2515
+
2516
+ Returns
2517
+ -------
2518
+ loss : torch.Tensor
2519
+ Loss value computed between predicted and target embeddings.
2520
+ """
2521
+
2522
+ with torch.no_grad():
2523
+ # encode text and image with CLIP
2524
+ text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize_clip_embeddings)
2525
+ image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize_clip_embeddings)
2526
+
2527
+ # reduce dimensionality (optional)
2528
+ if self.reduce_clip_embedding_dim:
2529
+ text_embeddings = self.prior_model.clip_text_projection(text_embeddings)
2530
+ image_embeddings = self.prior_model.clip_image_projection(image_embeddings)
2531
+
2532
+ # sample timestep t ~ Uniform(1, T)
2533
+ batch_size = image_embeddings.shape[0]
2534
+ timesteps = torch.randint(0, self.prior_model.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
2535
+
2536
+ # sample noise ε ~ N(0, I)
2537
+ noise = torch.randn_like(image_embeddings)
2538
+
2539
+ # compute noised embedding z_{i,t}
2540
+ noisy_image_embeddings = self.prior_model.forward_diffusion(image_embeddings, noise, timesteps)
2541
+
2542
+ # Predict unnoised embedding ẑ_i
2543
+ predicted_image_embeddings = self.prior_model(text_embeddings, noisy_image_embeddings, timesteps)
2544
+
2545
+ # transform back to original space if using dimension reduction
2546
+ if self.reduce_clip_embedding_dim:
2547
+ predicted_image_embeddings = self.prior_model.clip_image_projection.inverse_transform(predicted_image_embeddings)
2548
+ target_embeddings = self.prior_model.clip_image_projection.inverse_transform(image_embeddings)
2549
+ else:
2550
+ target_embeddings = image_embeddings
2551
+
2552
+ # compute loss L = ||ẑ_i - z_i||²
2553
+ loss = self.objective(predicted_image_embeddings, target_embeddings)
2554
+ return loss
2555
+
2556
+ def _optimizer_step(self, scaler: torch.GradScaler) -> None:
2557
+ """Performs an optimizer step with gradient clipping.
2558
+
2559
+ Applies gradient clipping, updates the optimizer with scaled gradients, and resets
2560
+ gradients for the next iteration.
2561
+
2562
+ Parameters
2563
+ ----------
2564
+ `scaler` : torch.GradScaler
2565
+ Gradient scaler for mixed precision training.
2566
+ """
2567
+ scaler.unscale_(self.optimizer)
2568
+
2569
+ # gradient clipping
2570
+ torch.nn.utils.clip_grad_norm_(self.prior_model.parameters(), max_norm=1.0)
2571
+
2572
+ scaler.step(self.optimizer)
2573
+ scaler.update()
2574
+ self.optimizer.zero_grad()
2575
+
2576
+ def _compute_mean_loss(self, losses: List[float]) -> float:
2577
+ """Computes the mean loss and synchronizes across processes if using DDP.
2578
+
2579
+ Calculates the mean of the provided loss values and performs an all-reduce operation
2580
+ in DDP mode to synchronize the loss across processes.
2581
+
2582
+ Parameters
2583
+ ----------
2584
+ `losses` : List[float]
2585
+ List of loss values from a training or validation epoch.
2586
+
2587
+ Returns
2588
+ -------
2589
+ mean_loss : float
2590
+ Mean loss value, synchronized across processes if DDP is enabled.
2591
+ """
2592
+ mean_loss = torch.tensor(losses).mean().item()
2593
+
2594
+ if self.use_ddp:
2595
+ loss_tensor = torch.tensor(mean_loss, device=self.device)
2596
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
2597
+ mean_loss = loss_tensor.item()
2598
+
2599
+ return mean_loss
2600
+
2601
+
2602
+ def validate(self) -> float:
2603
+ """Validates the UnCLIP prior model.
2604
+
2605
+ Computes the validation loss by encoding images and text, applying forward diffusion,
2606
+ predicting clean embeddings, and comparing with target embeddings.
2607
+
2608
+ Returns
2609
+ -------
2610
+ val_loss : float
2611
+ Mean validation loss, synchronized across processes if DDP is enabled.
2612
+ """
2613
+
2614
+ self.prior_model.eval()
2615
+
2616
+ val_losses = []
2617
+
2618
+ with torch.no_grad():
2619
+ for images, texts in self.val_loader:
2620
+ images = images.to(self.device, non_blocking=True)
2621
+
2622
+ # get embeddings
2623
+ text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize_clip_embeddings)
2624
+ image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize_clip_embeddings)
2625
+ original_image_embeddings = image_embeddings.clone()
2626
+
2627
+ if self.reduce_clip_embedding_dim:
2628
+ text_embeddings = self.prior_model.clip_text_projection(text_embeddings)
2629
+ image_embeddings = self.prior_model.clip_image_projection(image_embeddings)
2630
+
2631
+ # forward diffusion
2632
+ batch_size = image_embeddings.shape[0]
2633
+ timesteps = torch.randint(0, self.prior_model.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
2634
+ noise = torch.randn_like(image_embeddings)
2635
+ noisy_image_embeddings = self.prior_model.forward_diffusion(image_embeddings, noise, timesteps)
2636
+
2637
+ # predict
2638
+ predicted_embeddings = self.prior_model(text_embeddings, noisy_image_embeddings, timesteps)
2639
+
2640
+ if self.reduce_clip_embedding_dim:
2641
+ predicted_embeddings = self.prior_model.clip_image_projection.inverse_transform(predicted_embeddings)
2642
+
2643
+ # compute loss
2644
+ loss = self.objective(predicted_embeddings, original_image_embeddings)
2645
+ val_losses.append(loss.item())
2646
+
2647
+
2648
+ # compute averages
2649
+ val_loss = self._compute_mean_loss(val_losses)
2650
+
2651
+ # return to training mode
2652
+ self.prior_model.train()
2653
+
2654
+ return val_loss
2655
+
2656
+
2657
+ def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "", is_best: bool = False) -> None:
2658
+ """Saves a model checkpoint.
2659
+
2660
+ Saves the state of the prior model and optimizer to a checkpoint file, with options
2661
+ for best model or early stopping checkpoints.
2662
+
2663
+ Parameters
2664
+ ----------
2665
+ `epoch` : int
2666
+ Current epoch number.
2667
+ `loss` : float
2668
+ Current loss value.
2669
+ `suffix` : str, optional
2670
+ Suffix to append to the checkpoint filename, default "".
2671
+ `is_best` : bool, optional
2672
+ Whether to save the checkpoint as the best model, default False.
2673
+ """
2674
+ try:
2675
+ # Get state dicts
2676
+ prior_state = (
2677
+ self.prior_model.module.state_dict() if self.use_ddp
2678
+ else self.prior_model.state_dict()
2679
+ )
2680
+
2681
+ checkpoint = {
2682
+ 'epoch': epoch,
2683
+ 'prior_model_state_dict': prior_state,
2684
+ 'optimizer_state_dict': self.optimizer.state_dict(),
2685
+ 'loss': loss,
2686
+ 'max_epochs': self.max_epochs,
2687
+ }
2688
+
2689
+ # create the directory if it doesn't exist
2690
+ os.makedirs(self.store_path, exist_ok=True)
2691
+
2692
+ # define the checkpoint filename
2693
+ if is_best:
2694
+ filename = "best_model.pth"
2695
+ else:
2696
+ filename = f"checkpoint_epoch_{epoch}{suffix}.pth"
2697
+
2698
+ # construct the full save path
2699
+ save_path = os.path.join(self.store_path, filename)
2700
+
2701
+ # save checkpoint
2702
+ torch.save(checkpoint, save_path)
2703
+ if self.master_process: # only print from the master process in DDP
2704
+ print(f"Checkpoint saved: {save_path}")
2705
+
2706
+ except Exception as e:
2707
+ print(f"Failed to save checkpoint: {e}")
2708
+
2709
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
2710
+ """Loads a model checkpoint to resume training.
2711
+
2712
+ Restores the prior model and optimizer states from a saved checkpoint, handling
2713
+ DDP compatibility for state dictionaries.
2714
+
2715
+ Parameters
2716
+ ----------
2717
+ `checkpoint_path` : str
2718
+ Path to the checkpoint file.
2719
+
2720
+ Returns
2721
+ -------
2722
+ epoch : int
2723
+ The epoch at which the checkpoint was saved.
2724
+ loss : float
2725
+ The loss value at the checkpoint.
2726
+ """
2727
+ try:
2728
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
2729
+ except FileNotFoundError:
2730
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
2731
+
2732
+ # load prior model
2733
+ if 'prior_model_state_dict' in checkpoint:
2734
+ state_dict = checkpoint['prior_model_state_dict']
2735
+
2736
+ # handle DDP state dict compatibility
2737
+ if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
2738
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
2739
+ elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
2740
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
2741
+
2742
+ self.prior_model.load_state_dict(state_dict)
2743
+
2744
+ # load optimizer
2745
+ if 'optimizer_state_dict' in checkpoint:
2746
+ try:
2747
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
2748
+ except Exception as e:
2749
+ warnings.warn(f"Failed to load optimizer state: {e}")
2750
+
2751
+ epoch = checkpoint.get('epoch', 0)
2752
+ loss = checkpoint.get('loss', float('inf'))
2753
+
2754
+ if self.master_process:
2755
+ print(f"Loaded checkpoint from {checkpoint_path} (epoch {epoch}, loss {loss:.4f})")
2756
+
2757
+ return epoch, loss
2758
+
2759
+ ###==================================================================================================================###
2760
+
2761
+ class SampleUnCLIP(nn.Module):
2762
+ """Generates images using the UnCLIP model pipeline.
2763
+
2764
+ Combines a prior model, decoder model, CLIP model, and upsampler models to generate
2765
+ images from text prompts or noise. Performs diffusion-based sampling with classifier-free
2766
+ guidance in both prior and decoder stages, followed by upsampling to higher resolutions.
2767
+
2768
+ Parameters
2769
+ ----------
2770
+ `prior_model` : nn.Module
2771
+ The UnCLIP prior model for generating image embeddings from text.
2772
+ `decoder_model` : nn.Module
2773
+ The UnCLIP decoder model for generating low-resolution images from embeddings.
2774
+ `clip_model` : nn.Module
2775
+ CLIP model for encoding text prompts into embeddings.
2776
+ `low_res_upsampler` : nn.Module
2777
+ First upsampler model for scaling images from 64x64 to 256x256.
2778
+ `high_res_upsampler` : nn.Module, optional
2779
+ Second upsampler model for scaling images from 256x256 to 1024x1024, default None.
2780
+ `device` : Union[torch.device, str], optional
2781
+ Device for computation (default: CUDA if available, else CPU).
2782
+ `clip_embedding_dim` : int, optional
2783
+ Dimensionality of CLIP embeddings (default: 512).
2784
+ `prior_guidance_scale` : float, optional
2785
+ Classifier-free guidance scale for the prior model (default: 4.0).
2786
+ `decoder_guidance_scale` : float, optional
2787
+ Classifier-free guidance scale for the decoder model (default: 8.0).
2788
+ `batch_size` : int, optional
2789
+ Number of images to generate per batch (default: 1).
2790
+ `normalize` : bool, optional
2791
+ Whether to normalize CLIP embeddings (default: True).
2792
+ `prior_dim_reduction` : bool, optional
2793
+ Whether to apply dimensionality reduction in the prior model (default: True).
2794
+ `image_size` : Tuple[int, int, int], optional
2795
+ Size of the initial generated images (default: (3, 64, 64) for RGB 64x64).
2796
+ `use_high_res_upsampler` : bool, optional
2797
+ Whether to use the second upsampler for 1024x1024 output (default: True).
2798
+ `image_output_range` : Tuple[float, float], optional
2799
+ Range for clamping output images (default: (-1.0, 1.0)).
2800
+ """
2801
+ def __init__(
2802
+ self,
2803
+ prior_model: nn.Module,
2804
+ decoder_model: nn.Module,
2805
+ clip_model: nn.Module,
2806
+ low_res_upsampler: nn.Module,
2807
+ high_res_upsampler: Optional[nn.Module] = None,
2808
+ device: Optional[Union[torch.device, str]] = None,
2809
+ clip_embedding_dim: int = 512, # CLIP embedding dimension
2810
+ prior_guidance_scale: float = 4.0,
2811
+ decoder_guidance_scale: float = 8.0,
2812
+ batch_size: int = 1,
2813
+ normalize_clip_embeddings: bool = True,
2814
+ prior_dim_reduction: bool = True,
2815
+ initial_image_size: Tuple[int, int, int] = (3, 64, 64),
2816
+ use_high_res_upsampler: bool = True,
2817
+ image_output_range: Tuple[float, float] = (-1.0, 1.0)
2818
+ ) -> None:
2819
+ super().__init__()
2820
+
2821
+ if device is None:
2822
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2823
+ elif isinstance(device, str):
2824
+ self.device = torch.device(device)
2825
+ else:
2826
+ self.device = device
2827
+
2828
+ self.prior_model = prior_model.to(self.device).eval()
2829
+ self.decoder_model = decoder_model.to(self.device).eval()
2830
+ self.clip_model = clip_model.to(self.device).eval()
2831
+ self.low_res_upsampler = low_res_upsampler.to(self.device).eval()
2832
+ self.high_res_upsampler = high_res_upsampler.to(self.device).eval() if high_res_upsampler else None
2833
+
2834
+ self.prior_guidance_scale = prior_guidance_scale
2835
+ self.decoder_guidance_scale = decoder_guidance_scale
2836
+ self.batch_size = batch_size
2837
+ self.normalize_clip_embeddings = normalize_clip_embeddings
2838
+ self.prior_dim_reduction = prior_dim_reduction
2839
+ self.clip_embedding_dim = clip_embedding_dim
2840
+ self.initial_image_size = initial_image_size
2841
+ self.use_high_res_upsampler = use_high_res_upsampler
2842
+ self.image_output_range = image_output_range
2843
+ self.images_256 = None
2844
+ self.images_1024 = None
2845
+
2846
+ def forward(
2847
+ self,
2848
+ prompts: Optional[Union[str, List]] = None,
2849
+ normalize_output: bool = True,
2850
+ save_images: bool = True,
2851
+ save_path: str = "unclip_generated"
2852
+ ):
2853
+ """Generates images from text prompts or noise using the UnCLIP pipeline.
2854
+
2855
+ Executes the full UnCLIP generation process: prior model generates image embeddings,
2856
+ decoder model generates 64x64 images, first upsampler scales to 256x256, and optional
2857
+ second upsampler scales to 1024x1024. Supports classifier-free guidance and saves
2858
+ generated images if requested.
2859
+
2860
+ Parameters
2861
+ ----------
2862
+ `prompts` : Union[str, List], optional
2863
+ Text prompt(s) for conditional generation, default None (unconditional).
2864
+ `normalize_output` : bool, optional
2865
+ Whether to normalize output images to [0, 1] range (default: True).
2866
+ `save_images` : bool, optional
2867
+ Whether to save generated images to disk (default: True).
2868
+ `save_path` : str, optional
2869
+ Directory to save generated images (default: "unclip_generated").
2870
+
2871
+ Returns
2872
+ -------
2873
+ final_images : torch.Tensor
2874
+ Generated images, shape (batch_size, channels, height, width), either 256x256
2875
+ or 1024x1024 depending on use_second_upsampler.
2876
+ """
2877
+ # initialize noise for prior sampling (image embedding space)
2878
+ embedding_noise = torch.randn((self.batch_size, self.clip_embedding_dim), device=self.device)
2879
+
2880
+ with torch.no_grad():
2881
+
2882
+ # ====== PRIOR STAGE: generate image embeddings from text ======
2883
+ # encode text prompt using CLIP
2884
+ text_embeddings = self.clip_model(data=prompts, data_type="text", normalize=self.normalize_clip_embeddings)
2885
+ current_embeddings = embedding_noise.clone()
2886
+
2887
+ # optionally reduce dimensionality for prior model
2888
+ if self.prior_dim_reduction:
2889
+ text_embeddings_reduced = self.prior_model.clip_text_projection(text_embeddings)
2890
+ current_embeddings_reduced = self.prior_model.clip_image_projection(current_embeddings)
2891
+ else:
2892
+ text_embeddings_reduced = text_embeddings
2893
+ current_embeddings_reduced = current_embeddings
2894
+
2895
+ # prior diffusion sampling loop
2896
+ for t in reversed(range(self.prior_model.forward_diffusion.variance_scheduler.tau_num_steps)):
2897
+ timesteps = torch.full((self.batch_size,), t, device=self.device)
2898
+ prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
2899
+
2900
+ # predict embeddings
2901
+ predicted_embeddings = self.prior_model(text_embeddings_reduced, current_embeddings_reduced, timesteps)
2902
+
2903
+ # apply guidance
2904
+ guided_embeddings = self.compute_prior_guided_prediction(
2905
+ predicted_embeddings, text_embeddings_reduced, current_embeddings_reduced, timesteps
2906
+ )
2907
+
2908
+ # update embeddings using reverse diffusion
2909
+ current_embeddings_reduced, _ = self.prior_model.reverse_diffusion(
2910
+ current_embeddings_reduced, guided_embeddings, timesteps, prev_timesteps
2911
+ )
2912
+
2913
+ # convert back to full embedding dimension if needed
2914
+ if self.prior_dim_reduction:
2915
+ final_image_embeddings = self.prior_model.clip_image_projection.inverse_transform(current_embeddings_reduced)
2916
+ else:
2917
+ final_image_embeddings = current_embeddings_reduced
2918
+
2919
+ # ====== DECODER STAGE: generate 64x64 images from embeddings ======
2920
+ # initialize noise for decoder sampling
2921
+ decoder_noise = torch.randn((self.batch_size, self.initial_image_size[0], self.initial_image_size[1], self.initial_image_size[2]), device=self.device)
2922
+
2923
+ # project image embeddings to 4 tokens
2924
+ projected_embeddings = self.decoder_model.clip_decoder_projection(final_image_embeddings)
2925
+
2926
+ # encode text with GLIDE/decoder's text encoder
2927
+ glide_text_embeddings = self.decoder_model._encode_text_with_glide(prompts)
2928
+
2929
+ # concatenate embeddings for context
2930
+ context = self.decoder_model._concatenate_embeddings(glide_text_embeddings, projected_embeddings)
2931
+
2932
+ current_images = decoder_noise
2933
+
2934
+ for t in reversed(range(self.decoder_model.forward_diffusion.variance_scheduler.tau_num_steps)):
2935
+
2936
+ timesteps = torch.full((self.batch_size,), t, device=self.device)
2937
+ prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
2938
+
2939
+ # predict noise
2940
+ predicted_noise = self.decoder_model.noise_predictor(current_images, timesteps, context, None)
2941
+
2942
+ # apply guidance
2943
+ guided_noise = self.compute_decoder_guided_prediction(
2944
+ predicted_noise, current_images, timesteps, context
2945
+ )
2946
+
2947
+ # update images using reverse diffusion
2948
+ current_images, _ = self.decoder_model.reverse_diffusion(
2949
+ current_images, guided_noise, timesteps, prev_timesteps
2950
+ )
2951
+
2952
+ generated_64x64 = current_images
2953
+
2954
+ # ====== FIRST UPSAMPLER: 64x64 -> 256x256 ======
2955
+ upsampled_256_noise = torch.randn((self.batch_size, self.initial_image_size[0], 256, 256), device=self.device)
2956
+ current_256_images = upsampled_256_noise
2957
+
2958
+ for t in reversed(range(self.low_res_upsampler.forward_diffusion.variance_scheduler.tau_num_steps)):
2959
+ timesteps = torch.full((self.batch_size,), t, device=self.device)
2960
+ prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
2961
+
2962
+ # predict noise for upsampling (conditioned on low-res image)
2963
+ predicted_noise = self.low_res_upsampler(current_256_images, timesteps, generated_64x64)
2964
+
2965
+ # update using reverse diffusion
2966
+ current_256_images, _ = self.low_res_upsampler.reverse_diffusion(
2967
+ current_256_images, predicted_noise, timesteps, prev_timesteps
2968
+ )
2969
+
2970
+ self.images_256 = current_256_images
2971
+
2972
+ # ====== SECOND UPSAMPLER: 256x256 -> 1024x1024 (if enabled) ======
2973
+ if self.use_high_res_upsampler and self.high_res_upsampler:
2974
+ upsampled_1024_noise = torch.randn((self.batch_size, self.initial_image_size[0], 1024, 1024), device=self.device)
2975
+ current_1024_images = upsampled_1024_noise
2976
+
2977
+ for t in reversed(range(self.high_res_upsampler.forward_diffusion.variance_scheduler.tau_num_steps)):
2978
+ timesteps = torch.full((self.batch_size,), t, device=self.device)
2979
+ prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
2980
+
2981
+ # predict noise for upsampling (conditioned on 256x256 image)
2982
+ predicted_noise = self.high_res_upsampler(current_1024_images, timesteps, self.images_256)
2983
+
2984
+ # update using reverse diffusion
2985
+ current_1024_images, _ = self.high_res_upsampler.reverse_diffusion(
2986
+ current_1024_images, predicted_noise, timesteps, prev_timesteps
2987
+ )
2988
+
2989
+ self.images_1024 = current_1024_images
2990
+
2991
+ # ====== POST-PROCESSING ======
2992
+ # normalize output to [0, 1] range if requested
2993
+ if normalize_output:
2994
+ final_256 = (self.images_256 - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
2995
+ final_1024 = None
2996
+ if self.images_1024 is not None:
2997
+ final_1024 = (self.images_1024 - self.image_output_range[0]) / (
2998
+ self.image_output_range[1] - self.image_output_range[0])
2999
+ else:
3000
+ final_256 = self.images_256
3001
+ final_1024 = self.images_1024
3002
+
3003
+ # save images if requested
3004
+ if save_images:
3005
+ os.makedirs(save_path, exist_ok=True)
3006
+ os.makedirs(os.path.join(save_path, "images_256"), exist_ok=True)
3007
+ if final_1024 is not None:
3008
+ os.makedirs(os.path.join(save_path, "images_1024"), exist_ok=True)
3009
+
3010
+ for i in range(self.batch_size):
3011
+ img_path_256 = os.path.join(save_path, "images_256", f"image_{i+1}.png")
3012
+ torchvision.utils.save_image(final_256[i], img_path_256)
3013
+
3014
+ if final_1024 is not None:
3015
+ img_path_1024 = os.path.join(save_path, "images_1024", f"image_{i+1}.png")
3016
+ torchvision.utils.save_image(final_1024[i], img_path_1024)
3017
+
3018
+ # return final images
3019
+ if final_1024 is not None:
3020
+ return final_1024
3021
+ else:
3022
+ return final_256
3023
+
3024
+ def compute_prior_guided_prediction(
3025
+ self,
3026
+ predicted_embeddings: torch.Tensor,
3027
+ text_embeddings: torch.Tensor,
3028
+ current_embeddings: torch.Tensor,
3029
+ timesteps: torch.Tensor
3030
+ ) -> torch.Tensor:
3031
+ """Computes classifier-free guidance for the prior model.
3032
+
3033
+ Combines conditioned and unconditioned predictions using the classifier-free guidance
3034
+ formula to enhance the quality of generated image embeddings.
3035
+
3036
+ Parameters
3037
+ ----------
3038
+ `predicted_embeddings` : torch.Tensor
3039
+ Conditioned predicted embeddings, shape (batch_size, embedding_dim).
3040
+ `text_embeddings` : torch.Tensor
3041
+ Text embeddings from CLIP, shape (batch_size, embedding_dim).
3042
+ `current_embeddings` : torch.Tensor
3043
+ Current noisy embeddings, shape (batch_size, embedding_dim).
3044
+ `timesteps` : torch.Tensor
3045
+ Timestep indices, shape (batch_size,).
3046
+
3047
+ Returns
3048
+ -------
3049
+ guided_embeddings : torch.Tensor
3050
+ Guided embeddings, shape (batch_size, embedding_dim).
3051
+ """
3052
+ # use zero embeddings for unconditional generation
3053
+ zero_text_embeddings = torch.zeros_like(text_embeddings)
3054
+ unconditioned_pred = self.prior_model(zero_text_embeddings, current_embeddings, timesteps)
3055
+
3056
+ # CFG formula: (1 + guidance_scale) * conditioned - guidance_scale * unconditioned
3057
+ return (1.0 + self.prior_guidance_scale) * predicted_embeddings - self.prior_guidance_scale * unconditioned_pred
3058
+
3059
+ def compute_decoder_guided_prediction(
3060
+ self,
3061
+ predicted_noise: torch.Tensor,
3062
+ current_images: torch.Tensor,
3063
+ timesteps: torch.Tensor,
3064
+ context: torch.Tensor
3065
+ ) -> torch.Tensor:
3066
+ """Computes classifier-free guidance for the decoder model.
3067
+
3068
+ Combines conditioned and unconditioned noise predictions using the classifier-free
3069
+ guidance formula to enhance the quality of generated images.
3070
+
3071
+ Parameters
3072
+ ----------
3073
+ `predicted_noise` : torch.Tensor
3074
+ Conditioned predicted noise, shape (batch_size, channels, height, width).
3075
+ `current_images` : torch.Tensor
3076
+ Current noisy images, shape (batch_size, channels, height, width).
3077
+ `timesteps` : torch.Tensor
3078
+ Timestep indices, shape (batch_size,).
3079
+ `context` : torch.Tensor
3080
+ Context embeddings (concatenated GLIDE text and projected image embeddings),
3081
+ shape (batch_size, seq_len, embedding_dim).
3082
+
3083
+ Returns
3084
+ -------
3085
+ guided_noise : torch.Tensor
3086
+ Guided noise prediction, shape (batch_size, channels, height, width).
3087
+ """
3088
+ zero_context = torch.zeros_like(context)
3089
+ unconditioned_noise = self.decoder_model.noise_predictor(current_images, timesteps, zero_context, None)
3090
+
3091
+ # CFG formula: (1 + guidance_scale) * conditioned - guidance_scale * unconditioned
3092
+ return (1.0 + self.decoder_guidance_scale) * predicted_noise - self.decoder_guidance_scale * unconditioned_noise
3093
+
3094
+ def to(self, device: Union[torch.device, str]) -> Self:
3095
+ """Moves the module and all its components to the specified device.
3096
+
3097
+ Updates the device attribute and moves all sub-models (prior, decoder, CLIP,
3098
+ and upsamplers) to the specified device.
3099
+
3100
+ Parameters
3101
+ ----------
3102
+ device : Union[torch.device, str]
3103
+ Target device for the module and its components.
3104
+
3105
+ Returns
3106
+ -------
3107
+ SampleUnCLIP
3108
+ The module moved to the specified device.
3109
+ """
3110
+ if isinstance(device, str):
3111
+ device = torch.device(device)
3112
+
3113
+ self.device = device
3114
+
3115
+ # move all sub-models to the specified device
3116
+ self.prior_model.to(device)
3117
+ self.decoder_model.to(device)
3118
+ self.clip_model.to(device)
3119
+ self.low_res_upsampler.to(device)
3120
+
3121
+ if self.second_upsampler_model is not None:
3122
+ self.second_upsampler_model.to(device)
3123
+
3124
+ return super().to(device)
3125
+
3126
+ ###==================================================================================================================###
3127
+
3128
+ class UpsamplerUnCLIP(nn.Module):
3129
+ """Diffusion-based upsampler for UnCLIP models.
3130
+
3131
+ A U-Net-like model that upsamples low-resolution images to high-resolution images,
3132
+ conditioned on noisy high-resolution images and timesteps, using residual blocks,
3133
+ downsampling, and upsampling layers.
3134
+
3135
+ Parameters
3136
+ ----------
3137
+ `forward_diffusion` : nn.Module
3138
+ Forward diffusion module (e.g., ForwardUnCLIP) for adding noise during training.
3139
+ `in_channels` : int, optional
3140
+ Number of input channels (default: 3, for RGB images).
3141
+ `out_channels` : int, optional
3142
+ Number of output channels (default: 3, for RGB noise prediction).
3143
+ `model_channels` : int, optional
3144
+ Base number of channels in the model (default: 192).
3145
+ `num_res_blocks` : int, optional
3146
+ Number of residual blocks per resolution level (default: 2).
3147
+ `channel_mult` : Tuple[int, ...], optional
3148
+ Channel multiplier for each resolution level (default: (1, 2, 4, 8)).
3149
+ `dropout` : float, optional
3150
+ Dropout probability for regularization (default: 0.1).
3151
+ `time_embed_dim` : int, optional
3152
+ Dimensionality of time embeddings (default: 768).
3153
+ `low_res_size` : int, optional
3154
+ Spatial size of low-resolution input (default: 64).
3155
+ `high_res_size` : int, optional
3156
+ Spatial size of high-resolution output (default: 256).
3157
+ """
3158
+
3159
+ def __init__(
3160
+ self,
3161
+ forward_diffusion: nn.Module,
3162
+ reverse_diffusion: nn.Module,
3163
+ in_channels: int = 3,
3164
+ out_channels: int = 3,
3165
+ model_channels: int = 192,
3166
+ num_res_blocks: int = 2,
3167
+ channel_mult: Tuple[int, ...] = (1, 2, 4, 8),
3168
+ dropout_rate: float = 0.1,
3169
+ time_embed_dim: int = 768,
3170
+ low_res_size: int = 64,
3171
+ high_res_size: int = 256,
3172
+ ) -> None:
3173
+ super().__init__()
3174
+
3175
+ self.forward_diffusion = forward_diffusion # this will be used on training time inside 'TrainUpsamplerUnCLIP'
3176
+ self.reverse_diffusion = reverse_diffusion # this module will be used in inference time
3177
+ self.in_channels = in_channels
3178
+ self.out_channels = out_channels
3179
+ self.model_channels = model_channels
3180
+ self.num_res_blocks = num_res_blocks
3181
+ self.low_res_size = low_res_size
3182
+ self.high_res_size = high_res_size
3183
+
3184
+ # time embedding
3185
+ self.time_embed = nn.Sequential(
3186
+ SinusoidalPositionalEmbedding(model_channels),
3187
+ nn.Linear(model_channels, time_embed_dim),
3188
+ nn.SiLU(),
3189
+ nn.Linear(time_embed_dim, time_embed_dim),
3190
+ )
3191
+
3192
+ # Input projection
3193
+ # concatenate noisy high-res and upsampled low-res
3194
+ self.input_proj = nn.Conv2d(in_channels * 2, model_channels, 3, padding=1)
3195
+
3196
+ # encoder (downsampling path)
3197
+ self.encoder_blocks = nn.ModuleList()
3198
+ self.downsample_blocks = nn.ModuleList()
3199
+
3200
+ ch = model_channels
3201
+ for level, mult in enumerate(channel_mult):
3202
+ for _ in range(num_res_blocks):
3203
+ self.encoder_blocks.append(
3204
+ ResBlock(ch, model_channels * mult, time_embed_dim, dropout_rate)
3205
+ )
3206
+ ch = model_channels * mult
3207
+
3208
+ if level != len(channel_mult) - 1:
3209
+ self.downsample_blocks.append(DownsampleBlock(ch, ch))
3210
+
3211
+ # middle blocks
3212
+ self.middle_blocks = nn.ModuleList([
3213
+ ResBlock(ch, ch, time_embed_dim, dropout_rate),
3214
+ ResBlock(ch, ch, time_embed_dim, dropout_rate),
3215
+ ])
3216
+
3217
+ # decoder (upsampling path)
3218
+ self.decoder_blocks = nn.ModuleList()
3219
+ self.upsample_blocks = nn.ModuleList()
3220
+
3221
+ for level, mult in reversed(list(enumerate(channel_mult))):
3222
+ for i in range(num_res_blocks + 1):
3223
+ # skip connections double the input channels
3224
+ in_ch = ch + (model_channels * mult if i == 0 else 0)
3225
+ out_ch = model_channels * mult
3226
+
3227
+ self.decoder_blocks.append(
3228
+ ResBlock(in_ch, out_ch, time_embed_dim, dropout_rate)
3229
+ )
3230
+ ch = out_ch
3231
+
3232
+ if level != 0:
3233
+ self.upsample_blocks.append(UpsampleBlock(ch, ch))
3234
+
3235
+ # output projection
3236
+ self.output_proj = nn.Sequential(
3237
+ nn.GroupNorm(8, ch),
3238
+ nn.SiLU(),
3239
+ nn.Conv2d(ch, out_channels, 3, padding=1),
3240
+ )
3241
+
3242
+ def forward(self, x_high: torch.Tensor, t: torch.Tensor, x_low: torch.Tensor) -> torch.Tensor:
3243
+ """Predicts noise for the upsampling process.
3244
+
3245
+ Processes a noisy high-resolution image and a low-resolution conditioning image,
3246
+ conditioned on timesteps, to predict the noise component for denoising.
3247
+
3248
+ Parameters
3249
+ ----------
3250
+ `x_high` : torch.Tensor
3251
+ Noisy high-resolution image, shape (batch_size, in_channels, high_res_size, high_res_size).
3252
+ `t` : torch.Tensor
3253
+ Timestep indices, shape (batch_size,).
3254
+ `x_low` : torch.Tensor
3255
+ Low-resolution conditioning image, shape (batch_size, in_channels, low_res_size, low_res_size).
3256
+
3257
+ Returns
3258
+ -------
3259
+ out : torch.Tensor
3260
+ Predicted noise, shape (batch_size, out_channels, high_res_size, high_res_size).
3261
+ """
3262
+ # upsample low-resolution image to match high-resolution
3263
+ x_low_upsampled = F.interpolate(
3264
+ x_low,
3265
+ size=(x_high.shape[-2], x_high.shape[-1]),
3266
+ mode='bicubic',
3267
+ align_corners=False
3268
+ )
3269
+
3270
+ # concatenate noisy high-res and upsampled low-res
3271
+ x = torch.cat([x_high, x_low_upsampled], dim=1)
3272
+
3273
+ # time embedding
3274
+ time_emb = self.time_embed(t.float()) # Ensure float for embedding
3275
+
3276
+ # input projection
3277
+ h = self.input_proj(x)
3278
+
3279
+ # store skip connections
3280
+ skip_connections = []
3281
+
3282
+ # encoder
3283
+ for i, block in enumerate(self.encoder_blocks):
3284
+ h = block(h, time_emb)
3285
+ if (i + 1) % self.num_res_blocks == 0:
3286
+ skip_connections.append(h)
3287
+ downsample_idx = (i + 1) // self.num_res_blocks - 1
3288
+ if downsample_idx < len(self.downsample_blocks):
3289
+ h = self.downsample_blocks[downsample_idx](h)
3290
+
3291
+ # middle
3292
+ for i, block in enumerate(self.middle_blocks):
3293
+ h = block(h, time_emb)
3294
+
3295
+ # decoder
3296
+ upsample_idx = 0
3297
+ for i, block in enumerate(self.decoder_blocks):
3298
+ # add skip connection
3299
+ if i % (self.num_res_blocks + 1) == 0 and skip_connections:
3300
+ skip = skip_connections.pop()
3301
+ h = torch.cat([h, skip], dim=1)
3302
+
3303
+ h = block(h, time_emb)
3304
+
3305
+ # upsample at the end of each resolution level
3306
+ if ((i + 1) % (self.num_res_blocks + 1) == 0 and
3307
+ upsample_idx < len(self.upsample_blocks)):
3308
+ h = self.upsample_blocks[upsample_idx](h)
3309
+ upsample_idx += 1
3310
+
3311
+ # output projection
3312
+ out = self.output_proj(h)
3313
+
3314
+ return out
3315
+
3316
+
3317
+
3318
+ class SinusoidalPositionalEmbedding(nn.Module):
3319
+ """Sinusoidal positional embedding for timesteps.
3320
+
3321
+ Generates sinusoidal embeddings for timesteps to condition the upsampler on the
3322
+ diffusion process stage.
3323
+
3324
+ Parameters
3325
+ ----------
3326
+ `dim` : int
3327
+ Dimensionality of the embedding.
3328
+ """
3329
+
3330
+ def __init__(self, dim: int):
3331
+ super().__init__()
3332
+ self.dim = dim
3333
+
3334
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
3335
+ """Generates sinusoidal embeddings for timesteps.
3336
+
3337
+ Parameters
3338
+ ----------
3339
+ `timesteps` : torch.Tensor
3340
+ Timestep indices, shape (batch_size,).
3341
+
3342
+ Returns
3343
+ -------
3344
+ embeddings : torch.Tensor
3345
+ Sinusoidal embeddings, shape (batch_size, dim).
3346
+ """
3347
+ device = timesteps.device
3348
+ half_dim = self.dim // 2
3349
+ embeddings = math.log(10000) / (half_dim - 1)
3350
+ embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
3351
+ embeddings = timesteps[:, None] * embeddings[None, :]
3352
+ embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
3353
+ return embeddings
3354
+
3355
+
3356
+ class ResBlock(nn.Module):
3357
+ """Residual block with time embedding and conditioning.
3358
+
3359
+ A convolutional residual block with group normalization, time embedding conditioning,
3360
+ and optional scale-shift normalization, used in the UnCLIP upsampler.
3361
+
3362
+ Parameters
3363
+ ----------
3364
+ `in_channels` : int
3365
+ Number of input channels.
3366
+ `out_channels` : int
3367
+ Number of output channels.
3368
+ `time_embed_dim` : int
3369
+ Dimensionality of time embeddings.
3370
+ `dropout` : float, optional
3371
+ Dropout probability (default: 0.1).
3372
+ `use_scale_shift_norm` : bool, optional
3373
+ Whether to use scale-shift normalization for time embeddings (default: True).
3374
+ """
3375
+ def __init__(self, in_channels: int, out_channels: int, time_embed_dim: int,
3376
+ dropout: float = 0.1, use_scale_shift_norm: bool = True):
3377
+ super().__init__()
3378
+ self.use_scale_shift_norm = use_scale_shift_norm
3379
+
3380
+ self.in_layers = nn.Sequential(
3381
+ nn.GroupNorm(8, in_channels),
3382
+ nn.SiLU(),
3383
+ nn.Conv2d(in_channels, out_channels, 3, padding=1)
3384
+ )
3385
+
3386
+ self.time_emb_proj = nn.Sequential(
3387
+ nn.SiLU(),
3388
+ nn.Linear(time_embed_dim, out_channels * 2 if use_scale_shift_norm else out_channels)
3389
+ )
3390
+
3391
+ self.out_norm = nn.GroupNorm(8, out_channels)
3392
+ self.out_rest = nn.Sequential(
3393
+ nn.SiLU(),
3394
+ nn.Dropout(dropout),
3395
+ nn.Conv2d(out_channels, out_channels, 3, padding=1)
3396
+ )
3397
+
3398
+ if in_channels != out_channels:
3399
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, 1)
3400
+ else:
3401
+ self.skip_connection = nn.Identity()
3402
+
3403
+ def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
3404
+ """Processes input through the residual block with time conditioning.
3405
+
3406
+ Parameters
3407
+ ----------
3408
+ `x` : torch.Tensor
3409
+ Input tensor, shape (batch_size, in_channels, height, width).
3410
+ `time_emb` : torch.Tensor
3411
+ Time embeddings, shape (batch_size, time_embed_dim).
3412
+
3413
+ Returns
3414
+ -------
3415
+ out : torch.Tensor
3416
+ Output tensor, shape (batch_size, out_channels, height, width).
3417
+ """
3418
+ h = self.in_layers(x)
3419
+
3420
+ # apply time embedding
3421
+ emb_out = self.time_emb_proj(time_emb)[:, :, None, None]
3422
+
3423
+ if self.use_scale_shift_norm:
3424
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
3425
+ h = self.out_norm(h) * (1 + scale) + shift
3426
+ h = self.out_rest(h)
3427
+ else:
3428
+ h = h + emb_out
3429
+ h = self.out_norm(h)
3430
+ h = self.out_rest(h)
3431
+
3432
+ return h + self.skip_connection(x)
3433
+
3434
+
3435
+ class UpsampleBlock(nn.Module):
3436
+ """Upsampling block using transposed convolution.
3437
+
3438
+ Increases the spatial resolution of the input tensor using a transposed convolution.
3439
+
3440
+ Parameters
3441
+ ----------
3442
+ `in_channels` : int
3443
+ Number of input channels.
3444
+ `out_channels` : int
3445
+ Number of output channels.
3446
+ """
3447
+
3448
+ def __init__(self, in_channels: int, out_channels: int):
3449
+ super().__init__()
3450
+ self.conv = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)
3451
+
3452
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
3453
+ """Upsamples the input tensor.
3454
+
3455
+ Parameters
3456
+ ----------
3457
+ `x` : torch.Tensor
3458
+ Input tensor, shape (batch_size, in_channels, height, width).
3459
+
3460
+ Returns
3461
+ -------
3462
+ out : torch.Tensor
3463
+ Upsampled tensor, shape (batch_size, out_channels, height*2, width*2).
3464
+ """
3465
+ return self.conv(x)
3466
+
3467
+
3468
+ class DownsampleBlock(nn.Module):
3469
+ """Downsampling block using strided convolution.
3470
+
3471
+ Reduces the spatial resolution of the input tensor using a strided convolution.
3472
+
3473
+ Parameters
3474
+ ----------
3475
+ `in_channels` : int
3476
+ Number of input channels.
3477
+ `out_channels` : int
3478
+ Number of output channels.
3479
+ """
3480
+
3481
+ def __init__(self, in_channels: int, out_channels: int):
3482
+ super().__init__()
3483
+ self.conv = nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)
3484
+
3485
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
3486
+ """Downsamples the input tensor.
3487
+
3488
+ Parameters
3489
+ ----------
3490
+ `x` : torch.Tensor
3491
+ Input tensor, shape (batch_size, in_channels, height, width).
3492
+
3493
+ Returns
3494
+ -------
3495
+ out : torch.Tensor
3496
+ Downsampled tensor, shape (batch_size, out_channels, height//2, width//2).
3497
+ """
3498
+ return self.conv(x)
3499
+
3500
+ ###==================================================================================================================###
3501
+
3502
+ class TrainUpsamplerUnCLIP(nn.Module):
3503
+ """Trainer for the UnCLIP upsampler model.
3504
+
3505
+ Orchestrates the training of the UnCLIP upsampler model, integrating forward diffusion,
3506
+ noise prediction, and low-resolution image conditioning with optional corruption (Gaussian
3507
+ blur or BSR degradation). Supports mixed precision, gradient accumulation, DDP, and
3508
+ comprehensive training utilities.
3509
+
3510
+ Parameters
3511
+ ----------
3512
+ `upsampler_model` : nn.Module
3513
+ The UnCLIP upsampler model (e.g., UpsamplerUnCLIP) to be trained.
3514
+ `train_loader` : torch.utils.data.DataLoader
3515
+ DataLoader for training data, providing low- and high-resolution image pairs.
3516
+ `optimizer` : torch.optim.Optimizer
3517
+ Optimizer for training the upsampler model.
3518
+ `objective` : Callable
3519
+ Loss function to compute the difference between predicted and target noise.
3520
+ `val_loader` : torch.utils.data.DataLoader, optional
3521
+ DataLoader for validation data, default None.
3522
+ `max_epochs` : int, optional
3523
+ Maximum number of training epochs (default: 1000).
3524
+ `device` : Union[str, torch.device], optional
3525
+ Device for computation (default: CUDA if available, else CPU).
3526
+ `store_path` : str, optional
3527
+ Directory to save model checkpoints (default: "unclip_upsampler").
3528
+ `patience` : int, optional
3529
+ Number of epochs to wait for improvement before early stopping (default: 100).
3530
+ `warmup_epochs` : int, optional
3531
+ Number of epochs for learning rate warmup (default: 100).
3532
+ `val_frequency` : int, optional
3533
+ Frequency (in epochs) for validation (default: 10).
3534
+ `use_ddp` : bool, optional
3535
+ Whether to use Distributed Data Parallel training (default: False).
3536
+ `grad_accumulation_steps` : int, optional
3537
+ Number of gradient accumulation steps before optimizer update (default: 1).
3538
+ `log_frequency` : int, optional
3539
+ Frequency (in epochs) for printing progress (default: 1).
3540
+ `use_compilation` : bool, optional
3541
+ Whether to compile the model using torch.compile (default: False).
3542
+ `image_output_range` : Tuple[float, float], optional
3543
+ Range for clamping output images (default: (-1.0, 1.0)).
3544
+ `normalize_image_outputs` : bool, optional
3545
+ Whether to normalize inputs/outputs (default: True).
3546
+ `use_autocast` : bool, optional
3547
+ Whether to use automatic mixed precision training (default: True).
3548
+ """
3549
+
3550
+ def __init__(
3551
+ self,
3552
+ upsampler_model: nn.Module,
3553
+ train_loader: torch.utils.data.DataLoader,
3554
+ optimizer: torch.optim.Optimizer,
3555
+ objective: Callable,
3556
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
3557
+ max_epochs: int = 1000,
3558
+ device: Optional[Union[str, torch.device]] = None,
3559
+ store_path: str = "unclip_upsampler",
3560
+ patience: int = 100,
3561
+ warmup_epochs: int = 100,
3562
+ val_frequency: int = 10,
3563
+ use_ddp: bool = False,
3564
+ grad_accumulation_steps: int = 1,
3565
+ log_frequency: int = 1,
3566
+ use_compilation: bool = False,
3567
+ image_output_range: Tuple[float, float] = (-1.0, 1.0),
3568
+ normalize_image_outputs: bool = True,
3569
+ use_autocast: bool = True
3570
+ ) -> None:
3571
+ super().__init__()
3572
+
3573
+ # training configuration
3574
+ self.use_ddp = use_ddp
3575
+ self.grad_accumulation_steps = grad_accumulation_steps
3576
+ self.use_compilation = use_compilation
3577
+ self.use_autocast = use_autocast # Store autocast flag
3578
+
3579
+ # device initialization
3580
+ if device is None:
3581
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3582
+ elif isinstance(device, str):
3583
+ self.device = torch.device(device)
3584
+ else:
3585
+ self.device = device
3586
+
3587
+ # setup distributed training
3588
+ if self.use_ddp:
3589
+ self._setup_ddp()
3590
+ else:
3591
+ self._setup_single_gpu()
3592
+
3593
+ # compile and wrap models
3594
+ self._compile_models()
3595
+ self._wrap_models_for_ddp()
3596
+
3597
+ # core model
3598
+ self.upsampler_model = upsampler_model.to(self.device)
3599
+ self.num_timesteps = self.upsampler_model.forward_diffusion.variance_scheduler.num_steps
3600
+
3601
+ # training components
3602
+ self.optimizer = optimizer
3603
+ self.objective = objective
3604
+ self.train_loader = train_loader
3605
+ self.val_loader = val_loader
3606
+
3607
+ # training parameters
3608
+ self.max_epochs = max_epochs
3609
+ self.patience = patience
3610
+ self.val_frequency = val_frequency
3611
+ self.log_frequency = log_frequency
3612
+ self.image_output_range = image_output_range
3613
+ self.normalize_image_outputs = normalize_image_outputs
3614
+
3615
+ # checkpoint management
3616
+ self.store_path = store_path
3617
+
3618
+ # learning rate scheduling
3619
+ self.scheduler = ReduceLROnPlateau(
3620
+ self.optimizer,
3621
+ patience=self.patience,
3622
+ factor=0.5
3623
+ )
3624
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
3625
+
3626
+ def forward(self) -> Tuple[List[float], float]:
3627
+ """Trains the UnCLIP upsampler model to predict noise for denoising.
3628
+
3629
+ Executes the training loop, optimizing the upsampler model using low- and high-resolution
3630
+ image pairs, mixed precision, gradient clipping, and learning rate scheduling. Supports
3631
+ validation, early stopping, and checkpointing.
3632
+
3633
+ Returns
3634
+ -------
3635
+ train_losses : List[float]
3636
+ List of mean training losses per epoch.
3637
+ best_val_loss : float
3638
+ Best validation or training loss achieved.
3639
+ """
3640
+ # set models to training mode
3641
+ self.upsampler_model.train()
3642
+ if self.upsampler_model.forward_diffusion.variance_scheduler.trainable_beta:
3643
+ self.upsampler_model.forward_diffusion.variance_scheduler.train()
3644
+ else:
3645
+ self.upsampler_model.forward_diffusion.variance_scheduler.eval()
3646
+
3647
+ # initialize training components
3648
+ scaler = torch.GradScaler() if self.use_autocast else None
3649
+ train_losses = []
3650
+ best_val_loss = float("inf")
3651
+ wait = 0
3652
+
3653
+ # main training loop
3654
+ for epoch in range(self.max_epochs):
3655
+ if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
3656
+ self.train_loader.sampler.set_epoch(epoch)
3657
+
3658
+ train_losses_epoch = []
3659
+
3660
+ # training step loop with gradient accumulation
3661
+ for step, (low_res_images, high_res_images) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
3662
+ low_res_images = low_res_images.to(self.device, non_blocking=True)
3663
+ high_res_images = high_res_images.to(self.device, non_blocking=True)
3664
+
3665
+ # forward pass with optional autocast
3666
+ if self.use_autocast:
3667
+ with torch.autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
3668
+ batch_size = high_res_images.shape[0]
3669
+ timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
3670
+ noise = torch.randn_like(high_res_images)
3671
+ # force FP32 for forward_diffusion to avoid NaN in variance scheduling
3672
+ with torch.autocast(device_type='cuda', enabled=False):
3673
+ high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
3674
+ corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
3675
+ low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
3676
+ predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
3677
+ loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
3678
+ else:
3679
+ batch_size = high_res_images.shape[0]
3680
+ timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
3681
+ noise = torch.randn_like(high_res_images)
3682
+ high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
3683
+ corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
3684
+ low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
3685
+ predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
3686
+ loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
3687
+
3688
+ # backward pass
3689
+ if self.use_autocast:
3690
+ scaler.scale(loss).backward()
3691
+ else:
3692
+ loss.backward()
3693
+
3694
+ if (step + 1) % self.grad_accumulation_steps == 0:
3695
+ # clip gradients
3696
+ if self.use_autocast:
3697
+ scaler.unscale_(self.optimizer)
3698
+ torch.nn.utils.clip_grad_norm_(self.upsampler_model.parameters(), max_norm=1.0)
3699
+ torch.nn.utils.clip_grad_norm_(self.upsampler_model.forward_diffusion.parameters(), max_norm=1.0)
3700
+
3701
+ # optimizer step
3702
+ if self.use_autocast:
3703
+ scaler.step(self.optimizer)
3704
+ scaler.update()
3705
+ else:
3706
+ self.optimizer.step()
3707
+ self.optimizer.zero_grad()
3708
+ torch.cuda.empty_cache() # clear memory after optimizer step
3709
+
3710
+ train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
3711
+
3712
+ self.warmup_lr_scheduler.step()
3713
+
3714
+ mean_train_loss = self._compute_mean_loss(train_losses_epoch)
3715
+ train_losses.append(mean_train_loss)
3716
+
3717
+ if self.master_process and (epoch + 1) % self.log_frequency == 0:
3718
+ current_lr = self.optimizer.param_groups[0]['lr']
3719
+ print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
3720
+
3721
+ current_loss = mean_train_loss
3722
+
3723
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
3724
+ val_loss = self.validate()
3725
+ if self.master_process:
3726
+ print(f" | Val Loss: {val_loss:.4f}")
3727
+ print()
3728
+ current_loss = val_loss
3729
+
3730
+ self.scheduler.step(current_loss)
3731
+
3732
+ if self.master_process:
3733
+ if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
3734
+ best_val_loss = current_loss
3735
+ wait = 0
3736
+ self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
3737
+ else:
3738
+ wait += 1
3739
+ if wait >= self.patience:
3740
+ print("Early stopping triggered")
3741
+ self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
3742
+ break
3743
+
3744
+ if self.use_ddp:
3745
+ destroy_process_group()
3746
+
3747
+ return train_losses, best_val_loss
3748
+
3749
+ def _compute_mean_loss(self, losses: List[float]) -> float:
3750
+ """Computes mean loss with DDP synchronization if needed.
3751
+
3752
+ Calculates the mean of the provided losses and synchronizes the result across
3753
+ processes in DDP mode.
3754
+
3755
+ Parameters
3756
+ ----------
3757
+ `losses` : List[float]
3758
+ List of loss values for the current epoch.
3759
+
3760
+ Returns
3761
+ -------
3762
+ mean_loss : float
3763
+ Mean loss value, synchronized if using DDP.
3764
+ """
3765
+ if not losses:
3766
+ return 0.0
3767
+ mean_loss = sum(losses) / len(losses)
3768
+ if self.use_ddp:
3769
+ # synchronize loss across all processes
3770
+ loss_tensor = torch.tensor(mean_loss, device=self.device)
3771
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
3772
+ mean_loss = (loss_tensor / self.ddp_world_size).item()
3773
+
3774
+ return mean_loss
3775
+
3776
+ def _setup_ddp(self) -> None:
3777
+ """Sets up Distributed Data Parallel training configuration.
3778
+
3779
+ Initializes the process group, sets up rank information, and configures the CUDA
3780
+ device for the current process in DDP mode.
3781
+ """
3782
+ required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
3783
+ for var in required_env_vars:
3784
+ if var not in os.environ:
3785
+ raise ValueError(f"DDP enabled but {var} environment variable not set")
3786
+
3787
+ if not torch.cuda.is_available():
3788
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
3789
+
3790
+ if not torch.distributed.is_initialized():
3791
+ init_process_group(backend="nccl")
3792
+
3793
+ self.ddp_rank = int(os.environ["RANK"])
3794
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
3795
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"])
3796
+
3797
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
3798
+ torch.cuda.set_device(self.device)
3799
+
3800
+ self.master_process = self.ddp_rank == 0
3801
+
3802
+ if self.master_process:
3803
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
3804
+
3805
+ def _setup_single_gpu(self) -> None:
3806
+ """Sets up single GPU or CPU training configuration.
3807
+
3808
+ Configures the training setup for single-device operation, setting rank and process
3809
+ information for non-DDP training.
3810
+ """
3811
+ self.ddp_rank = 0
3812
+ self.ddp_local_rank = 0
3813
+ self.ddp_world_size = 1
3814
+ self.master_process = True
3815
+
3816
+ @staticmethod
3817
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
3818
+ """Creates a learning rate scheduler for warmup.
3819
+
3820
+ Generates a scheduler that linearly increases the learning rate from 0 to the
3821
+ optimizer's initial value over the specified warmup epochs.
3822
+
3823
+ Parameters
3824
+ ----------
3825
+ `optimizer` : torch.optim.Optimizer
3826
+ Optimizer to apply the scheduler to.
3827
+ `warmup_epochs` : int
3828
+ Number of epochs for the warmup phase.
3829
+
3830
+ Returns
3831
+ -------
3832
+ lr_scheduler : torch.optim.lr_scheduler.LambdaLR
3833
+ Learning rate scheduler for warmup.
3834
+ """
3835
+ def lr_lambda(epoch):
3836
+ return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
3837
+
3838
+ return LambdaLR(optimizer, lr_lambda)
3839
+
3840
+ def _wrap_models_for_ddp(self) -> None:
3841
+ """Wraps models with DistributedDataParallel for multi-GPU training.
3842
+
3843
+ Configures the upsampler model for DDP training by wrapping it with DistributedDataParallel.
3844
+ """
3845
+ if self.use_ddp:
3846
+ self.upsampler_model = self.upsampler_model.to(self.ddp_local_rank)
3847
+ self.upsampler_model = DDP(
3848
+ self.upsampler_model,
3849
+ device_ids=[self.ddp_local_rank],
3850
+ find_unused_parameters=True
3851
+ )
3852
+
3853
+ def _compile_models(self) -> None:
3854
+ """Compiles models for optimization if supported.
3855
+
3856
+ Attempts to compile the upsampler model using torch.compile for optimization,
3857
+ falling back to uncompiled execution if compilation fails.
3858
+ """
3859
+ if self.use_compilation:
3860
+ try:
3861
+ self.upsampler_model = self.upsampler_model.to(self.device)
3862
+ self.upsampler_model = torch.compile(self.upsampler_model, mode="reduce-overhead")
3863
+
3864
+ if self.master_process:
3865
+ print("Models compiled successfully")
3866
+ except Exception as e:
3867
+ if self.master_process:
3868
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
3869
+
3870
+ def corrupt_conditioning_image(self, x_low: torch.Tensor, corruption_type: str = "gaussian_blur") -> torch.Tensor:
3871
+ """Corrupts the low-resolution conditioning image for robustness.
3872
+
3873
+ Applies Gaussian blur or BSR degradation to the low-resolution image to simulate
3874
+ real-world degradation, as specified in the UnCLIP paper.
3875
+
3876
+ Parameters
3877
+ ----------
3878
+ `x_low` : torch.Tensor
3879
+ Low-resolution input image, shape (batch_size, channels, low_res_size, low_res_size).
3880
+ `corruption_type` : str, optional
3881
+ Type of corruption to apply: "gaussian_blur" or "bsr_degradation" (default: "gaussian_blur").
3882
+
3883
+ Returns
3884
+ -------
3885
+ x_degraded : torch.Tensor
3886
+ Corrupted low-resolution image, same shape as input.
3887
+ """
3888
+ if corruption_type == "gaussian_blur":
3889
+ # apply Gaussian blur
3890
+ kernel_size = random.choice([3, 5, 7])
3891
+ sigma = random.uniform(0.5, 2.0)
3892
+ return self._gaussian_blur(x_low, kernel_size, sigma)
3893
+ elif corruption_type == "bsr_degradation":
3894
+ # more diverse BSR degradation for second upsampler
3895
+ return self._bsr_degradation(x_low)
3896
+ else:
3897
+ return x_low
3898
+
3899
+ def _gaussian_blur(self, x: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
3900
+ """Applies Gaussian blur to the input image.
3901
+
3902
+ Parameters
3903
+ ----------
3904
+ `x` : torch.Tensor
3905
+ Input image tensor, shape (batch_size, channels, height, width).
3906
+ `kernel_size` : int
3907
+ Size of the Gaussian kernel.
3908
+ `sigma` : float
3909
+ Standard deviation of the Gaussian distribution.
3910
+
3911
+ Returns
3912
+ -------
3913
+ x_blurred : torch.Tensor
3914
+ Blurred image tensor, same shape as input.
3915
+ """
3916
+ # create Gaussian kernel
3917
+ kernel = self._get_gaussian_kernel(kernel_size, sigma).to(x.device)
3918
+ kernel = kernel.expand(x.shape[1], 1, kernel_size, kernel_size)
3919
+ padding = kernel_size // 2
3920
+ return F.conv2d(x, kernel, padding=padding, groups=x.shape[1])
3921
+
3922
+ def _get_gaussian_kernel(self, kernel_size: int, sigma: float) -> torch.Tensor:
3923
+ """Generates a 2D Gaussian kernel.
3924
+
3925
+ Parameters
3926
+ ----------
3927
+ `kernel_size` : int
3928
+ Size of the Gaussian kernel.
3929
+ `sigma` : float
3930
+ Standard deviation of the Gaussian distribution.
3931
+
3932
+ Returns
3933
+ -------
3934
+ kernel : torch.Tensor
3935
+ 2D Gaussian kernel, shape (kernel_size, kernel_size).
3936
+ """
3937
+ coords = torch.arange(kernel_size, dtype=torch.float32) - kernel_size // 2
3938
+ g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
3939
+ g = g / g.sum()
3940
+ return g[:, None] * g[None, :]
3941
+
3942
+ def _bsr_degradation(self, x: torch.Tensor) -> torch.Tensor:
3943
+ """Applies BSR degradation to the input image.
3944
+
3945
+ Simulates degradation with noise and Gaussian blur, as used in the UnCLIP paper
3946
+ for the second upsampler.
3947
+
3948
+ Parameters
3949
+ ----------
3950
+ `x` : torch.Tensor
3951
+ Input image tensor, shape (batch_size, channels, height, width).
3952
+
3953
+ Returns
3954
+ -------
3955
+ x_degraded : torch.Tensor
3956
+ Degraded image tensor, same shape as input, clamped to [-1, 1].
3957
+ """
3958
+ # add noise
3959
+ noise_level = random.uniform(0.0, 0.1)
3960
+ noise = torch.randn_like(x) * noise_level
3961
+
3962
+ # apply blur
3963
+ kernel_size = random.choice([3, 5, 7])
3964
+ sigma = random.uniform(0.5, 3.0)
3965
+ x_degraded = self._gaussian_blur(x + noise, kernel_size, sigma)
3966
+
3967
+ return torch.clamp(x_degraded, -1.0, 1.0)
3968
+
3969
+ def validate(self) -> float:
3970
+ """Validates the UnCLIP upsampler model.
3971
+
3972
+ Computes the validation loss by applying forward diffusion to high-resolution images,
3973
+ predicting noise with the upsampler model conditioned on corrupted low-resolution images,
3974
+ and comparing predicted noise to ground truth.
3975
+
3976
+ Returns
3977
+ -------
3978
+ val_loss : float
3979
+ Mean validation loss.
3980
+ """
3981
+ # set models to eval mode for evaluation
3982
+ self.upsampler_model.eval()
3983
+ self.upsampler_model.forward_diffusion.eval()
3984
+
3985
+ val_losses = []
3986
+
3987
+ with torch.no_grad():
3988
+ for low_res_images, high_res_images in self.val_loader:
3989
+ low_res_images = low_res_images.to(self.device, non_blocking=True)
3990
+ high_res_images = high_res_images.to(self.device, non_blocking=True)
3991
+ batch_size = high_res_images.shape[0]
3992
+ timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
3993
+ noise = torch.randn_like(high_res_images)
3994
+ high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
3995
+ corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
3996
+ low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
3997
+ predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
3998
+ # compute loss
3999
+ loss = self.objective(predicted_noise, noise)
4000
+ val_losses.append(loss.item())
4001
+
4002
+ # compute average loss
4003
+ val_loss = torch.tensor(val_losses).mean().item()
4004
+
4005
+ if self.use_ddp:
4006
+ val_loss_tensor = torch.tensor(val_loss, device=self.device)
4007
+ dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
4008
+ val_loss = val_loss_tensor.item()
4009
+
4010
+ # return to training mode
4011
+ self.upsampler_model.train()
4012
+ if not self.upsampler_model.forward_diffusion.variance_scheduler.trainable_beta:
4013
+ self.upsampler_model.forward_diffusion.variance_scheduler.eval()
4014
+
4015
+ return val_loss
4016
+
4017
+ def _save_checkpoint(self, epoch: int, loss: float, is_best: bool = False, suffix: str = ""):
4018
+ """Saves model checkpoint.
4019
+
4020
+ Saves the state of the upsampler model, its variance scheduler, optimizer, and
4021
+ schedulers, with options for best model and epoch-specific checkpoints.
4022
+
4023
+ Parameters
4024
+ ----------
4025
+ `epoch` : int
4026
+ Current epoch number.
4027
+ `loss` : float
4028
+ Current loss value.
4029
+ `is_best` : bool, optional
4030
+ Whether to save as the best model checkpoint (default: False).
4031
+ `suffix` : str, optional
4032
+ Suffix to add to checkpoint filename, default "".
4033
+ """
4034
+ if not self.master_process:
4035
+ return
4036
+ checkpoint = {
4037
+ 'epoch': epoch,
4038
+ 'loss': loss,
4039
+ # core model
4040
+ 'upsampler_model_state_dict': self.upsampler_model.module.state_dict() if self.use_ddp else self.upsampler_model.state_dict(),
4041
+ 'optimizer_state_dict': self.optimizer.state_dict(),
4042
+ # training configuration
4043
+ 'model_channels': self.upsampler_model.model_channels,
4044
+ 'num_res_blocks': self.upsampler_model.num_res_blocks,
4045
+ 'normalize': self.normalize_image_outputs,
4046
+ 'output_range': self.image_output_range
4047
+ }
4048
+
4049
+ # save variance scheduler (submodule of forward_diffusion)
4050
+ checkpoint['variance_scheduler_state_dict'] = (
4051
+ self.upsampler_model.module.forward_diffusion.variance_scheduler.state_dict() if self.use_ddp
4052
+ else self.upsampler_model.forward_diffusion.variance_scheduler.state_dict()
4053
+ )
4054
+
4055
+ # save schedulers state
4056
+ checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
4057
+ checkpoint['warmup_scheduler_state_dict'] = self.warmup_lr_scheduler.state_dict()
4058
+
4059
+ filename = f"unclip_upsampler_epoch_{epoch}{suffix}.pth"
4060
+ if is_best:
4061
+ filename = f"unclip_upsampler_best{suffix}.pth"
4062
+
4063
+ filepath = os.path.join(self.store_path, filename)
4064
+ os.makedirs(self.store_path, exist_ok=True)
4065
+ torch.save(checkpoint, filepath)
4066
+
4067
+ if is_best:
4068
+ print(f"Best model saved: {filepath}")
4069
+
4070
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
4071
+ """Loads model checkpoint.
4072
+
4073
+ Restores the state of the upsampler model, its variance scheduler, optimizer, and
4074
+ schedulers from a saved checkpoint, handling DDP compatibility.
4075
+
4076
+ Parameters
4077
+ ----------
4078
+ `checkpoint_path` : str
4079
+ Path to the checkpoint file.
4080
+
4081
+ Returns
4082
+ -------
4083
+ epoch : int
4084
+ The epoch at which the checkpoint was saved.
4085
+ loss : float
4086
+ The loss at the checkpoint.
4087
+ """
4088
+ try:
4089
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
4090
+ except FileNotFoundError:
4091
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
4092
+
4093
+ def _load_model_state_dict(model: nn.Module, state_dict: dict, model_name: str) -> None:
4094
+ """Helper function to load state dict with DDP compatibility."""
4095
+ try:
4096
+ # handle DDP state dict compatibility
4097
+ if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
4098
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
4099
+ elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
4100
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
4101
+
4102
+ model.load_state_dict(state_dict)
4103
+ if self.master_process:
4104
+ print(f"✓ Loaded {model_name}")
4105
+ except Exception as e:
4106
+ warnings.warn(f"Failed to load {model_name}: {e}")
4107
+
4108
+ # load core upsampler model
4109
+ if 'upsampler_model_state_dict' in checkpoint:
4110
+ _load_model_state_dict(self.upsampler_model, checkpoint['upsampler_model_state_dict'],
4111
+ 'upsampler_model')
4112
+
4113
+ # load variance scheduler (submodule of forward_diffusion)
4114
+ if 'variance_scheduler_state_dict' in checkpoint or 'hyper_params_state_dict' in checkpoint:
4115
+ state_dict = checkpoint.get('variance_scheduler_state_dict', checkpoint.get('hyper_params_state_dict'))
4116
+ try:
4117
+ _load_model_state_dict(self.upsampler_model.forward_diffusion.variance_scheduler, state_dict, 'variance_scheduler')
4118
+ except Exception as e:
4119
+ warnings.warn(f"Failed to load variance scheduler: {e}")
4120
+
4121
+ # load optimizer
4122
+ if 'optimizer_state_dict' in checkpoint:
4123
+ try:
4124
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
4125
+ if self.master_process:
4126
+ print("✓ Loaded optimizer")
4127
+ except Exception as e:
4128
+ warnings.warn(f"Failed to load optimizer state: {e}")
4129
+
4130
+ # load schedulers
4131
+ if 'scheduler_state_dict' in checkpoint:
4132
+ try:
4133
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
4134
+ if self.master_process:
4135
+ print("✓ Loaded main scheduler")
4136
+ except Exception as e:
4137
+ warnings.warn(f"Failed to load scheduler state: {e}")
4138
+
4139
+ if 'warmup_scheduler_state_dict' in checkpoint:
4140
+ try:
4141
+ self.warmup_lr_scheduler.load_state_dict(checkpoint['warmup_scheduler_state_dict'])
4142
+ if self.master_process:
4143
+ print("✓ Loaded warmup scheduler")
4144
+ except Exception as e:
4145
+ warnings.warn(f"Failed to load warmup scheduler state: {e}")
4146
+
4147
+ # verify configuration compatibility
4148
+ if 'model_channels' in checkpoint:
4149
+ if checkpoint['model_channels'] != self.upsampler_model.model_channels:
4150
+ warnings.warn(
4151
+ f"Model channels mismatch: checkpoint={checkpoint['model_channels']}, current={self.upsampler_model.model_channels}")
4152
+
4153
+ if 'num_res_blocks' in checkpoint:
4154
+ if checkpoint['num_res_blocks'] != self.upsampler_model.num_res_blocks:
4155
+ warnings.warn(
4156
+ f"Num res blocks mismatch: checkpoint={checkpoint['num_res_blocks']}, current={self.upsampler_model.num_res_blocks}")
4157
+
4158
+ if 'normalize' in checkpoint:
4159
+ if checkpoint['normalize'] != self.normalize_image_outputs:
4160
+ warnings.warn(
4161
+ f"Normalize setting mismatch: checkpoint={checkpoint['normalize']}, current={self.normalize_image_outputs}")
4162
+
4163
+ epoch = checkpoint.get('epoch', 0)
4164
+ loss = checkpoint.get('loss', float('inf'))
4165
+
4166
+ if self.master_process:
4167
+ print(f"Successfully loaded checkpoint from {checkpoint_path}")
4168
+ print(f"Epoch: {epoch}, Loss: {loss:.4f}")
4169
+
4170
+ return epoch, loss