TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
torchdiff/ddim.py ADDED
@@ -0,0 +1,1222 @@
1
+ """
2
+ **Denoising Diffusion Implicit Models (DDIM)**
3
+
4
+ This module provides a complete implementation of DDIM, as described in Song et al.
5
+ (2021, "Denoising Diffusion Implicit Models"). It includes components for forward and
6
+ reverse diffusion processes, hyperparameter management, training, and image sampling.
7
+ Supports both unconditional and conditional generation with text prompts, using a
8
+ subsampled time step schedule for faster sampling compared to DDPM.
9
+
10
+ **Components**
11
+
12
+ - **ForwardDDIM**: Forward diffusion process to add noise.
13
+ - **ReverseDDIM**: Reverse diffusion process to denoise with subsampled steps.
14
+ - **VarianceSchedulerDDIM**: Noise schedule management with subsampled (tau) schedule.
15
+ - **TrainDDIM**: Training loop with mixed precision and scheduling.
16
+ - **SampleDDIM**: Image generation from trained models with subsampled steps.
17
+
18
+ **Notes**
19
+
20
+ - The subsampled time step schedule (tau) enables faster sampling, controlled by the
21
+ `tau_num_steps` parameter in HyperParamsDDIM.
22
+
23
+ **References**:
24
+
25
+ - Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising diffusion implicit models." arXiv preprint arXiv:2010.02502 (2020).
26
+
27
+ -------------------------------------------------------------------------------
28
+ """
29
+
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.distributed as dist
34
+ from torch.nn.parallel import DistributedDataParallel as DDP
35
+ from torch.distributed import init_process_group, destroy_process_group
36
+ from tqdm import tqdm
37
+ from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
38
+ from transformers import BertTokenizer
39
+ import warnings
40
+ from torchvision.utils import save_image
41
+ from typing import Optional, Tuple, Callable, List, Any, Union, Self
42
+ import os
43
+
44
+
45
+ ###==================================================================================================================###
46
+
47
+
48
+ class ForwardDDIM(nn.Module):
49
+ """Forward diffusion process of DDIM.
50
+
51
+ Implements the forward diffusion process for Denoising Diffusion Implicit Models (DDIM),
52
+ which perturbs input data by adding Gaussian noise over a series of time steps,
53
+ as defined in Song et al. (2021, "Denoising Diffusion Implicit Models").
54
+
55
+ Parameters
56
+ ----------
57
+ `variance_scheduler` : object
58
+ Hyperparameter object (VarianceSchedulerDDIM) containing the noise schedule parameters.
59
+ Expected to have attributes: `num_steps`, `trainable_beta`, `betas`, `sqrt_alpha_cumprod`,
60
+ `sqrt_one_minus_alpha_cumprod`, `compute_schedule`
61
+ """
62
+
63
+ def __init__(self, variance_scheduler: torch.nn.Module) -> None:
64
+ super().__init__()
65
+ self.variance_scheduler = variance_scheduler
66
+
67
+ def forward(self, x0: torch.Tensor, noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
68
+ """Applies the forward diffusion process to the input data.
69
+
70
+ Perturbs the input data `x0` by adding Gaussian noise according to the DDIM
71
+ forward process at specified time steps, using cumulative noise schedule parameters.
72
+
73
+ Parameters
74
+ ----------
75
+ `x0` : torch.Tensor
76
+ Input data tensor of shape (batch_size, channels, height, width).
77
+ `noise` : torch.Tensor
78
+ Gaussian noise tensor of the same shape as `x0`.
79
+ `time_steps` : torch.Tensor
80
+ Tensor of time step indices (long), shape (batch_size,),
81
+ where each value is in the range [0, hyper_params.num_steps - 1].
82
+
83
+ Returns
84
+ -------
85
+ xt (torch.Tensor) - Noisy data tensor `xt` at the specified time steps, with the same shape as `x0`.
86
+ """
87
+ if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.num_steps)):
88
+ raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.num_steps - 1}")
89
+
90
+ if self.variance_scheduler.trainable_beta:
91
+ _, _, _, sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.compute_schedule(
92
+ time_steps
93
+ )
94
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.to(x0.device)
95
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.to(x0.device)
96
+ else:
97
+ sqrt_alpha_cumprod_t = self.variance_scheduler.sqrt_alpha_cumprod[time_steps].to(x0.device)
98
+ sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.sqrt_one_minus_alpha_cumprod[time_steps].to(x0.device)
99
+
100
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1, 1, 1)
101
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1, 1, 1)
102
+
103
+ xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
104
+
105
+ return xt
106
+
107
+
108
+ ###==================================================================================================================###
109
+
110
+
111
+ class ReverseDDIM(nn.Module):
112
+ """Reverse diffusion process of DDIM.
113
+
114
+ Implements the reverse diffusion process for Denoising Diffusion Implicit Models
115
+ (DDIM), which denoises a noisy input `xt` using a predicted noise component and a
116
+ subsampled time step schedule, as defined in Song et al. (2021).
117
+
118
+ Parameters
119
+ ----------
120
+ `variance_scheduler` : object
121
+ Hyperparameter object (VarianceSchedulerDDIM) containing the noise schedule parameters.
122
+ Expected to have attributes: `tau_num_steps`, `eta`, `get_tau_schedule`.
123
+ """
124
+
125
+ def __init__(self, variance_scheduler: torch.nn.Module):
126
+ super().__init__()
127
+ self.variance_scheduler = variance_scheduler
128
+
129
+ def forward(self, xt: torch.Tensor, predicted_noise: torch.Tensor, time_steps: torch.Tensor, prev_time_steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
130
+ """Applies the reverse diffusion process to the noisy input.
131
+
132
+ Denoises the input `xt` at time step `t` to produce the previous step `xt_prev`
133
+ at `prev_time_steps` using the predicted noise and the DDIM reverse process.
134
+ Optionally includes stochastic noise scaled by `eta`.
135
+
136
+ Parameters
137
+ ----------
138
+ `xt` : torch.Tensor
139
+ Noisy input tensor at time step `t`, shape (batch_size, channels, height, width).
140
+ `predicted_noise` : torch.Tensor
141
+ Predicted noise tensor, same shape as `xt`, typically output by a neural network.
142
+ `time_steps` : torch.Tensor
143
+ Tensor of time step indices (long), shape (batch_size,), where each value
144
+ is in the range [0, hyper_params.tau_num_steps - 1].
145
+ `prev_time_steps` : torch.Tensor
146
+ Tensor of previous time step indices (long), shape (batch_size,), where each
147
+ value is in the range [0, hyper_params.tau_num_steps - 1].
148
+
149
+ Returns
150
+ -------
151
+ xt_prev : torch.Tensor
152
+ Denoised tensor at `prev_time_steps`, same shape as `xt`.
153
+ x0 : torch.Tensor
154
+ Estimated original data (t=0), same shape as `xt`.
155
+ """
156
+ if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.tau_num_steps)):
157
+ raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
158
+ if not torch.all((prev_time_steps >= 0) & (prev_time_steps < self.variance_scheduler.tau_num_steps)):
159
+ raise ValueError(f"prev_time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
160
+
161
+ _, _, _, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = self.variance_scheduler.get_tau_schedule()
162
+ tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
163
+ tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
164
+ prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
165
+ prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
166
+
167
+ eta = self.variance_scheduler.eta
168
+ x0 = (xt - tau_sqrt_one_minus_alpha_cumprod_t * predicted_noise) / tau_sqrt_alpha_cumprod_t
169
+ noise_coeff = eta * ((tau_sqrt_one_minus_alpha_cumprod_t / prev_tau_sqrt_alpha_cumprod_t) *
170
+ prev_tau_sqrt_one_minus_alpha_cumprod_t / torch.clamp(tau_sqrt_one_minus_alpha_cumprod_t, min=1e-8))
171
+ direction_coeff = torch.clamp(prev_tau_sqrt_one_minus_alpha_cumprod_t ** 2 - noise_coeff ** 2, min=1e-8).sqrt()
172
+ xt_prev = prev_tau_sqrt_alpha_cumprod_t * x0 + noise_coeff * torch.randn_like(xt) + direction_coeff * predicted_noise
173
+
174
+ return xt_prev, x0
175
+
176
+
177
+ ###==================================================================================================================###
178
+
179
+ class VarianceSchedulerDDIM(nn.Module):
180
+ """Hyperparameters for DDIM noise schedule with flexible beta computation.
181
+
182
+ Manages the noise schedule parameters for DDIM, including beta values, derived
183
+ quantities (alphas, alpha_cumprod, etc.), and a subsampled time step schedule
184
+ (tau schedule), as inspired by Song et al. (2021). Supports trainable or fixed
185
+ schedules and various beta scheduling methods.
186
+
187
+ Parameters
188
+ ----------
189
+ `eta` : float, optional
190
+ Noise scaling factor for the DDIM reverse process (default: 0, deterministic).
191
+ `num_steps` : int, optional
192
+ Total number of diffusion steps (default: 1000).
193
+ `tau_num_steps` : int, optional
194
+ Number of subsampled time steps for DDIM sampling (default: 100).
195
+ `beta_start` : float, optional
196
+ Starting value for beta (default: 1e-4).
197
+ `beta_end` : float, optional
198
+ Ending value for beta (default: 0.02).
199
+ `trainable_beta` : bool, optional
200
+ Whether the beta schedule is trainable (default: False).
201
+ `beta_method` : str, optional
202
+ Method for computing the beta schedule (default: "linear").
203
+ Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ eta: Optional[float] = None,
209
+ num_steps: int = 1000,
210
+ tau_num_steps: int = 100,
211
+ beta_start: float = 1e-4,
212
+ beta_end: float = 0.02,
213
+ trainable_beta: bool = False,
214
+ beta_method: str = "linear"
215
+ ):
216
+ super().__init__()
217
+ self.eta = eta or 0
218
+ self.num_steps = num_steps
219
+ self.tau_num_steps = tau_num_steps
220
+ self.beta_start = beta_start
221
+ self.beta_end = beta_end
222
+ self.trainable_beta = trainable_beta
223
+ self.beta_method = beta_method
224
+
225
+ if not (0 < beta_start < beta_end < 1):
226
+ raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
227
+ if num_steps <= 0:
228
+ raise ValueError(f"num_steps ({num_steps}) must be positive")
229
+
230
+ beta_range = (beta_start, beta_end)
231
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
232
+
233
+ if trainable_beta:
234
+ # Use reparameterization trick for trainable betas
235
+ # Initialize unconstrained parameters and transform them to valid beta range
236
+ self.beta_raw = nn.Parameter(torch.logit((betas_init - beta_start) / (beta_end - beta_start)))
237
+ else:
238
+ self.register_buffer('betas_buffer', betas_init)
239
+ self.register_buffer('alphas', 1 - self.betas)
240
+ self.register_buffer('alpha_cumprod', torch.cumprod(self.alphas, dim=0))
241
+ self.register_buffer('sqrt_alpha_cumprod', torch.sqrt(self.alpha_cumprod))
242
+ self.register_buffer('sqrt_one_minus_alpha_cumprod', torch.sqrt(1 - self.alpha_cumprod))
243
+
244
+ self.register_buffer('tau_indices', torch.linspace(0, num_steps - 1, tau_num_steps, dtype=torch.long))
245
+
246
+
247
+ @property
248
+ def betas(self) -> torch.Tensor:
249
+ """Returns the beta values, applying reparameterization if trainable."""
250
+ if self.trainable_beta:
251
+ # Transform unconstrained parameters to valid beta range using sigmoid
252
+ return self.beta_start + (self.beta_end - self.beta_start) * torch.sigmoid(self.beta_raw)
253
+ # Return the registered buffer directly if it exists
254
+ #return getattr(self, '_buffers', {}).get('betas_buffer', None) or ValueError("Betas buffer not found")
255
+ return self._buffers['betas_buffer']
256
+
257
+
258
+ def compute_beta_schedule(self, beta_range: Tuple[float, float], num_steps: int, method: str) -> torch.Tensor:
259
+ """Computes the beta schedule based on the specified method.
260
+
261
+ Generates a sequence of beta values for the DDIM noise schedule using the
262
+ chosen method, ensuring values are clamped within the specified range.
263
+
264
+ Parameters
265
+ ----------
266
+ `beta_range` : tuple
267
+ Tuple of (min_beta, max_beta) specifying the valid range for beta values.
268
+ `num_steps` : int
269
+ Number of diffusion steps.
270
+ `method` : str
271
+ Method for computing the beta schedule. Supported methods:
272
+ "linear", "sigmoid", "quadratic", "constant", "inverse_time".
273
+
274
+ Returns
275
+ -------
276
+ beta (torch.Tensor) - Tensor of beta values, shape (num_steps,).
277
+ """
278
+ beta_min, beta_max = beta_range
279
+ if method == "sigmoid":
280
+ x = torch.linspace(-6, 6, num_steps)
281
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
282
+ elif method == "quadratic":
283
+ x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
284
+ beta = x ** 2
285
+ elif method == "constant":
286
+ beta = torch.full((num_steps,), beta_max)
287
+ elif method == "inverse_time":
288
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
289
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
290
+ elif method == "linear":
291
+ beta = torch.linspace(beta_min, beta_max, num_steps)
292
+ else:
293
+ raise ValueError(
294
+ f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
295
+
296
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
297
+ return beta
298
+
299
+ def get_tau_schedule(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
300
+ """Computes the subsampled (tau) noise schedule for DDIM.
301
+
302
+ Returns the noise schedule parameters for the subsampled time steps used in
303
+ DDIM sampling, based on the `tau_indices`.
304
+
305
+ Returns
306
+ -------
307
+ tau_betas : torch.Tensor
308
+ Beta values for subsampled steps, shape (tau_num_steps,).
309
+ tau_alphas : torch.Tensor
310
+ Alpha values for subsampled steps, shape (tau_num_steps,).
311
+ tau_alpha_cumprod : torch.Tensor
312
+ Cumulative product of alphas for subsampled steps, shape (tau_num_steps,).
313
+ tau_sqrt_alpha_cumprod : torch.Tensor
314
+ Square root of alpha_cumprod for subsampled steps, shape (tau_num_steps,).
315
+ tau_sqrt_one_minus_alpha_cumprod : torch.Tensor
316
+ Square root of (1 - alpha_cumprod) for subsampled steps, shape (tau_num_steps,).
317
+ """
318
+ if self.trainable_beta:
319
+ # Use the property to get constrained betas
320
+ betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod = self.compute_schedule()
321
+ else:
322
+ betas = self.betas
323
+ alphas = self.alphas
324
+ alpha_cumprod = self.alpha_cumprod
325
+ sqrt_alpha_cumprod = self.sqrt_alpha_cumprod
326
+ sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod
327
+
328
+ tau_betas = betas[self.tau_indices]
329
+ tau_alphas = alphas[self.tau_indices]
330
+ tau_alpha_cumprod = alpha_cumprod[self.tau_indices]
331
+ tau_sqrt_alpha_cumprod = sqrt_alpha_cumprod[self.tau_indices]
332
+ tau_sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[self.tau_indices]
333
+
334
+ return tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod
335
+
336
+ def compute_schedule(self, time_steps: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
337
+ """Computes noise schedule parameters dynamically from betas.
338
+
339
+ Calculates the derived noise schedule parameters (alphas, alpha_cumprod, etc.)
340
+ from the provided beta values, as used in the DDIM forward and reverse processes.
341
+
342
+ Parameters
343
+ ----------
344
+ `time_steps` : torch.Tensor, optional
345
+ If provided, returns parameters only for specified time steps.
346
+ If None, returns parameters for all time steps.
347
+
348
+ Returns
349
+ -------
350
+ betas : torch.Tensor
351
+ Beta values, shape (num_steps,) or (len(time_steps),).
352
+ alphas : torch.Tensor
353
+ 1 - betas, shape (num_steps,) or (len(time_steps),).
354
+ alpha_cumprod : torch.Tensor
355
+ Cumulative product of alphas, shape (num_steps,) or (len(time_steps),).
356
+ sqrt_alpha_cumprod : torch.Tensor
357
+ Square root of alpha_cumprod, shape (num_steps,) or (len(time_steps),).
358
+ sqrt_one_minus_alpha_cumprod : torch.Tensor
359
+ Square root of (1 - alpha_cumprod), shape (num_steps,) or (len(time_steps),).
360
+ """
361
+ # Use the property to get constrained betas
362
+ betas = self.betas
363
+ alphas = 1 - betas
364
+ alpha_cumprod = torch.cumprod(alphas, dim=0)
365
+ sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod)
366
+ sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alpha_cumprod)
367
+
368
+ if time_steps is not None:
369
+ return (betas[time_steps], alphas[time_steps], alpha_cumprod[time_steps],
370
+ sqrt_alpha_cumprod[time_steps], sqrt_one_minus_alpha_cumprod[time_steps])
371
+ else:
372
+ return betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod
373
+
374
+
375
+ ###==================================================================================================================###
376
+
377
+
378
+ class TrainDDIM(nn.Module):
379
+ """Trainer for Denoising Diffusion Implicit Models (DDIM).
380
+
381
+ Manages the training process for DDIM, optimizing a noise predictor model to learn
382
+ the noise added by the forward diffusion process. Supports conditional training with
383
+ text prompts, mixed precision training, learning rate scheduling, early stopping, and
384
+ checkpointing, as inspired by Song et al. (2021).
385
+
386
+ Parameters
387
+ ----------
388
+ `noise_predictor` : nn.Module
389
+ Model to predict noise added during the forward diffusion process.
390
+ forward_diffusion : nn.Module
391
+ Forward DDIM diffusion module for adding noise.
392
+ reverse_diffusion: nn.Module
393
+ Reverse DDIM diffusion module for denoising.
394
+ `data_loader` : torch.utils.data.DataLoader
395
+ DataLoader for training data.
396
+ `optimizer` : torch.optim.Optimizer
397
+ Optimizer for training the noise predictor and conditional model (if applicable).
398
+ `objective` : callable
399
+ Loss function to compute the difference between predicted and actual noise.
400
+ `val_loader` : torch.utils.data.DataLoader, optional
401
+ DataLoader for validation data, default None.
402
+ `max_epochs` : int, optional
403
+ Maximum number of training epochs (default: 1000).
404
+ `device` : torch.device, optional
405
+ Device for computation (default: CUDA if available, else CPU).
406
+ `conditional_model` : nn.Module, optional
407
+ Model for conditional generation (e.g., text embeddings), default None.
408
+ `metrics_` : object, optional
409
+ Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
410
+ `bert_tokenizer` : BertTokenizer, optional
411
+ Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
412
+ `max_token_length` : int, optional
413
+ Maximum length for tokenized prompts (default: 77).
414
+ `store_path` : str, optional
415
+ Path to save model checkpoints (default: "ddim_model.pth").
416
+ `patience` : int, optional
417
+ Number of epochs to wait for improvement before early stopping (default: 100).
418
+ `warmup_epochs` : int, optional
419
+ Number of epochs for learning rate warmup (default: 100).
420
+ `val_frequency` : int, optional
421
+ Frequency (in epochs) for validation (default: 10).
422
+ `output_range` : tuple, optional
423
+ Range for clamping generated images (default: (-1, 1)).
424
+ `normalize_output` : bool, optional
425
+ Whether to normalize generated images to [0, 1] for metrics (default: True).
426
+ `use_ddp` : bool, optional
427
+ Whether to use Distributed Data Parallel training (default: False).
428
+ `grad_accumulation_steps` : int, optional
429
+ Number of gradient accumulation steps before optimizer update (default: 1).
430
+ `log_frequency` : int, optional
431
+ Number of epochs before printing loss.
432
+ use_compilation : bool, optional
433
+ whether the model is internally compiled using torch.compile (default: false)
434
+ """
435
+ def __init__(
436
+ self,
437
+ noise_predictor: torch.nn.Module,
438
+ forward_diffusion: torch.nn.Module,
439
+ reverse_diffusion: torch.nn.Module,
440
+ data_loader: torch.utils.data.DataLoader,
441
+ optimizer: torch.optim.Optimizer,
442
+ objective: Callable,
443
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
444
+ max_epochs: int = 1000,
445
+ device: str = None,
446
+ conditional_model: torch.nn.Module = None,
447
+ metrics_: Optional[Any] = None,
448
+ bert_tokenizer: Optional[BertTokenizer] = None,
449
+ max_token_length: int = 77,
450
+ store_path: Optional[str] = None,
451
+ patience: int = 100,
452
+ warmup_epochs: int = 100,
453
+ val_frequency: int = 10,
454
+ image_output_range: Tuple[float, float] = (-1, 1),
455
+ normalize_output: bool = True,
456
+ use_ddp: bool = False,
457
+ grad_accumulation_steps: int = 1,
458
+ log_frequency: int = 1,
459
+ use_compilation: bool = False
460
+ ) -> None:
461
+ super().__init__()
462
+ # initialize DDP settings first
463
+ self.use_ddp = use_ddp
464
+ self.grad_accumulation_steps = grad_accumulation_steps
465
+ if device is None:
466
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
467
+ elif isinstance(device, str):
468
+ self.device = torch.device(device)
469
+ else:
470
+ self.device = device
471
+
472
+ # setup distributed training if enabled
473
+ if self.use_ddp:
474
+ self._setup_ddp()
475
+ else:
476
+ self._setup_single_gpu()
477
+
478
+ # move models to appropriate device
479
+ self.noise_predictor = noise_predictor.to(self.device)
480
+ self.forward_diffusion = forward_diffusion.to(self.device)
481
+ self.reverse_diffusion = reverse_diffusion.to(self.device)
482
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
483
+
484
+ # training components
485
+ self.metrics_ = metrics_
486
+ self.optimizer = optimizer
487
+ self.objective = objective
488
+ self.store_path = store_path or "ddim_model"
489
+ self.data_loader = data_loader
490
+ self.val_loader = val_loader
491
+ self.max_epochs = max_epochs
492
+ self.max_token_length = max_token_length
493
+ self.patience = patience
494
+ self.val_frequency = val_frequency
495
+ self.image_output_range = image_output_range
496
+ self.normalize_output = normalize_output
497
+ self.log_frequency = log_frequency
498
+ self.use_compilation = use_compilation
499
+
500
+ # learning rate scheduling
501
+ self.scheduler = ReduceLROnPlateau(
502
+ self.optimizer,
503
+ patience=self.patience,
504
+ factor=0.5
505
+ )
506
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
507
+
508
+ # initialize tokenizer
509
+ if bert_tokenizer is None:
510
+ try:
511
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
512
+ except Exception as e:
513
+ raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
514
+
515
+ def _setup_ddp(self) -> None:
516
+ """Setup Distributed Data Parallel training configuration.
517
+
518
+ Initializes process group, determines rank information, and sets up
519
+ CUDA device for the current process.
520
+ """
521
+ # check if DDP environment variables are set
522
+ if "RANK" not in os.environ:
523
+ raise ValueError("DDP enabled but RANK environment variable not set")
524
+ if "LOCAL_RANK" not in os.environ:
525
+ raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
526
+ if "WORLD_SIZE" not in os.environ:
527
+ raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
528
+
529
+ # ensure CUDA is available for DDP
530
+ if not torch.cuda.is_available():
531
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
532
+
533
+ # initialize process group only if not already initialized
534
+ if not torch.distributed.is_initialized():
535
+ init_process_group(backend="nccl")
536
+
537
+ # get rank information
538
+ self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
539
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
540
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
541
+
542
+ # set device and make it current
543
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
544
+ torch.cuda.set_device(self.device)
545
+
546
+ # master process handles logging, checkpointing, etc.
547
+ self.master_process = self.ddp_rank == 0
548
+
549
+ if self.master_process:
550
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
551
+
552
+ def _setup_single_gpu(self) -> None:
553
+ """Setup single GPU or CPU training configuration."""
554
+ self.ddp_rank = 0
555
+ self.ddp_local_rank = 0
556
+ self.ddp_world_size = 1
557
+ self.master_process = True
558
+
559
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
560
+ """Loads a training checkpoint to resume training.
561
+
562
+ Restores the state of the noise predictor, conditional model (if applicable),
563
+ and optimizer from a saved checkpoint. Handles DDP model state dict loading.
564
+
565
+ Parameters
566
+ ----------
567
+ checkpoint_path : str
568
+ Path to the checkpoint file.
569
+
570
+ Returns
571
+ -------
572
+ epoch : int
573
+ The epoch at which the checkpoint was saved.
574
+ loss : float
575
+ The loss at the checkpoint.
576
+ """
577
+ try:
578
+ # load checkpoint with proper device mapping
579
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
580
+ except FileNotFoundError:
581
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
582
+
583
+ # load noise predictor state
584
+ if 'model_state_dict_noise_predictor' not in checkpoint:
585
+ raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
586
+
587
+ # handle DDP wrapped model state dict
588
+ state_dict = checkpoint['model_state_dict_noise_predictor']
589
+ if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
590
+ # if loading non-DDP checkpoint into DDP model, add 'module.' prefix
591
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
592
+ elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
593
+ # if loading DDP checkpoint into non-DDP model, remove 'module.' prefix
594
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
595
+
596
+ self.noise_predictor.load_state_dict(state_dict)
597
+
598
+ # load conditional model state if applicable
599
+ if self.conditional_model is not None:
600
+ if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
601
+ cond_state_dict = checkpoint['model_state_dict_conditional']
602
+ # handle DDP wrapping for conditional model
603
+ if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
604
+ cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
605
+ elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
606
+ cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
607
+ self.conditional_model.load_state_dict(cond_state_dict)
608
+ else:
609
+ warnings.warn(
610
+ "Checkpoint contains no 'model_state_dict_conditional' or it is None, "
611
+ "skipping conditional model loading"
612
+ )
613
+
614
+ # load variance_scheduler state
615
+ if 'variance_scheduler_model' not in checkpoint:
616
+ raise KeyError("Checkpoint missing 'variance_scheduler_model' key")
617
+ try:
618
+ if isinstance(self.forward_diffusion.variance_scheduler, nn.Module):
619
+ self.forward_diffusion.variance_scheduler.load_state_dict(
620
+ checkpoint['variance_scheduler_model'])
621
+ if isinstance(self.reverse_diffusion.variance_scheduler, nn.Module):
622
+ self.reverse_diffusion.variance_scheduler.load_state_dict(
623
+ checkpoint['variance_scheduler_model'])
624
+ else:
625
+ self.forward_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
626
+ self.reverse_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
627
+ except Exception as e:
628
+ warnings.warn(f"Variance_scheduler loading failed: {e}. Continuing with current variance_scheduler.")
629
+
630
+ # load optimizer state
631
+ if 'optimizer_state_dict' not in checkpoint:
632
+ raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
633
+ try:
634
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
635
+ except ValueError as e:
636
+ warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
637
+
638
+ epoch = checkpoint.get('epoch', -1)
639
+ loss = checkpoint.get('loss', float('inf'))
640
+
641
+ if self.master_process:
642
+ print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
643
+ return epoch, loss
644
+
645
+ @staticmethod
646
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
647
+ """Creates a learning rate scheduler for warmup.
648
+
649
+ Generates a scheduler that linearly increases the learning rate from 0 to the
650
+ optimizer's initial value over the specified warmup epochs, then maintains it.
651
+
652
+ Parameters
653
+ ----------
654
+ `optimizer` : torch.optim.Optimizer
655
+ Optimizer to apply the scheduler to.
656
+ `warmup_epochs` : int
657
+ Number of epochs for the warmup phase.
658
+
659
+ Returns
660
+ -------
661
+ lr_scheduler (torch.optim.lr_scheduler.LambdaLR) - Learning rate scheduler for warmup.
662
+ """
663
+ def lr_lambda(epoch: int) -> float:
664
+ if epoch < warmup_epochs:
665
+ return epoch / warmup_epochs
666
+ return 1.0
667
+
668
+ return LambdaLR(optimizer, lr_lambda)
669
+
670
+ def _wrap_models_for_ddp(self) -> None:
671
+ """Wrap models with DistributedDataParallel for multi-GPU training."""
672
+ if self.use_ddp:
673
+ # wrap noise predictor with DDP
674
+ self.noise_predictor = DDP(
675
+ self.noise_predictor,
676
+ device_ids=[self.ddp_local_rank],
677
+ find_unused_parameters=True
678
+ )
679
+
680
+ # wrap conditional model with DDP if it exists
681
+ if self.conditional_model is not None:
682
+ self.conditional_model = DDP(
683
+ self.conditional_model,
684
+ device_ids=[self.ddp_local_rank],
685
+ find_unused_parameters=True
686
+ )
687
+
688
+ def forward(self) -> Tuple[List, float]:
689
+ """Trains the DDIM model to predict noise added by the forward diffusion process.
690
+
691
+ Executes the training loop, optimizing the noise predictor and conditional model
692
+ (if applicable) using mixed precision, gradient clipping, and learning rate
693
+ scheduling. Supports validation, early stopping, and checkpointing.
694
+
695
+ Returns
696
+ -------
697
+ train_losses : list of float
698
+ List of mean training losses per epoch.
699
+ best_val_loss : float
700
+ Best validation or training loss achieved.
701
+ """
702
+
703
+ # set models to training mode
704
+ self.noise_predictor.train()
705
+ if self.conditional_model is not None:
706
+ self.conditional_model.train()
707
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
708
+ self.reverse_diffusion.train()
709
+ self.forward_diffusion.train()
710
+ else:
711
+ self.reverse_diffusion.eval()
712
+ self.forward_diffusion.eval()
713
+
714
+ # compile models for optimization (if supported)
715
+ if self.use_compilation:
716
+ try:
717
+ self.noise_predictor = torch.compile(self.noise_predictor)
718
+ if self.conditional_model is not None:
719
+ self.conditional_model = torch.compile(self.conditional_model)
720
+ except Exception as e:
721
+ if self.master_process:
722
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
723
+
724
+ # wrap models for DDP after compilation
725
+ self._wrap_models_for_ddp()
726
+
727
+ # initialize training components
728
+ scaler = torch.GradScaler()
729
+ train_losses = []
730
+ best_val_loss = float("inf")
731
+ wait = 0
732
+
733
+ # main training loop
734
+ for epoch in range(self.max_epochs):
735
+ # set epoch for distributed sampler if using DDP
736
+ if self.use_ddp and hasattr(self.data_loader.sampler, 'set_epoch'):
737
+ self.data_loader.sampler.set_epoch(epoch)
738
+
739
+ train_losses_epoch = []
740
+
741
+ # training step loop with gradient accumulation
742
+ for step, (x, y) in enumerate(tqdm(self.data_loader, disable=not self.master_process)):
743
+ x = x.to(self.device)
744
+
745
+ # process conditional inputs if conditional model exists
746
+ if self.conditional_model is not None:
747
+ y_encoded = self._process_conditional_input(y)
748
+ else:
749
+ y_encoded = None
750
+
751
+ # forward pass with mixed precision
752
+ with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
753
+ # generate noise and timesteps
754
+ noise = torch.randn_like(x).to(self.device)
755
+ t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
756
+
757
+ # apply forward diffusion
758
+ noisy_x = self.forward_diffusion(x, noise, t)
759
+
760
+ # predict noise
761
+ predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
762
+
763
+ # compute loss and scale for gradient accumulation
764
+ loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
765
+
766
+ # backward pass
767
+ scaler.scale(loss).backward()
768
+
769
+ # gradient accumulation and optimizer step
770
+ if (step + 1) % self.grad_accumulation_steps == 0:
771
+ # clip gradients
772
+ scaler.unscale_(self.optimizer)
773
+ torch.nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
774
+ if self.conditional_model is not None:
775
+ torch.nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
776
+
777
+ # optimizer step
778
+ scaler.step(self.optimizer)
779
+ scaler.update()
780
+ self.optimizer.zero_grad()
781
+
782
+ # update learning rate (warmup scheduler)
783
+ self.warmup_lr_scheduler.step()
784
+
785
+ # record loss (unscaled)
786
+ train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
787
+
788
+ # compute mean training loss
789
+ mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
790
+
791
+ # all-reduce loss across processes for DDP
792
+ if self.use_ddp:
793
+ loss_tensor = torch.tensor(mean_train_loss, device=self.device)
794
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
795
+ mean_train_loss = loss_tensor.item()
796
+
797
+ train_losses.append(mean_train_loss)
798
+
799
+ # print training progress (only master process)
800
+ if self.master_process and (epoch + 1) % self.log_frequency == 0:
801
+ current_lr = self.optimizer.param_groups[0]['lr']
802
+ print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
803
+
804
+ # validation step
805
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
806
+ val_metrics = self.validate()
807
+ val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
808
+
809
+ if self.master_process:
810
+ print(f" | Val Loss: {val_loss:.4f}", end="")
811
+ if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
812
+ print(f" | FID: {fid:.4f}", end="")
813
+ if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
814
+ print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
815
+ if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
816
+ print(f" | LPIPS: {lpips_score:.4f}", end="")
817
+ print()
818
+
819
+ current_best = val_loss
820
+ self.scheduler.step(val_loss)
821
+ else:
822
+ if self.master_process:
823
+ print()
824
+ current_best = mean_train_loss
825
+ self.scheduler.step(mean_train_loss)
826
+
827
+ # save checkpoint and early stopping (only master process)
828
+ if self.master_process:
829
+ if current_best < best_val_loss and (epoch + 1) % self.val_frequency == 0:
830
+ best_val_loss = current_best
831
+ wait = 0
832
+ self._save_checkpoint(epoch + 1, best_val_loss)
833
+ else:
834
+ wait += 1
835
+ if wait >= self.patience:
836
+ print("Early stopping triggered")
837
+ self._save_checkpoint(epoch + 1, best_val_loss, "_early_stop")
838
+ break
839
+
840
+ # clean up DDP
841
+ if self.use_ddp:
842
+ destroy_process_group()
843
+
844
+ return train_losses, best_val_loss
845
+
846
+ def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor:
847
+ """Process conditional input for text-to-image generation.
848
+
849
+ Parameters
850
+ ----------
851
+ y : torch.Tensor or list
852
+ Conditional input (text prompts).
853
+
854
+ Returns
855
+ -------
856
+ torch.Tensor
857
+ Encoded conditional input.
858
+ """
859
+ # convert to string list
860
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
861
+ y_list = [str(item) for item in y_list]
862
+
863
+ # tokenize
864
+ y_encoded = self.tokenizer(
865
+ y_list,
866
+ padding="max_length",
867
+ truncation=True,
868
+ max_length=self.max_token_length,
869
+ return_tensors="pt"
870
+ ).to(self.device)
871
+
872
+ # get embeddings
873
+ input_ids = y_encoded["input_ids"]
874
+ attention_mask = y_encoded["attention_mask"]
875
+ y_encoded = self.conditional_model(input_ids, attention_mask)
876
+
877
+ return y_encoded
878
+ def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "") -> None:
879
+ """Save model checkpoint (only called by master process).
880
+
881
+ Parameters
882
+ ----------
883
+ epoch : int
884
+ Current epoch number.
885
+ loss : float
886
+ Current loss value.
887
+ suffix : str, optional
888
+ Suffix to add to checkpoint filename.
889
+ """
890
+ try:
891
+ # get state dicts, handling DDP wrapping
892
+ noise_predictor_state = (
893
+ self.noise_predictor.module.state_dict() if self.use_ddp
894
+ else self.noise_predictor.state_dict()
895
+ )
896
+ conditional_state = None
897
+ if self.conditional_model is not None:
898
+ conditional_state = (
899
+ self.conditional_model.module.state_dict() if self.use_ddp
900
+ else self.conditional_model.state_dict()
901
+ )
902
+
903
+ checkpoint = {
904
+ 'epoch': epoch,
905
+ 'model_state_dict_noise_predictor': noise_predictor_state,
906
+ 'model_state_dict_conditional': conditional_state,
907
+ 'optimizer_state_dict': self.optimizer.state_dict(),
908
+ 'loss': loss,
909
+ 'variance_scheduler_model': (
910
+ self.forward_diffusion.variance_scheduler.state_dict() if isinstance(
911
+ self.forward_diffusion.variance_scheduler, nn.Module)
912
+ else self.forward_diffusion.variance_scheduler
913
+ ),
914
+ 'max_epochs': self.max_epochs,
915
+ }
916
+
917
+ filename = f"ddim_epoch_{epoch}{suffix}.pth"
918
+ filepath = os.path.join(self.store_path, filename)
919
+ os.makedirs(self.store_path, exist_ok=True)
920
+ torch.save(checkpoint, filepath)
921
+
922
+ print(f"Model saved at epoch {epoch}")
923
+
924
+ except Exception as e:
925
+ print(f"Failed to save model: {e}")
926
+
927
+ def validate(self) -> Tuple[float, float, float, float, float, float]:
928
+ """Validates the noise predictor and computes evaluation Metrics.
929
+
930
+ Computes validation loss (MSE between predicted and ground truth noise) and generates
931
+ samples using the reverse diffusion model by manually iterating over timesteps.
932
+ Decodes samples to images and computes image-domain Metrics (MSE, PSNR, SSIM, FID, LPIPS)
933
+ if metrics_ is provided.
934
+
935
+ Returns
936
+ -------
937
+ val_loss : float
938
+ Mean validation loss.
939
+ fid : float, or `float('inf')` if not computed
940
+ Mean FID score.
941
+ mse : float, or None if not computed
942
+ Mean MSE
943
+ psnr : float, or None if not computed
944
+ Mean PSNR
945
+ ssim : float, or None if not computed
946
+ Mean SSIM
947
+ lpips_score : float, or None if not computed
948
+ Mean LPIPS score
949
+ """
950
+
951
+ self.noise_predictor.eval()
952
+ if self.conditional_model is not None:
953
+ self.conditional_model.eval()
954
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
955
+ self.forward_diffusion.eval()
956
+ self.reverse_diffusion.eval()
957
+
958
+ val_losses = []
959
+ fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
960
+
961
+ with torch.no_grad():
962
+ for x, y in self.val_loader:
963
+ x = x.to(self.device)
964
+ x_orig = x.clone()
965
+
966
+ # process conditional input
967
+ if self.conditional_model is not None:
968
+ y_encoded = self._process_conditional_input(y)
969
+ else:
970
+ y_encoded = None
971
+
972
+ # compute validation loss
973
+ noise = torch.randn_like(x).to(self.device)
974
+ t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
975
+
976
+ noisy_x = self.forward_diffusion(x, noise, t)
977
+ predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
978
+ loss = self.objective(predicted_noise, noise)
979
+ val_losses.append(loss.item())
980
+
981
+ # generate samples for metrics evaluation
982
+ if self.metrics_ is not None and self.reverse_diffusion is not None:
983
+ xt = torch.randn_like(x).to(self.device)
984
+
985
+ # reverse diffusion sampling
986
+ for t in reversed(range(self.forward_diffusion.variance_scheduler.tau_num_steps)):
987
+ time_steps = torch.full((xt.shape[0],), t, device=self.device)#, dtype=torch.long)
988
+ prev_time_steps = torch.full((xt.shape[0],), max(t - 1, 0), device=self.device)#, dtype=torch.long)
989
+ predicted_noise = self.noise_predictor(xt, time_steps, y_encoded, None)
990
+ xt, _ = self.reverse_diffusion(xt, predicted_noise, time_steps, prev_time_steps)
991
+
992
+ # clamp and normalize generated samples
993
+ x_hat = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
994
+ if self.normalize_output:
995
+ x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
996
+ x_orig = (x_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
997
+
998
+ # compute metrics
999
+ metrics_result = self.metrics_.forward(x_orig, x_hat)
1000
+ fid, mse, psnr, ssim, lpips_score = metrics_result
1001
+
1002
+ if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
1003
+ fid_scores.append(fid)
1004
+ if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
1005
+ mse_scores.append(mse)
1006
+ psnr_scores.append(psnr)
1007
+ ssim_scores.append(ssim)
1008
+ if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
1009
+ lpips_scores.append(lpips_score)
1010
+
1011
+ # compute average metrics
1012
+ val_loss = torch.tensor(val_losses).mean().item()
1013
+
1014
+ # all-reduce validation metrics across processes for DDP
1015
+ if self.use_ddp:
1016
+ val_loss_tensor = torch.tensor(val_loss, device=self.device)
1017
+ dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
1018
+ val_loss = val_loss_tensor.item()
1019
+
1020
+ fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
1021
+ mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
1022
+ psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
1023
+ ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
1024
+ lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
1025
+
1026
+ # return to training mode
1027
+ self.noise_predictor.train()
1028
+ if self.conditional_model is not None:
1029
+ self.conditional_model.train()
1030
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
1031
+ self.reverse_diffusion.train()
1032
+ self.forward_diffusion.train()
1033
+
1034
+ return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
1035
+
1036
+ ###==================================================================================================================###
1037
+
1038
+ class SampleDDIM(nn.Module):
1039
+ """Image generation using a trained DDIM model.
1040
+
1041
+ Implements the sampling process for DDIM, generating images by iteratively denoising
1042
+ random noise using a trained noise predictor and reverse diffusion process with a
1043
+ subsampled time step schedule. Supports conditional generation with text prompts,
1044
+ as inspired by Song et al. (2021).
1045
+
1046
+ Parameters
1047
+ ----------
1048
+ `reverse_diffusion` : nn.Module
1049
+ Reverse diffusion module (e.g., ReverseDDIM) for the reverse process.
1050
+ `noise_predictor` : nn.Module
1051
+ Trained model to predict noise at each time step.
1052
+ `image_shape` : tuple
1053
+ Tuple of (height, width) specifying the generated image dimensions.
1054
+ `conditional_model` : nn.Module, optional
1055
+ Model for conditional generation (e.g., text embeddings), default None.
1056
+ `tokenizer` : str, optional
1057
+ Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased").
1058
+ `max_length` : int, optional
1059
+ Maximum length for tokenized prompts (default: 77).
1060
+ `batch_size` : int, optional
1061
+ Number of images to generate per batch (default: 1).
1062
+ `in_channels` : int, optional
1063
+ Number of input channels for generated images (default: 3).
1064
+ `device` : torch.device, optional
1065
+ Device for computation (default: CUDA if available, else CPU).
1066
+ `output_range` : tuple, optional
1067
+ Tuple of (min, max) for clamping generated images (default: (-1, 1)).
1068
+ """
1069
+ def __init__(
1070
+ self,
1071
+ reverse_diffusion: torch.nn.Module,
1072
+ noise_predictor: torch.nn.Module,
1073
+ image_shape: Tuple[int, int],
1074
+ conditional_model: Optional[torch.nn.Module] = None,
1075
+ bert_tokenizer: str = "bert-base-uncased",
1076
+ max_token_length: int = 77,
1077
+ batch_size: int = 1,
1078
+ in_channels: int = 3,
1079
+ device: Optional[str] = None,
1080
+ image_output_range: Tuple[float, float] = (-1.0, 1.0)
1081
+ ) -> None:
1082
+ super().__init__()
1083
+ if device is None:
1084
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1085
+ elif isinstance(device, str):
1086
+ self.device = torch.device(device)
1087
+ else:
1088
+ self.device = device
1089
+ self.reverse = reverse_diffusion.to(self.device)
1090
+ self.noise_predictor = noise_predictor.to(self.device)
1091
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
1092
+ self.tokenizer = BertTokenizer.from_pretrained(bert_tokenizer)
1093
+ self.max_token_length = max_token_length
1094
+ self.in_channels = in_channels
1095
+ self.image_shape = image_shape
1096
+ self.batch_size = batch_size
1097
+ self.image_output_range = image_output_range
1098
+
1099
+ if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(
1100
+ isinstance(s, int) and s > 0 for s in image_shape):
1101
+ raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
1102
+ if batch_size <= 0:
1103
+ raise ValueError("batch_size must be positive")
1104
+ if not isinstance(image_output_range, (tuple, list)) or len(image_output_range) != 2 or image_output_range[0] >= image_output_range[1]:
1105
+ raise ValueError("image_output_range must be a tuple (min, max) with min < max")
1106
+
1107
+
1108
+ def tokenize(self, prompts: Union[List, str]) -> Tuple[torch.Tensor, torch.Tensor]:
1109
+ """Tokenizes text prompts for conditional generation.
1110
+
1111
+ Converts input prompts into tokenized input IDs and attention masks using the
1112
+ specified tokenizer, suitable for use with the conditional model.
1113
+
1114
+ Parameters
1115
+ ----------
1116
+ `prompts` : str or list
1117
+ A single text prompt or a list of text prompts.
1118
+
1119
+ Returns
1120
+ -------
1121
+ input_ids : torch.Tensor
1122
+ Tokenized input IDs, shape (batch_size, max_length).
1123
+ attention_mask : torch.Tensor
1124
+ Attention mask, shape (batch_size, max_length).
1125
+ """
1126
+ if isinstance(prompts, str):
1127
+ prompts = [prompts]
1128
+ elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
1129
+ raise TypeError("prompts must be a string or list of strings")
1130
+ encoded = self.tokenizer(
1131
+ prompts,
1132
+ padding="max_length",
1133
+ truncation=True,
1134
+ max_length=self.max_token_length,
1135
+ return_tensors="pt"
1136
+ )
1137
+ return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
1138
+
1139
+ def forward(self, conditions: Optional[Union[str, List]] = None, normalize_output: bool = True, save_images: bool = True, save_path: str = "ddim_generated") -> torch.Tensor:
1140
+ """Generates images using the DDIM sampling process.
1141
+
1142
+ Iteratively denoises random noise to generate images using the reverse diffusion
1143
+ process with a subsampled time step schedule and noise predictor. Supports
1144
+ conditional generation with text prompts.
1145
+
1146
+ Parameters
1147
+ ----------
1148
+ `conditions` : str or list, optional
1149
+ Text prompt(s) for conditional generation, default None.
1150
+ `normalize_output` : bool, optional
1151
+ If True, normalizes output images to [0, 1] (default: True).
1152
+ `save_images` : bool, optional
1153
+ If True, saves generated images to `save_path` (default: True).
1154
+ `save_path` : str, optional
1155
+ Directory to save generated images (default: "ddim_generated").
1156
+
1157
+ Returns
1158
+ -------
1159
+ generated_imgs (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width). If `normalize_output` is True, images are normalized to [0, 1]; otherwise, they are clamped to `output_range`.
1160
+ """
1161
+
1162
+ if conditions is not None and self.conditional_model is None:
1163
+ raise ValueError("Conditions provided but no conditional model specified")
1164
+ if conditions is None and self.conditional_model is not None:
1165
+ raise ValueError("Conditions must be provided for conditional model")
1166
+
1167
+ noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
1168
+
1169
+ self.noise_predictor.eval()
1170
+ self.reverse.eval()
1171
+ if self.conditional_model:
1172
+ self.conditional_model.eval()
1173
+
1174
+ with torch.no_grad():
1175
+ xt = noisy_samples
1176
+ for t in reversed(range(self.reverse.variance_scheduler.tau_num_steps)):
1177
+ time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
1178
+ prev_time_steps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device, dtype=torch.long)
1179
+
1180
+ if self.conditional_model is not None and conditions is not None:
1181
+ input_ids, attention_masks = self.tokenize(conditions)
1182
+ key_padding_mask = (attention_masks == 0)
1183
+ y = self.conditional_model(input_ids, key_padding_mask)
1184
+ predicted_noise = self.noise_predictor(xt, time_steps, y, None)
1185
+ else:
1186
+ predicted_noise = self.noise_predictor(xt, time_steps, None)
1187
+
1188
+ xt, _ = self.reverse(xt, predicted_noise, time_steps, prev_time_steps)
1189
+
1190
+ generated_imgs = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
1191
+ if normalize_output:
1192
+ generated_imgs = (generated_imgs - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
1193
+
1194
+ if save_images:
1195
+ os.makedirs(save_path, exist_ok=True) # create directory if it doesn't exist
1196
+ for i in range(generated_imgs.size(0)):
1197
+ img_path = os.path.join(save_path, f"image_{i+1}.png")
1198
+ save_image(generated_imgs[i], img_path)
1199
+
1200
+ return generated_imgs
1201
+
1202
+ def to(self, device: torch.device) -> Self:
1203
+ """Moves the module and its components to the specified device.
1204
+
1205
+ Updates the device attribute and moves the reverse diffusion, noise predictor,
1206
+ and conditional model (if present) to the specified device.
1207
+
1208
+ Parameters
1209
+ ----------
1210
+ `device` : torch.device
1211
+ Target device for the module and its components.
1212
+
1213
+ Returns
1214
+ -------
1215
+ sample_ddim (SampleDDIM) - moved to the specified device.
1216
+ """
1217
+ self.device = device
1218
+ self.noise_predictor.to(device)
1219
+ self.reverse.to(device)
1220
+ if self.conditional_model:
1221
+ self.conditional_model.to(device)
1222
+ return super().to(device)