TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
torchdiff/ddpm.py ADDED
@@ -0,0 +1,1153 @@
1
+ """
2
+ **Denoising Diffusion Probabilistic Models (DDPM) implementation**
3
+
4
+ This module provides a complete implementation of DDPM, as described in Ho et al.
5
+ (2020, "Denoising Diffusion Probabilistic Models"). It includes components for forward
6
+ and reverse diffusion processes, hyperparameter management, training, and image
7
+ sampling. Supports both unconditional and conditional generation with text prompts.
8
+
9
+ **Components**
10
+
11
+ - **ForwardDDPM**: Forward diffusion process to add noise.
12
+ - **ReverseDDPM**: Reverse diffusion process to denoise.
13
+ - **VarianceSchedulerDDPM**: Noise schedule management.
14
+ - **TrainDDPM**: Training loop with mixed precision and scheduling.
15
+ - **SampleDDPM**: Image generation from trained models.
16
+
17
+ **References**
18
+
19
+ - Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models.
20
+
21
+ - Salimans, Tim, et al. "Pixelcnn++: Improving the pixelcnn with discretized logistic mixture likelihood and other modifications."
22
+ arXiv preprint arXiv:1701.05517 (2017).
23
+
24
+ -------------------------------------------------------------------------------
25
+ """
26
+
27
+
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ from typing import Optional, Tuple, Callable, List, Any, Union, Self
32
+ from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
33
+ import torch.distributed as dist
34
+ from torch.nn.parallel import DistributedDataParallel as DDP
35
+ from torch.distributed import init_process_group, destroy_process_group
36
+ from tqdm import tqdm
37
+ from transformers import BertTokenizer
38
+ import warnings
39
+ from torchvision.utils import save_image
40
+ import os
41
+
42
+
43
+ ###==================================================================================================================###
44
+
45
+
46
+ class ForwardDDPM(nn.Module):
47
+ """Forward diffusion process for Denoising Diffusion Probabilistic Models (DDPM).
48
+
49
+ Implements the forward diffusion process for DDPM, which perturbs input data by
50
+ adding Gaussian noise over a series of time steps, as defined in Ho et al. (2020).
51
+ The noise schedule can be either fixed or trainable, depending on the provided
52
+ hyperparameters.
53
+
54
+ Parameters
55
+ ----------
56
+ variance_scheuler : object
57
+ Hyperparameter object (VarianceSchedulerDDPM) containing the noise schedule parameters. Expected to have
58
+ attributes: `num_steps`, `trainable_beta`, `betas`, `sqrt_alpha_bars`, `sqrt_one_minus_alpha_bars`, `compute_schedule`.
59
+ """
60
+ def __init__(self, variance_scheduler: torch.nn.Module) -> None:
61
+ super().__init__()
62
+ self.variance_scheduler = variance_scheduler
63
+
64
+ def forward(self, x0: torch.Tensor, noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
65
+ """Applies the forward diffusion process to the input data.
66
+
67
+ Perturbs the input data `x0` by adding Gaussian noise according to the DDPM
68
+ forward process at specified time steps. Uses the reparameterization trick:
69
+ x_t = sqrt(ᾱ_t) * x_0 + sqrt(1 - ᾱ_t) * ε.
70
+
71
+ Parameters
72
+ ----------
73
+ x0 : torch.Tensor
74
+ Input data tensor of shape (batch_size, channels, height, width).
75
+ noise : torch.Tensor
76
+ Gaussian noise tensor of the same shape as `x0`.
77
+ time_steps : torch.Tensor
78
+ Tensor of time step indices (long), shape (batch_size,), where each value
79
+ is in the range [0, hyper_params.num_steps - 1].
80
+
81
+ Returns
82
+ -------
83
+ xt : torch.Tensor
84
+ Noisy data tensor `xt` at the specified time steps, with the same shape as `x0`.
85
+ """
86
+ if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.num_steps)):
87
+ raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.num_steps - 1}")
88
+
89
+ if self.variance_scheduler.trainable_beta:
90
+ _, _, _, sqrt_alpha_bar_t, sqrt_one_minus_alpha_bar_t = self.variance_scheduler.compute_schedule(time_steps)
91
+ sqrt_alpha_bar_t = sqrt_alpha_bar_t.to(x0.device)
92
+ sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.to(x0.device)
93
+ else:
94
+ sqrt_alpha_bar_t = self.variance_scheduler.sqrt_alpha_bars[time_steps].to(x0.device)
95
+ sqrt_one_minus_alpha_bar_t = self.variance_scheduler.sqrt_one_minus_alpha_bars[time_steps].to(x0.device)
96
+
97
+ sqrt_alpha_bar_t = sqrt_alpha_bar_t.view(-1, 1, 1, 1)
98
+ sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.view(-1, 1, 1, 1)
99
+ xt = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise
100
+
101
+ return xt
102
+
103
+ ###==================================================================================================================###
104
+
105
+
106
+ class ReverseDDPM(nn.Module):
107
+ """Reverse diffusion process for Denoising Diffusion Probabilistic Models (DDPM).
108
+
109
+ Implements the reverse diffusion process for DDPM, which iteratively denoises a
110
+ noisy input `xt` using a predicted noise component, as defined in Ho et al. (2020).
111
+ The process relies on a noise schedule that can be either fixed or trainable,
112
+ specified through the provided hyperparameters.
113
+
114
+ Parameters
115
+ ----------
116
+ variance_scheduler : object
117
+ Hyperparameter object (VarianceSchedulerDDPM) containing the noise schedule parameters. Expected to have
118
+ attributes: `num_steps`, `trainable_beta`, `betas`, `alphas`, `alpha_bars`, `compute_schedule`.
119
+ """
120
+ def __init__(self, variance_scheduler: torch.nn.Module) -> None:
121
+ super().__init__()
122
+ self.variance_scheduler = variance_scheduler
123
+
124
+ def forward(self, xt: torch.Tensor, predicted_noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
125
+ """Applies the reverse diffusion process to the noisy input.
126
+
127
+ Denoises the input `xt` by computing the mean of the reverse process
128
+ distribution using the predicted noise and optionally adding stochastic noise
129
+ for time steps greater than 0, as per the DDPM reverse process.
130
+
131
+ Parameters
132
+ ----------
133
+ xt : torch.Tensor
134
+ Noisy input tensor at time step `t`, of shape (batch_size, channels, height, width).
135
+ predicted_noise : torch.Tensor
136
+ Predicted noise tensor, of the same shape as `xt`, typically output by a neural network.
137
+ time_steps : torch.Tensor
138
+ Tensor of time step indices (long), shape (batch_size,), where each value
139
+ is in the range [0, hyper_params.num_steps - 1].
140
+
141
+ Returns
142
+ -------
143
+ xt_minus_1 : torch.Tensor
144
+ Denoised tensor `xt_minus_1` at time step `t-1`, with the same shape as `xt`.
145
+ """
146
+ if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.num_steps)):
147
+ raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.num_steps - 1}")
148
+
149
+ if self.variance_scheduler.trainable_beta:
150
+ betas_t, alphas_t, alpha_bars_t, _, _ = self.variance_scheduler.compute_schedule(time_steps)
151
+ betas_t = betas_t.to(xt.device)
152
+ alphas_t = alphas_t.to(xt.device)
153
+ alpha_bars_t = alpha_bars_t.to(xt.device)
154
+ alpha_bars_t_minus_1 = torch.zeros_like(alpha_bars_t).to(xt.device)
155
+ non_zero_mask = time_steps > 0
156
+ if non_zero_mask.any():
157
+ _, _, alpha_bars_t_minus_1_tmp, _, _ = self.variance_scheduler.compute_schedule(time_steps[non_zero_mask] - 1)
158
+ alpha_bars_t_minus_1[non_zero_mask] = alpha_bars_t_minus_1_tmp.to(xt.device)
159
+ else:
160
+ betas_t = self.variance_scheduler.betas[time_steps].to(xt.device)
161
+ alphas_t = self.variance_scheduler.alphas[time_steps].to(xt.device)
162
+ alpha_bars_t = self.variance_scheduler.alpha_bars[time_steps].to(xt.device)
163
+ alpha_bars_t_minus_1 = torch.zeros_like(alpha_bars_t).to(xt.device)
164
+ non_zero_mask = time_steps > 0
165
+ if non_zero_mask.any():
166
+ alpha_bars_t_minus_1[non_zero_mask] = self.variance_scheduler.alpha_bars[time_steps[non_zero_mask] - 1].to(xt.device)
167
+
168
+ sqrt_alphas_t = torch.sqrt(alphas_t).view(-1, 1, 1, 1)
169
+ sqrt_one_minus_alpha_bars_t = torch.sqrt(1 - alpha_bars_t).view(-1, 1, 1, 1)
170
+ betas_t = betas_t.view(-1, 1, 1, 1)
171
+
172
+ mu = (xt - (betas_t / sqrt_one_minus_alpha_bars_t) * predicted_noise) / sqrt_alphas_t
173
+
174
+ mask = (time_steps == 0)
175
+ if mask.all():
176
+ return mu
177
+
178
+ variance = (1 - alpha_bars_t_minus_1) / (1 - alpha_bars_t) * betas_t.squeeze()
179
+ std = torch.sqrt(variance).view(-1, 1, 1, 1)
180
+
181
+ z = torch.randn_like(xt).to(xt.device)
182
+ xt_minus_1 = mu + (~mask).float().view(-1, 1, 1, 1) * std * z
183
+ return xt_minus_1
184
+
185
+
186
+ ###==================================================================================================================###
187
+
188
+
189
+ class VarianceSchedulerDDPM(nn.Module):
190
+ """Hyperparameters for Denoising Diffusion Probabilistic Models (DDPM) noise schedule.
191
+
192
+ Manages the noise schedule parameters for DDPM, including the computation of beta
193
+ values and derived quantities (alphas, alpha_bars, etc.), with support for
194
+ trainable or fixed schedules and various beta scheduling methods, as inspired by
195
+ Ho et al. (2020).
196
+
197
+ Parameters
198
+ ----------
199
+ num_steps : int, optional
200
+ Number of diffusion steps (default: 1000).
201
+ beta_start : float, optional
202
+ Starting value for beta (default: 1e-4).
203
+ beta_end : float, optional
204
+ Ending value for beta (default: 0.02).
205
+ trainable_beta : bool, optional
206
+ Whether the beta schedule is trainable (default: False).
207
+ beta_method : str, optional
208
+ Method for computing the beta schedule (default: "linear").
209
+ Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
210
+ """
211
+ def __init__(self, num_steps: int = 1000, beta_start: float = 1e-4, beta_end: float = 0.02, trainable_beta: bool = False, beta_method: str = "linear") -> None:
212
+ super().__init__()
213
+ self.num_steps = num_steps
214
+ self.beta_start = beta_start
215
+ self.beta_end = beta_end
216
+ self.trainable_beta = trainable_beta
217
+ self.beta_method = beta_method
218
+
219
+ if not (0 < beta_start < beta_end < 1):
220
+ raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
221
+ if num_steps <= 0:
222
+ raise ValueError(f"num_steps ({num_steps}) must be positive")
223
+
224
+ beta_range = (beta_start, beta_end)
225
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
226
+
227
+ if trainable_beta:
228
+ self.betas = nn.Parameter(torch.log(betas_init))
229
+ else:
230
+ self.register_buffer('betas', betas_init)
231
+ self.register_buffer('alphas', 1 - self.betas)
232
+ self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
233
+ self.register_buffer('sqrt_alpha_bars', torch.sqrt(self.alpha_bars))
234
+ self.register_buffer('sqrt_one_minus_alpha_bars', torch.sqrt(1 - self.alpha_bars))
235
+
236
+ def compute_beta_schedule(self, beta_range: Tuple[float, float], num_steps: int, method: str) -> torch.Tensor:
237
+ """Computes the beta schedule based on the specified method.
238
+
239
+ Generates a sequence of beta values for the DDPM noise schedule using the
240
+ chosen method, ensuring values are clamped within the specified range.
241
+
242
+ Parameters
243
+ ----------
244
+ beta_range : tuple
245
+ Tuple of (min_beta, max_beta) specifying the valid range for beta values.
246
+ num_steps : int
247
+ Number of diffusion steps.
248
+ method : str
249
+ Method for computing the beta schedule. Supported methods:
250
+ "linear", "sigmoid", "quadratic", "constant", "inverse_time".
251
+
252
+ Returns
253
+ -------
254
+ beta (torch.Tensor) - Tensor of beta values, shape (num_steps,).
255
+ """
256
+ beta_min, beta_max = beta_range
257
+ if method == "sigmoid":
258
+ x = torch.linspace(-6, 6, num_steps)
259
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
260
+ elif method == "quadratic":
261
+ x = torch.linspace(beta_min**0.5, beta_max**0.5, num_steps)
262
+ beta = x**2
263
+ elif method == "constant":
264
+ beta = torch.full((num_steps,), beta_max)
265
+ elif method == "inverse_time":
266
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
267
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
268
+ elif method == "linear":
269
+ beta = torch.linspace(beta_min, beta_max, num_steps)
270
+ else:
271
+ raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
272
+
273
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
274
+ return beta
275
+
276
+ def compute_schedule(self, time_steps: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
277
+ """Computes noise schedule parameters dynamically from betas.
278
+
279
+ Parameters-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
280
+ ----------
281
+ time_steps : torch.Tensor, optional
282
+ Tensor of time step indices (long), shape (batch_size,). If None, returns parameters for all steps.
283
+
284
+ Returns
285
+ -------
286
+ betas, alphas, alpha_bars, sqrt_alpha_bars, sqrt_one_minus_alpha_bars : torch.Tensor
287
+ Schedule parameters, shape (batch_size,) if time_steps is provided, else (num_steps,).
288
+ """
289
+ if self.trainable_beta:
290
+ # Compute betas from trainable log_betas using sigmoid
291
+ betas = torch.sigmoid(self.betas) * (self.beta_end - self.beta_start) + self.beta_start
292
+ else:
293
+ betas = self.betas
294
+
295
+ alphas = 1 - betas
296
+ alpha_bars = torch.cumprod(alphas, dim=0)
297
+
298
+ if time_steps is not None:
299
+ betas = betas[time_steps]
300
+ alphas = alphas[time_steps]
301
+ alpha_bars = alpha_bars[time_steps]
302
+
303
+ return betas, alphas, alpha_bars, torch.sqrt(alpha_bars), torch.sqrt(1 - alpha_bars)
304
+
305
+
306
+ ###==================================================================================================================###
307
+
308
+
309
+ class TrainDDPM(nn.Module):
310
+ """Trainer for Denoising Diffusion Probabilistic Models (DDPM) with Multi-GPU Support.
311
+
312
+ Manages the training process for DDPM, optimizing a noise predictor model to learn
313
+ the noise added by the forward diffusion process. Supports conditional training with
314
+ text prompts, mixed precision training, learning rate scheduling, early stopping,
315
+ checkpointing, and distributed data parallel (DDP) training across multiple GPUs.
316
+
317
+ Parameters
318
+ ----------
319
+ noise_predictor : nn.Module
320
+ Model to predict noise added during the forward diffusion process.
321
+ forward_diffusion : nn.Module
322
+ Forward DDPM diffusion module for adding noise.
323
+ reverse_diffusion: nn.Module
324
+ Reverse DDPM diffusion module for denoising.
325
+ data_loader : torch.utils.data.DataLoader
326
+ DataLoader for training data. Should be wrapped with DistributedSampler for DDP.
327
+ optimizer : torch.optim.Optimizer
328
+ Optimizer for training the noise predictor and conditional model (if applicable).
329
+ objective : callable
330
+ Loss function to compute the difference between predicted and actual noise.
331
+ val_loader : torch.utils.data.DataLoader, optional
332
+ DataLoader for validation data, default None.
333
+ max_epochs : int, optional
334
+ Maximum number of training epochs (default: 1000).
335
+ device : torch.device, optional
336
+ Device for computation (default: CUDA if available, else CPU).
337
+ conditional_model : nn.Module, optional
338
+ Model for conditional generation (e.g., text embeddings), default None.
339
+ metrics_ : object, optional
340
+ Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
341
+ tokenizer : BertTokenizer, optional
342
+ Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
343
+ max_token_length : int, optional
344
+ Maximum length for tokenized prompts (default: 77).
345
+ store_path : str, optional
346
+ Path to save model checkpoints (default: "ddpm_model").
347
+ patience : int, optional
348
+ Number of epochs to wait for improvement before early stopping (default: 100).
349
+ warmup_epochs : int, optional
350
+ Number of epochs for learning rate warmup (default: 100).
351
+ val_frequency : int, optional
352
+ Frequency (in epochs) for validation (default: 10).
353
+ image_output_range : tuple, optional
354
+ Range for clamping generated images (default: (-1, 1)).
355
+ normalize_output : bool, optional
356
+ Whether to normalize generated images to [0, 1] for metrics (default: True).
357
+ use_ddp : bool, optional
358
+ Whether to use Distributed Data Parallel training (default: False).
359
+ grad_accumulation_steps : int, optional
360
+ Number of gradient accumulation steps before optimizer update (default: 1).
361
+ log_frequency : int, optional
362
+ Number of epochs before printing loss.
363
+ use_compilation : bool, optional
364
+ whether the model is internally compiled using torch.compile (default: false)
365
+ """
366
+
367
+ def __init__(
368
+ self,
369
+ noise_predictor: torch.nn.Module,
370
+ forward_diffusion: torch.nn.Module,
371
+ reverse_diffusion: torch.nn.Module,
372
+ data_loader: torch.utils.data.DataLoader,
373
+ optimizer: torch.optim.Optimizer,
374
+ objective: Callable,
375
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
376
+ max_epochs: int = 1000,
377
+ device: Optional[Union[str, torch.device]] = None,
378
+ conditional_model: torch.nn.Module = None,
379
+ metrics_: Optional[Any] = None,
380
+ bert_tokenizer: Optional[BertTokenizer] = None,
381
+ max_token_length: int = 77,
382
+ store_path: Optional[str] = None,
383
+ patience: int = 100,
384
+ warmup_epochs: int = 100,
385
+ val_frequency: int = 10,
386
+ image_output_range: Tuple[float, float] = (-1.0, 1.0),
387
+ normalize_output: bool = True,
388
+ use_ddp: bool = False,
389
+ grad_accumulation_steps: int = 1,
390
+ log_frequency: int = 1,
391
+ use_compilation: bool = False
392
+ ) -> None:
393
+ super().__init__()
394
+
395
+ # initialize DDP settings first
396
+ self.use_ddp = use_ddp
397
+ self.grad_accumulation_steps = grad_accumulation_steps
398
+ if device is None:
399
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
400
+ elif isinstance(device, str):
401
+ self.device = torch.device(device)
402
+ else:
403
+ self.device = device
404
+
405
+ # setup distributed training if enabled
406
+ if self.use_ddp:
407
+ self._setup_ddp()
408
+ else:
409
+ self._setup_single_gpu()
410
+
411
+ # move models to appropriate device
412
+ self.noise_predictor = noise_predictor.to(self.device)
413
+ self.forward_diffusion = forward_diffusion.to(self.device)
414
+ self.reverse_diffusion = reverse_diffusion.to(self.device)
415
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
416
+
417
+ # training components
418
+ self.metrics_ = metrics_
419
+ self.optimizer = optimizer
420
+ self.objective = objective
421
+ self.store_path = store_path or "ddpm_model"
422
+ self.data_loader = data_loader
423
+ self.val_loader = val_loader
424
+ self.max_epochs = max_epochs
425
+ self.max_token_length = max_token_length
426
+ self.patience = patience
427
+ self.val_frequency = val_frequency
428
+ self.image_output_range = image_output_range
429
+ self.normalize_output = normalize_output
430
+ self.log_frequency = log_frequency
431
+ self.use_compilation = use_compilation
432
+
433
+ # learning rate scheduling
434
+ self.scheduler = ReduceLROnPlateau(
435
+ self.optimizer,
436
+ patience=self.patience,
437
+ factor=0.5
438
+ )
439
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
440
+
441
+ # initialize tokenizer
442
+ if bert_tokenizer is None:
443
+ try:
444
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
445
+ except Exception as e:
446
+ raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
447
+ else:
448
+ self.tokenizer = bert_tokenizer
449
+
450
+ def _setup_ddp(self) -> None:
451
+ """Setup Distributed Data Parallel training configuration.
452
+
453
+ Initializes process group, determines rank information, and sets up
454
+ CUDA device for the current process.
455
+ """
456
+ # check if DDP environment variables are set
457
+ if "RANK" not in os.environ:
458
+ raise ValueError("DDP enabled but RANK environment variable not set")
459
+ if "LOCAL_RANK" not in os.environ:
460
+ raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
461
+ if "WORLD_SIZE" not in os.environ:
462
+ raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
463
+
464
+ # ensure CUDA is available for DDP
465
+ if not torch.cuda.is_available():
466
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
467
+
468
+ # initialize process group only if not already initialized
469
+ if not torch.distributed.is_initialized():
470
+ init_process_group(backend="nccl")
471
+
472
+ # get rank information
473
+ self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
474
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
475
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
476
+
477
+ # set device and make it current
478
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
479
+ torch.cuda.set_device(self.device)
480
+
481
+ # master process handles logging, checkpointing, etc.
482
+ self.master_process = self.ddp_rank == 0
483
+
484
+ if self.master_process:
485
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
486
+
487
+ def _setup_single_gpu(self) -> None:
488
+ """Setup single GPU or CPU training configuration."""
489
+ self.ddp_rank = 0
490
+ self.ddp_local_rank = 0
491
+ self.ddp_world_size = 1
492
+ self.master_process = True
493
+
494
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
495
+ """Loads a training checkpoint to resume training.
496
+
497
+ Restores the state of the noise predictor, conditional model (if applicable),
498
+ and optimizer from a saved checkpoint. Handles DDP model state dict loading.
499
+
500
+ Parameters
501
+ ----------
502
+ checkpoint_path : str
503
+ Path to the checkpoint file.
504
+
505
+ Returns
506
+ -------
507
+ epoch : int
508
+ The epoch at which the checkpoint was saved.
509
+ loss : float
510
+ The loss at the checkpoint.
511
+ """
512
+ try:
513
+ # load checkpoint with proper device mapping
514
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
515
+ except FileNotFoundError:
516
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
517
+
518
+ # load noise predictor state
519
+ if 'model_state_dict_noise_predictor' not in checkpoint:
520
+ raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
521
+
522
+ # handle DDP wrapped model state dict
523
+ state_dict = checkpoint['model_state_dict_noise_predictor']
524
+ if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
525
+ # if loading non-DDP checkpoint into DDP model, add 'module.' prefix
526
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
527
+ elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
528
+ # if loading DDP checkpoint into non-DDP model, remove 'module.' prefix
529
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
530
+
531
+ self.noise_predictor.load_state_dict(state_dict)
532
+
533
+ # load conditional model state if applicable
534
+ if self.conditional_model is not None:
535
+ if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
536
+ cond_state_dict = checkpoint['model_state_dict_conditional']
537
+ # handle DDP wrapping for conditional model
538
+ if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
539
+ cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
540
+ elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
541
+ cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
542
+ self.conditional_model.load_state_dict(cond_state_dict)
543
+ else:
544
+ warnings.warn(
545
+ "Checkpoint contains no 'model_state_dict_conditional' or it is None, "
546
+ "skipping conditional model loading"
547
+ )
548
+
549
+ # load variance_scheduler state
550
+ if 'variance_scheduler_model' not in checkpoint:
551
+ raise KeyError("Checkpoint missing 'variance_scheduler_model' key")
552
+ try:
553
+ if isinstance(self.forward_diffusion.variance_scheduler, nn.Module):
554
+ self.forward_diffusion.variance_scheduler.load_state_dict(
555
+ checkpoint['variance_scheduler_model'])
556
+ if isinstance(self.reverse_diffusion.variance_scheduler, nn.Module):
557
+ self.reverse_diffusion.variance_scheduler.load_state_dict(
558
+ checkpoint['variance_scheduler_model'])
559
+ else:
560
+ self.forward_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
561
+ self.reverse_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
562
+ except Exception as e:
563
+ warnings.warn(
564
+ f"Variance_scheduler loading failed: {e}. Continuing with current variance_scheduler.")
565
+
566
+ # load optimizer state
567
+ if 'optimizer_state_dict' not in checkpoint:
568
+ raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
569
+ try:
570
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
571
+ except ValueError as e:
572
+ warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
573
+
574
+ epoch = checkpoint.get('epoch', -1)
575
+ loss = checkpoint.get('loss', float('inf'))
576
+
577
+ if self.master_process:
578
+ print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
579
+ return epoch, loss
580
+
581
+ @staticmethod
582
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
583
+ """Creates a learning rate scheduler for warmup.
584
+
585
+ Generates a scheduler that linearly increases the learning rate from 0 to the
586
+ optimizer's initial value over the specified warmup epochs, then maintains it.
587
+
588
+ Parameters
589
+ ----------
590
+ optimizer : torch.optim.Optimizer
591
+ Optimizer to apply the scheduler to.
592
+ warmup_epochs : int
593
+ Number of epochs for the warmup phase.
594
+
595
+ Returns
596
+ -------
597
+ torch.optim.lr_scheduler.LambdaLR
598
+ Learning rate scheduler for warmup.
599
+ """
600
+
601
+ def lr_lambda(epoch):
602
+ if epoch < warmup_epochs:
603
+ return epoch / warmup_epochs
604
+ return 1.0
605
+
606
+ return LambdaLR(optimizer, lr_lambda)
607
+
608
+ def _wrap_models_for_ddp(self) -> None:
609
+ """Wrap models with DistributedDataParallel for multi-GPU training."""
610
+ if self.use_ddp:
611
+ # wrap noise predictor with DDP
612
+ self.noise_predictor = DDP(
613
+ self.noise_predictor,
614
+ device_ids=[self.ddp_local_rank],
615
+ find_unused_parameters=True
616
+ )
617
+
618
+ # wrap conditional model with DDP if it exists
619
+ if self.conditional_model is not None:
620
+ self.conditional_model = DDP(
621
+ self.conditional_model,
622
+ device_ids=[self.ddp_local_rank],
623
+ find_unused_parameters=True
624
+ )
625
+
626
+ def forward(self) -> Tuple[List, float]:
627
+ """Trains the DDPM model to predict noise added by the forward diffusion process.
628
+
629
+ Executes the training loop with support for distributed training, gradient accumulation,
630
+ mixed precision, gradient clipping, and learning rate scheduling. Includes validation,
631
+ early stopping, and checkpointing functionality.
632
+
633
+ Returns
634
+ -------
635
+ train_losses : list of float
636
+ List of mean training losses per epoch.
637
+ best_val_loss : float
638
+ Best validation or training loss achieved.
639
+ """
640
+ # set models to training mode
641
+ self.noise_predictor.train()
642
+ if self.conditional_model is not None:
643
+ self.conditional_model.train()
644
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
645
+ self.reverse_diffusion.train()
646
+ self.forward_diffusion.train()
647
+ else:
648
+ self.reverse_diffusion.eval()
649
+ self.forward_diffusion.eval()
650
+
651
+ # compile models for optimization (if supported)
652
+ if self.use_compilation:
653
+ try:
654
+ self.noise_predictor = torch.compile(self.noise_predictor)
655
+ if self.conditional_model is not None:
656
+ self.conditional_model = torch.compile(self.conditional_model)
657
+ except Exception as e:
658
+ if self.master_process:
659
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
660
+
661
+
662
+ # wrap models for DDP after compilation
663
+ self._wrap_models_for_ddp()
664
+
665
+ # initialize training components
666
+ scaler = torch.GradScaler()
667
+ train_losses = []
668
+ best_val_loss = float("inf")
669
+ wait = 0
670
+
671
+ # main training loop
672
+ for epoch in range(self.max_epochs):
673
+ # set epoch for distributed sampler if using DDP
674
+ if self.use_ddp and hasattr(self.data_loader.sampler, 'set_epoch'):
675
+ self.data_loader.sampler.set_epoch(epoch)
676
+
677
+ train_losses_epoch = []
678
+
679
+ # training step loop with gradient accumulation
680
+ for step, (x, y) in enumerate(tqdm(self.data_loader, disable=not self.master_process)):
681
+ x = x.to(self.device)
682
+
683
+ # process conditional inputs if conditional model exists
684
+ if self.conditional_model is not None:
685
+ y_encoded = self._process_conditional_input(y)
686
+ else:
687
+ y_encoded = None
688
+
689
+ # forward pass with mixed precision
690
+ with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
691
+ # generate noise and timesteps
692
+ noise = torch.randn_like(x).to(self.device)
693
+ t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
694
+
695
+ # apply forward diffusion
696
+ noisy_x = self.forward_diffusion(x, noise, t)
697
+
698
+ # predict noise
699
+ predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
700
+
701
+ # compute loss and scale for gradient accumulation
702
+ loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
703
+
704
+ # backward pass
705
+ scaler.scale(loss).backward()
706
+
707
+ # gradient accumulation and optimizer step
708
+ if (step + 1) % self.grad_accumulation_steps == 0:
709
+ # clip gradients
710
+ scaler.unscale_(self.optimizer)
711
+ torch.nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
712
+ if self.conditional_model is not None:
713
+ torch.nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
714
+
715
+ # optimizer step
716
+ scaler.step(self.optimizer)
717
+ scaler.update()
718
+ self.optimizer.zero_grad()
719
+
720
+ # update learning rate (warmup scheduler)
721
+ self.warmup_lr_scheduler.step()
722
+
723
+ # record loss (unscaled)
724
+ train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
725
+
726
+ # compute mean training loss
727
+ mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
728
+
729
+ # all-reduce loss across processes for DDP
730
+ if self.use_ddp:
731
+ loss_tensor = torch.tensor(mean_train_loss, device=self.device)
732
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
733
+ mean_train_loss = loss_tensor.item()
734
+
735
+ train_losses.append(mean_train_loss)
736
+
737
+ # print training progress (only master process)
738
+ if self.master_process and (epoch + 1) % self.log_frequency == 0:
739
+ current_lr = self.optimizer.param_groups[0]['lr']
740
+ print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
741
+
742
+ # validation step
743
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
744
+ val_metrics = self.validate()
745
+ val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
746
+
747
+ if self.master_process:
748
+ print(f" | Val Loss: {val_loss:.4f}", end="")
749
+ if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
750
+ print(f" | FID: {fid:.4f}", end="")
751
+ if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
752
+ print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
753
+ if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
754
+ print(f" | LPIPS: {lpips_score:.4f}", end="")
755
+ print()
756
+
757
+ current_best = val_loss
758
+ self.scheduler.step(val_loss)
759
+ else:
760
+ if self.master_process:
761
+ print()
762
+ current_best = mean_train_loss
763
+ self.scheduler.step(mean_train_loss)
764
+
765
+ # save checkpoint and early stopping (only master process)
766
+ if self.master_process:
767
+ if current_best < best_val_loss and (epoch + 1) % self.val_frequency == 0:
768
+ best_val_loss = current_best
769
+ wait = 0
770
+ self._save_checkpoint(epoch + 1, best_val_loss)
771
+ else:
772
+ wait += 1
773
+ if wait >= self.patience:
774
+ print("Early stopping triggered")
775
+ self._save_checkpoint(epoch + 1, best_val_loss, "_early_stop")
776
+ break
777
+
778
+ # clean up DDP
779
+ if self.use_ddp:
780
+ destroy_process_group()
781
+
782
+ return train_losses, best_val_loss
783
+
784
+ def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor:
785
+ """Process conditional input for text-to-image generation.
786
+
787
+ Parameters
788
+ ----------
789
+ y : torch.Tensor or list
790
+ Conditional input (text prompts).
791
+
792
+ Returns
793
+ -------
794
+ torch.Tensor
795
+ Encoded conditional input.
796
+ """
797
+ # Convert to string list
798
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
799
+ y_list = [str(item) for item in y_list]
800
+
801
+ # Tokenize
802
+ y_encoded = self.tokenizer(
803
+ y_list,
804
+ padding="max_length",
805
+ truncation=True,
806
+ max_length=self.max_token_length,
807
+ return_tensors="pt"
808
+ ).to(self.device)
809
+
810
+ # get embeddings
811
+ input_ids = y_encoded["input_ids"]
812
+ attention_mask = y_encoded["attention_mask"]
813
+ y_encoded = self.conditional_model(input_ids, attention_mask)
814
+
815
+ return y_encoded
816
+
817
+ def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "") -> None:
818
+ """Save model checkpoint (only called by master process).
819
+
820
+ Parameters
821
+ ----------
822
+ epoch : int
823
+ Current epoch number.
824
+ loss : float
825
+ Current loss value.
826
+ suffix : str, optional
827
+ Suffix to add to checkpoint filename.
828
+ """
829
+ try:
830
+ # get state dicts, handling DDP wrapping
831
+ noise_predictor_state = (
832
+ self.noise_predictor.module.state_dict() if self.use_ddp
833
+ else self.noise_predictor.state_dict()
834
+ )
835
+ conditional_state = None
836
+ if self.conditional_model is not None:
837
+ conditional_state = (
838
+ self.conditional_model.module.state_dict() if self.use_ddp
839
+ else self.conditional_model.state_dict()
840
+ )
841
+
842
+ checkpoint = {
843
+ 'epoch': epoch,
844
+ 'model_state_dict_noise_predictor': noise_predictor_state,
845
+ 'model_state_dict_conditional': conditional_state,
846
+ 'optimizer_state_dict': self.optimizer.state_dict(),
847
+ 'loss': loss,
848
+ 'variance_scheduler_model': (
849
+ self.forward_diffusion.variance_scheduler.state_dict() if isinstance(self.forward_diffusion.variance_scheduler, nn.Module)
850
+ else self.forward_diffusion.variance_scheduler
851
+ ),
852
+ 'max_epochs': self.max_epochs,
853
+ }
854
+
855
+ filename = f"ddpm_epoch_{epoch}{suffix}.pth"
856
+ filepath = os.path.join(self.store_path, filename)
857
+ os.makedirs(self.store_path, exist_ok=True)
858
+ torch.save(checkpoint, filepath)
859
+
860
+ print(f"Model saved at epoch {epoch}")
861
+
862
+ except Exception as e:
863
+ print(f"Failed to save model: {e}")
864
+
865
+
866
+
867
+ def validate(self) -> Tuple[float, float, float, float, float, float]:
868
+ """Validates the noise predictor and computes evaluation metrics.
869
+
870
+ Computes validation loss (MSE between predicted and ground truth noise) and generates
871
+ samples using the reverse diffusion model. Evaluates image quality metrics if available.
872
+
873
+ Returns
874
+ -------
875
+ tuple
876
+ (val_loss, fid, mse, psnr, ssim, lpips_score) where metrics may be None if not computed.
877
+ """
878
+ self.noise_predictor.eval()
879
+ if self.conditional_model is not None:
880
+ self.conditional_model.eval()
881
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
882
+ self.forward_diffusion.eval()
883
+ self.reverse_diffusion.eval()
884
+
885
+ val_losses = []
886
+ fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
887
+
888
+ with torch.no_grad():
889
+ for x, y in self.val_loader:
890
+ x = x.to(self.device)
891
+ x_orig = x.clone()
892
+
893
+ # process conditional input
894
+ if self.conditional_model is not None:
895
+ y_encoded = self._process_conditional_input(y)
896
+ else:
897
+ y_encoded = None
898
+
899
+ # compute validation loss
900
+ noise = torch.randn_like(x).to(self.device)
901
+ t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
902
+
903
+ noisy_x = self.forward_diffusion(x, noise, t)
904
+ predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
905
+ loss = self.objective(predicted_noise, noise)
906
+ val_losses.append(loss.item())
907
+
908
+ # generate samples for metrics evaluation
909
+ if self.metrics_ is not None and self.reverse_diffusion is not None:
910
+ xt = torch.randn_like(x).to(self.device)
911
+
912
+ # reverse diffusion sampling
913
+ for t in reversed(range(self.forward_diffusion.variance_scheduler.num_steps)):
914
+ time_steps = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
915
+ predicted_noise = self.noise_predictor(xt, time_steps, y_encoded, None)
916
+ xt = self.reverse_diffusion(xt, predicted_noise, time_steps)
917
+
918
+ # clamp and normalize generated samples
919
+ x_hat = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
920
+ if self.normalize_output:
921
+ x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
922
+ x_orig = (x_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
923
+
924
+ # compute metrics
925
+ metrics_result = self.metrics_.forward(x_orig, x_hat)
926
+ fid, mse, psnr, ssim, lpips_score = metrics_result
927
+
928
+ if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
929
+ fid_scores.append(fid)
930
+ if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
931
+ mse_scores.append(mse)
932
+ psnr_scores.append(psnr)
933
+ ssim_scores.append(ssim)
934
+ if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
935
+ lpips_scores.append(lpips_score)
936
+
937
+ # compute average metrics
938
+ val_loss = torch.tensor(val_losses).mean().item()
939
+
940
+ # all-reduce validation metrics across processes for DDP
941
+ if self.use_ddp:
942
+ val_loss_tensor = torch.tensor(val_loss, device=self.device)
943
+ dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
944
+ val_loss = val_loss_tensor.item()
945
+
946
+ fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
947
+ mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
948
+ psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
949
+ ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
950
+ lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
951
+
952
+ # return to training mode
953
+ self.noise_predictor.train()
954
+ if self.conditional_model is not None:
955
+ self.conditional_model.train()
956
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
957
+ self.reverse_diffusion.train()
958
+ self.forward_diffusion.train()
959
+
960
+ return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
961
+
962
+
963
+ ###==================================================================================================================###
964
+
965
+
966
+ class SampleDDPM(nn.Module):
967
+ """mage generation using a trained Denoising Diffusion Probabilistic Model (DDPM).
968
+
969
+ Implements the sampling process for DDPM, generating images by iteratively
970
+ denoising random noise using a trained noise predictor and reverse diffusion
971
+ process. Supports conditional generation with text prompts via a conditional
972
+ model, as inspired by Ho et al. (2020).
973
+
974
+ Parameters
975
+ ----------
976
+ reverse_diffusion : nn.Module
977
+ Reverse diffusion module (e.g., ReverseDDPM) for the reverse process.
978
+ noise_predictor : nn.Module
979
+ Trained model to predict noise at each time step.
980
+ image_shape : tuple
981
+ Tuple of (height, width) specifying the generated image dimensions.
982
+ conditional_model : nn.Module, optional
983
+ Model for conditional generation (e.g., text embeddings), default None.
984
+ tokenizer : str, optional
985
+ Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased").
986
+ max_token_length : int, optional
987
+ Maximum length for tokenized prompts (default: 77).
988
+ batch_size : int, optional
989
+ Number of images to generate per batch (default: 1).
990
+ in_channels : int, optional
991
+ Number of input channels for generated images (default: 3).
992
+ device : torch.device, optional
993
+ Device for computation (default: CUDA if available, else CPU).
994
+ image_output_range : tuple, optional
995
+ Tuple of (min, max) for clamping generated images (default: (-1, 1)).
996
+ """
997
+ def __init__(
998
+ self,
999
+ reverse_diffusion: torch.nn.Module,
1000
+ noise_predictor: torch.nn.Module,
1001
+ image_shape: Tuple[int, int],
1002
+ conditional_model: Optional[torch.nn.Module] = None,
1003
+ tokenizer: str = "bert-base-uncased",
1004
+ max_token_length: int = 77,
1005
+ batch_size: int = 1,
1006
+ in_channels: int = 3,
1007
+ device: Optional[str] = None,
1008
+ image_output_range: Tuple[float, float] = (-1.0, 1.0)
1009
+ ) -> None:
1010
+ super().__init__()
1011
+ if device is None:
1012
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1013
+ elif isinstance(device, str):
1014
+ self.device = torch.device(device)
1015
+ else:
1016
+ self.device = device
1017
+ self.reverse = reverse_diffusion.to(self.device)
1018
+ self.noise_predictor = noise_predictor.to(self.device)
1019
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
1020
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
1021
+ self.max_token_length = max_token_length
1022
+ self.in_channels = in_channels
1023
+ self.image_shape = image_shape
1024
+ self.batch_size = batch_size
1025
+ self.image_output_range = image_output_range
1026
+
1027
+ if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(
1028
+ isinstance(s, int) and s > 0 for s in image_shape):
1029
+ raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
1030
+ if batch_size <= 0:
1031
+ raise ValueError("batch_size must be positive")
1032
+ if not isinstance(image_output_range, (tuple, list)) or len(image_output_range) != 2 or image_output_range[0] >= image_output_range[1]:
1033
+ raise ValueError("output_range must be a tuple (min, max) with min < max")
1034
+
1035
+ def tokenize(self, prompts: Union[List, str]) -> Tuple[torch.Tensor, torch.Tensor]:
1036
+ """Tokenizes text prompts for conditional generation.
1037
+
1038
+ Converts input prompts into tokenized input IDs and attention masks using the
1039
+ specified tokenizer, suitable for use with the conditional model.
1040
+
1041
+ Parameters
1042
+ ----------
1043
+ prompts : str or list
1044
+ A single text prompt or a list of text prompts.
1045
+
1046
+ Returns
1047
+ -------
1048
+ input_ids : torch.Tensor
1049
+ Tokenized input IDs, shape (batch_size, max_length).
1050
+ attention_mask : torch.Tensor
1051
+ Attention mask, shape (batch_size, max_length).
1052
+ """
1053
+ if isinstance(prompts, str):
1054
+ prompts = [prompts]
1055
+ elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
1056
+ raise TypeError("prompts must be a string or list of strings")
1057
+ encoded = self.tokenizer(
1058
+ prompts,
1059
+ padding="max_length",
1060
+ truncation=True,
1061
+ max_length=self.max_token_length,
1062
+ return_tensors="pt"
1063
+ )
1064
+ return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
1065
+
1066
+ def forward(
1067
+ self,
1068
+ conditions: Optional[Union[str, List]] = None,
1069
+ normalize_output: bool = True,
1070
+ save_images: bool = True,
1071
+ save_path: str = "ddpm_generated"
1072
+ ) -> torch.Tensor:
1073
+ """Generates images using the DDPM sampling process.
1074
+
1075
+ Iteratively denoises random noise to generate images using the reverse diffusion
1076
+ process and noise predictor. Supports conditional generation with text prompts.
1077
+ Optionally saves generated images to a specified directory.
1078
+
1079
+ Parameters
1080
+ ----------
1081
+ conditions : str or list, optional
1082
+ Text prompt(s) for conditional generation, default None.
1083
+ normalize_output : bool, optional
1084
+ If True, normalizes output images to [0, 1] (default: True).
1085
+ save_images : bool, optional
1086
+ If True, saves generated images to `save_path` (default: True).
1087
+ save_path : str, optional
1088
+ Directory to save generated images (default: "ddpm_generated").
1089
+
1090
+ Returns
1091
+ -------
1092
+ generated_imgs (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width). If `normalize_output` is True, images are normalized to [0, 1]; otherwise, they are clamped to `output_range`.
1093
+ """
1094
+ if conditions is not None and self.conditional_model is None:
1095
+ raise ValueError("Conditions provided but no conditional model specified")
1096
+ if conditions is None and self.conditional_model is not None:
1097
+ raise ValueError("Conditions must be provided for conditional model")
1098
+
1099
+ noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(
1100
+ self.device)
1101
+
1102
+ self.noise_predictor.eval()
1103
+ self.reverse.eval()
1104
+ if self.conditional_model:
1105
+ self.conditional_model.eval()
1106
+
1107
+ with torch.no_grad():
1108
+ xt = noisy_samples
1109
+ for t in reversed(range(self.reverse.variance_scheduler.num_steps)):
1110
+ time_steps = torch.full((self.batch_size,), t, device=self.device)#, dtype=torch.long)
1111
+ if self.conditional_model is not None and conditions is not None:
1112
+ input_ids, attention_masks = self.tokenize(conditions)
1113
+ key_padding_mask = (attention_masks == 0)
1114
+ y = self.conditional_model(input_ids, key_padding_mask)
1115
+ predicted_noise = self.noise_predictor(xt, time_steps, y, None)
1116
+ else:
1117
+ predicted_noise = self.noise_predictor(xt, time_steps, None, None)
1118
+ xt = self.reverse(xt, predicted_noise, time_steps)
1119
+
1120
+ generated_imgs = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
1121
+ if normalize_output:
1122
+ generated_imgs = (generated_imgs - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
1123
+
1124
+ # save images if save_images is True
1125
+ if save_images:
1126
+ os.makedirs(save_path, exist_ok=True) # create directory if it doesn't exist
1127
+ for i in range(generated_imgs.size(0)):
1128
+ img_path = os.path.join(save_path, f"image_{i+1}.png")
1129
+ save_image(generated_imgs[i], img_path)
1130
+
1131
+ return generated_imgs
1132
+
1133
+ def to(self, device: torch.device) -> Self:
1134
+ """Moves the module and its components to the specified device.
1135
+
1136
+ Updates the device attribute and moves the reverse diffusion, noise predictor,
1137
+ and conditional model (if present) to the specified device.
1138
+
1139
+ Parameters
1140
+ ----------
1141
+ device : torch.device
1142
+ Target device for the module and its components.
1143
+
1144
+ Returns
1145
+ -------
1146
+ sample_ddpm (SampleDDPM) - moved to the specified device.
1147
+ """
1148
+ self.device = device
1149
+ self.noise_predictor.to(device)
1150
+ self.reverse.to(device)
1151
+ if self.conditional_model:
1152
+ self.conditional_model.to(device)
1153
+ return super().to(device)