TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
unclip/ddim_model.py ADDED
@@ -0,0 +1,1296 @@
1
+
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ import torch.distributed as dist
7
+ from torch.nn.parallel import DistributedDataParallel as DDP
8
+ from torch.distributed import init_process_group, destroy_process_group
9
+ from tqdm import tqdm
10
+ from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
11
+ from transformers import BertTokenizer
12
+ import warnings
13
+ from torchvision.utils import save_image
14
+ from typing import Optional, Tuple, Callable, List, Any, Union, Self
15
+ import os
16
+
17
+
18
+
19
+
20
+
21
+ class ForwardDDIM(nn.Module):
22
+ """Forward diffusion process of DDIM.
23
+
24
+ Implements the forward diffusion process for Denoising Diffusion Implicit Models (DDIM),
25
+ which perturbs input data by adding Gaussian noise over a series of time steps,
26
+ as defined in Song et al. (2021, "Denoising Diffusion Implicit Models").
27
+
28
+ Parameters
29
+ ----------
30
+ `variance_scheduler` : object
31
+ Variance-scheduler object (VarianceSchedulerDDIM) containing the noise schedule parameters.
32
+ Expected to have attributes: `num_steps`, `trainable_beta`, `betas`, `sqrt_alpha_cumprod`,
33
+ `sqrt_one_minus_alpha_cumprod`, `compute_schedule`
34
+ """
35
+
36
+ def __init__(self, variance_scheduler: torch.nn.Module) -> None:
37
+ super().__init__()
38
+ self.variance_scheduler = variance_scheduler
39
+
40
+ def forward(self, x0: torch.Tensor, noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
41
+ """Applies the forward diffusion process to the input data.
42
+
43
+ Perturbs the input data `x0` by adding Gaussian noise according to the DDIM
44
+ forward process at specified time steps, using cumulative noise schedule parameters.
45
+
46
+ Parameters
47
+ ----------
48
+ `x0` : torch.Tensor
49
+ Input data tensor of shape (batch_size, channels, height, width).
50
+ `noise` : torch.Tensor
51
+ Gaussian noise tensor of the same shape as `x0`.
52
+ `time_steps` : torch.Tensor
53
+ Tensor of time step indices (long), shape (batch_size,),
54
+ where each value is in the range [0, hyper_params.num_steps - 1].
55
+
56
+ Returns
57
+ -------
58
+ xt (torch.Tensor) - Noisy data tensor `xt` at the specified time steps, with the same shape as `x0`.
59
+ """
60
+ if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.num_steps)):
61
+ raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.num_steps - 1}")
62
+
63
+ if self.variance_scheduler.trainable_beta:
64
+ _, _, _, sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.compute_schedule(
65
+ time_steps
66
+ )
67
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.to(x0.device)
68
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.to(x0.device)
69
+ else:
70
+ sqrt_alpha_cumprod_t = self.variance_scheduler.sqrt_alpha_cumprod[time_steps].to(x0.device)
71
+ sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.sqrt_one_minus_alpha_cumprod[time_steps].to(x0.device)
72
+
73
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1, 1, 1)
74
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1, 1, 1)
75
+
76
+ xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
77
+
78
+ return xt
79
+
80
+
81
+ ###==================================================================================================================###
82
+
83
+
84
+ class ReverseDDIM(nn.Module):
85
+ """Reverse diffusion process of DDIM.
86
+
87
+ Implements the reverse diffusion process for Denoising Diffusion Implicit Models
88
+ (DDIM), which denoises a noisy input `xt` using a predicted noise component and a
89
+ subsampled time step schedule, as defined in Song et al. (2021).
90
+
91
+ Parameters
92
+ ----------
93
+ `variance_scheduler` : object
94
+ Variance-scheduler object (VarianceSchedulerDDIM) containing the noise schedule parameters.
95
+ Expected to have attributes: `tau_num_steps`, `eta`, `get_tau_schedule`.
96
+ """
97
+
98
+ def __init__(self, variance_scheduler: torch.nn.Module):
99
+ super().__init__()
100
+ self.variance_scheduler = variance_scheduler
101
+
102
+ def forward(self, xt: torch.Tensor, predicted_noise: torch.Tensor, time_steps: torch.Tensor, prev_time_steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
103
+ """Applies the reverse diffusion process to the noisy input.
104
+
105
+ Denoises the input `xt` at time step `t` to produce the previous step `xt_prev`
106
+ at `prev_time_steps` using the predicted noise and the DDIM reverse process.
107
+ Optionally includes stochastic noise scaled by `eta`.
108
+
109
+ Parameters
110
+ ----------
111
+ `xt` : torch.Tensor
112
+ Noisy input tensor at time step `t`, shape (batch_size, channels, height, width).
113
+ `predicted_noise` : torch.Tensor
114
+ Predicted noise tensor, same shape as `xt`, typically output by a neural network.
115
+ `time_steps` : torch.Tensor
116
+ Tensor of time step indices (long), shape (batch_size,), where each value
117
+ is in the range [0, hyper_params.tau_num_steps - 1].
118
+ `prev_time_steps` : torch.Tensor
119
+ Tensor of previous time step indices (long), shape (batch_size,), where each
120
+ value is in the range [0, hyper_params.tau_num_steps - 1].
121
+
122
+ Returns
123
+ -------
124
+ xt_prev : torch.Tensor
125
+ Denoised tensor at `prev_time_steps`, same shape as `xt`.
126
+ x0 : torch.Tensor
127
+ Estimated original data (t=0), same shape as `xt`.
128
+ """
129
+ if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.tau_num_steps)):
130
+ raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
131
+ if not torch.all((prev_time_steps >= 0) & (prev_time_steps < self.variance_scheduler.tau_num_steps)):
132
+ raise ValueError(f"prev_time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
133
+
134
+ _, _, _, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod =self.variance_scheduler.get_tau_schedule()
135
+ tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
136
+ tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
137
+ prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
138
+ prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
139
+
140
+ eta = self.variance_scheduler.eta
141
+ x0 = (xt - tau_sqrt_one_minus_alpha_cumprod_t * predicted_noise) / tau_sqrt_alpha_cumprod_t
142
+ noise_coeff = eta * ((tau_sqrt_one_minus_alpha_cumprod_t / prev_tau_sqrt_alpha_cumprod_t) *
143
+ prev_tau_sqrt_one_minus_alpha_cumprod_t / torch.clamp(tau_sqrt_one_minus_alpha_cumprod_t, min=1e-8))
144
+ direction_coeff = torch.clamp(prev_tau_sqrt_one_minus_alpha_cumprod_t ** 2 - noise_coeff ** 2, min=1e-8).sqrt()
145
+ xt_prev = prev_tau_sqrt_alpha_cumprod_t * x0 + noise_coeff * torch.randn_like(xt) + direction_coeff * predicted_noise
146
+
147
+ return xt_prev, x0
148
+
149
+
150
+ ###==================================================================================================================###
151
+
152
+ class VarianceSchedulerDDIM(nn.Module):
153
+ """Variance-scheduler for DDIM noise schedule with flexible beta computation.
154
+
155
+ Manages the noise schedule parameters for DDIM, including beta values, derived
156
+ quantities (alphas, alpha_cumprod, etc.), and a subsampled time step schedule
157
+ (tau schedule), as inspired by Song et al. (2021). Supports trainable or fixed
158
+ schedules and various beta scheduling methods.
159
+
160
+ Parameters
161
+ ----------
162
+ `eta` : float, optional
163
+ Noise scaling factor for the DDIM reverse process (default: 0, deterministic).
164
+ `num_steps` : int, optional
165
+ Total number of diffusion steps (default: 1000).
166
+ `tau_num_steps` : int, optional
167
+ Number of subsampled time steps for DDIM sampling (default: 100).
168
+ `beta_start` : float, optional
169
+ Starting value for beta (default: 1e-4).
170
+ `beta_end` : float, optional
171
+ Ending value for beta (default: 0.02).
172
+ `trainable_beta` : bool, optional
173
+ Whether the beta schedule is trainable (default: False).
174
+ `beta_method` : str, optional
175
+ Method for computing the beta schedule (default: "linear").
176
+ Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
177
+ """
178
+
179
+ def __init__(
180
+ self,
181
+ eta: Optional[float] = None,
182
+ num_steps: int = 1000,
183
+ tau_num_steps: int = 100,
184
+ beta_start: float = 1e-4,
185
+ beta_end: float = 0.02,
186
+ trainable_beta: bool = False,
187
+ beta_method: str = "linear"
188
+ ):
189
+ super().__init__()
190
+ self.eta = eta or 0
191
+ self.num_steps = num_steps
192
+ self.tau_num_steps = tau_num_steps
193
+ self.beta_start = beta_start
194
+ self.beta_end = beta_end
195
+ self.trainable_beta = trainable_beta
196
+ self.beta_method = beta_method
197
+
198
+ if not (0 < beta_start < beta_end < 1):
199
+ raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
200
+ if num_steps <= 0:
201
+ raise ValueError(f"num_steps ({num_steps}) must be positive")
202
+
203
+ beta_range = (beta_start, beta_end)
204
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
205
+
206
+ if trainable_beta:
207
+ # Use reparameterization trick for trainable betas
208
+ # Initialize unconstrained parameters and transform them to valid beta range
209
+ self.beta_raw = nn.Parameter(torch.logit((betas_init - beta_start) / (beta_end - beta_start)))
210
+ else:
211
+ self.register_buffer('betas_buffer', betas_init)
212
+ self.register_buffer('alphas', 1 - self.betas)
213
+ self.register_buffer('alpha_cumprod', torch.cumprod(self.alphas, dim=0))
214
+ self.register_buffer('sqrt_alpha_cumprod', torch.sqrt(self.alpha_cumprod))
215
+ self.register_buffer('sqrt_one_minus_alpha_cumprod', torch.sqrt(1 - self.alpha_cumprod))
216
+
217
+ self.register_buffer('tau_indices', torch.linspace(0, num_steps - 1, tau_num_steps, dtype=torch.long))
218
+
219
+
220
+ @property
221
+ def betas(self) -> torch.Tensor:
222
+ """Returns the beta values, applying reparameterization if trainable."""
223
+ if self.trainable_beta:
224
+ # Transform unconstrained parameters to valid beta range using sigmoid
225
+ return self.beta_start + (self.beta_end - self.beta_start) * torch.sigmoid(self.beta_raw)
226
+ # Return the registered buffer directly if it exists
227
+ #return getattr(self, '_buffers', {}).get('betas_buffer', None) or ValueError("Betas buffer not found")
228
+ return self._buffers['betas_buffer']
229
+
230
+
231
+ def compute_beta_schedule(self, beta_range: Tuple[float, float], num_steps: int, method: str) -> torch.Tensor:
232
+ """Computes the beta schedule based on the specified method.
233
+
234
+ Generates a sequence of beta values for the DDIM noise schedule using the
235
+ chosen method, ensuring values are clamped within the specified range.
236
+
237
+ Parameters
238
+ ----------
239
+ `beta_range` : tuple
240
+ Tuple of (min_beta, max_beta) specifying the valid range for beta values.
241
+ `num_steps` : int
242
+ Number of diffusion steps.
243
+ `method` : str
244
+ Method for computing the beta schedule. Supported methods:
245
+ "linear", "sigmoid", "quadratic", "constant", "inverse_time".
246
+
247
+ Returns
248
+ -------
249
+ beta (torch.Tensor) - Tensor of beta values, shape (num_steps,).
250
+ """
251
+ beta_min, beta_max = beta_range
252
+ if method == "sigmoid":
253
+ x = torch.linspace(-6, 6, num_steps)
254
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
255
+ elif method == "quadratic":
256
+ x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
257
+ beta = x ** 2
258
+ elif method == "constant":
259
+ beta = torch.full((num_steps,), beta_max)
260
+ elif method == "inverse_time":
261
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
262
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
263
+ elif method == "linear":
264
+ beta = torch.linspace(beta_min, beta_max, num_steps)
265
+ else:
266
+ raise ValueError(
267
+ f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
268
+
269
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
270
+ return beta
271
+
272
+ def get_tau_schedule(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
273
+ """Computes the subsampled (tau) noise schedule for DDIM.
274
+
275
+ Returns the noise schedule parameters for the subsampled time steps used in
276
+ DDIM sampling, based on the `tau_indices`.
277
+
278
+ Returns
279
+ -------
280
+ tau_betas : torch.Tensor
281
+ Beta values for subsampled steps, shape (tau_num_steps,).
282
+ tau_alphas : torch.Tensor
283
+ Alpha values for subsampled steps, shape (tau_num_steps,).
284
+ tau_alpha_cumprod : torch.Tensor
285
+ Cumulative product of alphas for subsampled steps, shape (tau_num_steps,).
286
+ tau_sqrt_alpha_cumprod : torch.Tensor
287
+ Square root of alpha_cumprod for subsampled steps, shape (tau_num_steps,).
288
+ tau_sqrt_one_minus_alpha_cumprod : torch.Tensor
289
+ Square root of (1 - alpha_cumprod) for subsampled steps, shape (tau_num_steps,).
290
+ """
291
+ if self.trainable_beta:
292
+ # Use the property to get constrained betas
293
+ betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod = self.compute_schedule()
294
+ else:
295
+ betas = self.betas
296
+ alphas = self.alphas
297
+ alpha_cumprod = self.alpha_cumprod
298
+ sqrt_alpha_cumprod = self.sqrt_alpha_cumprod
299
+ sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod
300
+
301
+ tau_betas = betas[self.tau_indices]
302
+ tau_alphas = alphas[self.tau_indices]
303
+ tau_alpha_cumprod = alpha_cumprod[self.tau_indices]
304
+ tau_sqrt_alpha_cumprod = sqrt_alpha_cumprod[self.tau_indices]
305
+ tau_sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[self.tau_indices]
306
+
307
+ return tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod
308
+
309
+ def compute_schedule(self, time_steps: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
310
+ """Computes noise schedule parameters dynamically from betas.
311
+
312
+ Calculates the derived noise schedule parameters (alphas, alpha_cumprod, etc.)
313
+ from the provided beta values, as used in the DDIM forward and reverse processes.
314
+
315
+ Parameters
316
+ ----------
317
+ `time_steps` : torch.Tensor, optional
318
+ If provided, returns parameters only for specified time steps.
319
+ If None, returns parameters for all time steps.
320
+
321
+ Returns
322
+ -------
323
+ betas : torch.Tensor
324
+ Beta values, shape (num_steps,) or (len(time_steps),).
325
+ alphas : torch.Tensor
326
+ 1 - betas, shape (num_steps,) or (len(time_steps),).
327
+ alpha_cumprod : torch.Tensor
328
+ Cumulative product of alphas, shape (num_steps,) or (len(time_steps),).
329
+ sqrt_alpha_cumprod : torch.Tensor
330
+ Square root of alpha_cumprod, shape (num_steps,) or (len(time_steps),).
331
+ sqrt_one_minus_alpha_cumprod : torch.Tensor
332
+ Square root of (1 - alpha_cumprod), shape (num_steps,) or (len(time_steps),).
333
+ """
334
+ # Use the property to get constrained betas
335
+ betas = self.betas
336
+ alphas = 1 - betas
337
+ alpha_cumprod = torch.cumprod(alphas, dim=0)
338
+ sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod)
339
+ sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alpha_cumprod)
340
+
341
+ if time_steps is not None:
342
+ return (betas[time_steps], alphas[time_steps], alpha_cumprod[time_steps],
343
+ sqrt_alpha_cumprod[time_steps], sqrt_one_minus_alpha_cumprod[time_steps])
344
+ else:
345
+ return betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod
346
+
347
+
348
+ ###==================================================================================================================###
349
+
350
+
351
+ class TrainDDIM(nn.Module):
352
+ """Trainer for Denoising Diffusion Implicit Models (DDIM).
353
+
354
+ Manages the training process for DDIM, optimizing a noise predictor model to learn
355
+ the noise added by the forward diffusion process. Supports conditional training with
356
+ text prompts, mixed precision training, learning rate scheduling, early stopping, and
357
+ checkpointing, as inspired by Song et al. (2021).
358
+
359
+ Parameters
360
+ ----------
361
+ `noise_predictor` : nn.Module
362
+ Model to predict noise added during the forward diffusion process.
363
+ `variance_scheduler` : nn.Module
364
+ Variance-scheduler module (e.g., VarianceSchedulerDDIM) defining the noise schedule.
365
+ `data_loader` : torch.utils.data.DataLoader
366
+ DataLoader for training data.
367
+ `optimizer` : torch.optim.Optimizer
368
+ Optimizer for training the noise predictor and conditional model (if applicable).
369
+ `objective` : callable
370
+ Loss function to compute the difference between predicted and actual noise.
371
+ `val_loader` : torch.utils.data.DataLoader, optional
372
+ DataLoader for validation data, default None.
373
+ `max_epoch` : int, optional
374
+ Maximum number of training epochs (default: 1000).
375
+ `device` : torch.device, optional
376
+ Device for computation (default: CUDA if available, else CPU).
377
+ `conditional_model` : nn.Module, optional
378
+ Model for conditional generation (e.g., text embeddings), default None.
379
+ `metrics_` : object, optional
380
+ Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
381
+ `tokenizer` : BertTokenizer, optional
382
+ Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
383
+ `max_length` : int, optional
384
+ Maximum length for tokenized prompts (default: 77).
385
+ `store_path` : str, optional
386
+ Path to save model checkpoints (default: "ddim_model.pth").
387
+ `patience` : int, optional
388
+ Number of epochs to wait for improvement before early stopping (default: 100).
389
+ `warmup_epochs` : int, optional
390
+ Number of epochs for learning rate warmup (default: 100).
391
+ `val_frequency` : int, optional
392
+ Frequency (in epochs) for validation (default: 10).
393
+ `output_range` : tuple, optional
394
+ Range for clamping generated images (default: (-1, 1)).
395
+ `normalize_output` : bool, optional
396
+ Whether to normalize generated images to [0, 1] for metrics (default: True).
397
+ `ddp` : bool, optional
398
+ Whether to use Distributed Data Parallel training (default: False).
399
+ `num_grad_accumulation` : int, optional
400
+ Number of gradient accumulation steps before optimizer update (default: 1).
401
+ `progress_frequency` : int, optional
402
+ Number of epochs before printing loss.
403
+ compilation : bool, optional
404
+ whether the model is internally compiled using torch.compile (default: false)
405
+ """
406
+ def __init__(
407
+ self,
408
+ noise_predictor: torch.nn.Module,
409
+ variance_scheduler: torch.nn.Module,
410
+ data_loader: torch.utils.data.DataLoader,
411
+ optimizer: torch.optim.Optimizer,
412
+ objective: Callable,
413
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
414
+ max_epoch: int = 1000,
415
+ device: str = None,
416
+ conditional_model: torch.nn.Module = None,
417
+ metrics_: Optional[Any] = None,
418
+ tokenizer: Optional[BertTokenizer] = None,
419
+ max_length: int = 77,
420
+ store_path: Optional[str] = None,
421
+ patience: int = 100,
422
+ warmup_epochs: int = 100,
423
+ val_frequency: int = 10,
424
+ output_range: Tuple[float, float] = (-1, 1),
425
+ normalize_output: bool = True,
426
+ ddp: bool = False,
427
+ num_grad_accumulation: int = 1,
428
+ progress_frequency: int = 1,
429
+ compilation: bool = False
430
+ ) -> None:
431
+ super().__init__()
432
+ # Initialize DDP settings first
433
+ self.ddp = ddp
434
+ self.num_grad_accumulation = num_grad_accumulation
435
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
436
+
437
+ # Setup distributed training if enabled
438
+ if self.ddp:
439
+ self._setup_ddp()
440
+ else:
441
+ self._setup_single_gpu()
442
+
443
+ # Move models to appropriate device
444
+ self.noise_predictor = noise_predictor.to(self.device)
445
+ self.variance_scheduler = variance_scheduler.to(self.device)
446
+ self.forward_diffusion = ForwardDDIM(variance_scheduler=self.variance_scheduler).to(self.device)
447
+ self.reverse_diffusion = ReverseDDIM(variance_scheduler=self.variance_scheduler).to(self.device)
448
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
449
+
450
+ # Training components
451
+ self.metrics_ = metrics_
452
+ self.optimizer = optimizer
453
+ self.objective = objective
454
+ self.store_path = store_path or "ddim_model"
455
+ self.data_loader = data_loader
456
+ self.val_loader = val_loader
457
+ self.max_epoch = max_epoch
458
+ self.max_length = max_length
459
+ self.patience = patience
460
+ self.val_frequency = val_frequency
461
+ self.output_range = output_range
462
+ self.normalize_output = normalize_output
463
+ self.progress_frequency = progress_frequency
464
+ self.compilation = compilation
465
+
466
+ # Learning rate scheduling
467
+ self.scheduler = ReduceLROnPlateau(
468
+ self.optimizer,
469
+ patience=self.patience,
470
+ factor=0.5
471
+ )
472
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
473
+
474
+ # Initialize tokenizer
475
+ if tokenizer is None:
476
+ try:
477
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
478
+ except Exception as e:
479
+ raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
480
+
481
+ def _setup_ddp(self) -> None:
482
+ """Setup Distributed Data Parallel training configuration.
483
+
484
+ Initializes process group, determines rank information, and sets up
485
+ CUDA device for the current process.
486
+ """
487
+ # Check if DDP environment variables are set
488
+ if "RANK" not in os.environ:
489
+ raise ValueError("DDP enabled but RANK environment variable not set")
490
+ if "LOCAL_RANK" not in os.environ:
491
+ raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
492
+ if "WORLD_SIZE" not in os.environ:
493
+ raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
494
+
495
+ # Ensure CUDA is available for DDP
496
+ if not torch.cuda.is_available():
497
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
498
+
499
+ # Initialize process group only if not already initialized
500
+ if not torch.distributed.is_initialized():
501
+ init_process_group(backend="nccl")
502
+
503
+ # Get rank information
504
+ self.ddp_rank = int(os.environ["RANK"]) # Global rank across all nodes
505
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # Local rank on current node
506
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # Total number of processes
507
+
508
+ # Set device and make it current
509
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
510
+ # self.device = f"cuda:{self.ddp_local_rank}"
511
+ torch.cuda.set_device(self.device)
512
+
513
+ # Master process handles logging, checkpointing, etc.
514
+ self.master_process = self.ddp_rank == 0
515
+
516
+ if self.master_process:
517
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
518
+
519
+ def _setup_single_gpu(self) -> None:
520
+ """Setup single GPU or CPU training configuration."""
521
+ self.ddp_rank = 0
522
+ self.ddp_local_rank = 0
523
+ self.ddp_world_size = 1
524
+ self.master_process = True
525
+
526
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
527
+ """Loads a training checkpoint to resume training.
528
+
529
+ Restores the state of the noise predictor, conditional model (if applicable),
530
+ and optimizer from a saved checkpoint. Handles DDP model state dict loading.
531
+
532
+ Parameters
533
+ ----------
534
+ checkpoint_path : str
535
+ Path to the checkpoint file.
536
+
537
+ Returns
538
+ -------
539
+ epoch : int
540
+ The epoch at which the checkpoint was saved.
541
+ loss : float
542
+ The loss at the checkpoint.
543
+ """
544
+ try:
545
+ # Load checkpoint with proper device mapping
546
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
547
+ except FileNotFoundError:
548
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
549
+
550
+ # Load noise predictor state
551
+ if 'model_state_dict_noise_predictor' not in checkpoint:
552
+ raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
553
+
554
+ # Handle DDP wrapped model state dict
555
+ state_dict = checkpoint['model_state_dict_noise_predictor']
556
+ if self.ddp and not any(key.startswith('module.') for key in state_dict.keys()):
557
+ # If loading non-DDP checkpoint into DDP model, add 'module.' prefix
558
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
559
+ elif not self.ddp and any(key.startswith('module.') for key in state_dict.keys()):
560
+ # If loading DDP checkpoint into non-DDP model, remove 'module.' prefix
561
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
562
+
563
+ self.noise_predictor.load_state_dict(state_dict)
564
+
565
+ # Load conditional model state if applicable
566
+ if self.conditional_model is not None:
567
+ if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
568
+ cond_state_dict = checkpoint['model_state_dict_conditional']
569
+ # Handle DDP wrapping for conditional model
570
+ if self.ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
571
+ cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
572
+ elif not self.ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
573
+ cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
574
+ self.conditional_model.load_state_dict(cond_state_dict)
575
+ else:
576
+ warnings.warn(
577
+ "Checkpoint contains no 'model_state_dict_conditional' or it is None, "
578
+ "skipping conditional model loading"
579
+ )
580
+
581
+ # Load hyper_params state
582
+ if 'variance_scheduler_model' not in checkpoint:
583
+ raise KeyError("Checkpoint missing 'variance_scheduler_model' key")
584
+ try:
585
+ if isinstance(self.variance_scheduler, nn.Module):
586
+ self.variance_scheduler.load_state_dict(checkpoint['variance_scheduler_model'])
587
+ else:
588
+ self.variance_scheduler = checkpoint['variance_scheduler_model']
589
+ except Exception as e:
590
+ warnings.warn(f"Variance_scheduler loading failed: {e}. Continuing with current variance_scheduler.")
591
+
592
+ # Load optimizer state
593
+ if 'optimizer_state_dict' not in checkpoint:
594
+ raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
595
+ try:
596
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
597
+ except ValueError as e:
598
+ warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
599
+
600
+ epoch = checkpoint.get('epoch', -1)
601
+ loss = checkpoint.get('loss', float('inf'))
602
+
603
+ if self.master_process:
604
+ print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
605
+ return epoch, loss
606
+
607
+ @staticmethod
608
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
609
+ """Creates a learning rate scheduler for warmup.
610
+
611
+ Generates a scheduler that linearly increases the learning rate from 0 to the
612
+ optimizer's initial value over the specified warmup epochs, then maintains it.
613
+
614
+ Parameters
615
+ ----------
616
+ `optimizer` : torch.optim.Optimizer
617
+ Optimizer to apply the scheduler to.
618
+ `warmup_epochs` : int
619
+ Number of epochs for the warmup phase.
620
+
621
+ Returns
622
+ -------
623
+ lr_scheduler (torch.optim.lr_scheduler.LambdaLR) - Learning rate scheduler for warmup.
624
+ """
625
+ def lr_lambda(epoch: int) -> float:
626
+ if epoch < warmup_epochs:
627
+ return epoch / warmup_epochs
628
+ return 1.0
629
+
630
+ return LambdaLR(optimizer, lr_lambda)
631
+
632
+ def _wrap_models_for_ddp(self) -> None:
633
+ """Wrap models with DistributedDataParallel for multi-GPU training."""
634
+ if self.ddp:
635
+ # Wrap noise predictor with DDP
636
+ self.noise_predictor = DDP(
637
+ self.noise_predictor,
638
+ device_ids=[self.ddp_local_rank],
639
+ find_unused_parameters=True
640
+ )
641
+
642
+ # Wrap conditional model with DDP if it exists
643
+ if self.conditional_model is not None:
644
+ self.conditional_model = DDP(
645
+ self.conditional_model,
646
+ device_ids=[self.ddp_local_rank],
647
+ find_unused_parameters=True
648
+ )
649
+
650
+ def forward(self) -> Tuple[List, float]:
651
+ """Trains the DDIM model to predict noise added by the forward diffusion process.
652
+
653
+ Executes the training loop, optimizing the noise predictor and conditional model
654
+ (if applicable) using mixed precision, gradient clipping, and learning rate
655
+ scheduling. Supports validation, early stopping, and checkpointing.
656
+
657
+ Returns
658
+ -------
659
+ train_losses : list of float
660
+ List of mean training losses per epoch.
661
+ best_val_loss : float
662
+ Best validation or training loss achieved.
663
+ """
664
+
665
+ # Set models to training mode
666
+ self.noise_predictor.train()
667
+ if self.conditional_model is not None:
668
+ self.conditional_model.train()
669
+
670
+ # Compile models for optimization (if supported)
671
+ if self.compilation:
672
+ try:
673
+ self.noise_predictor = torch.compile(self.noise_predictor)
674
+ if self.conditional_model is not None:
675
+ self.conditional_model = torch.compile(self.conditional_model)
676
+ except Exception as e:
677
+ if self.master_process:
678
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
679
+
680
+ # Wrap models for DDP after compilation
681
+ self._wrap_models_for_ddp()
682
+
683
+ # Initialize training components
684
+ scaler = torch.GradScaler()
685
+ train_losses = []
686
+ best_val_loss = float("inf")
687
+ wait = 0
688
+
689
+ # Main training loop
690
+ for epoch in range(self.max_epoch):
691
+ # Set epoch for distributed sampler if using DDP
692
+ if self.ddp and hasattr(self.data_loader.sampler, 'set_epoch'):
693
+ self.data_loader.sampler.set_epoch(epoch)
694
+
695
+ train_losses_epoch = []
696
+
697
+ # Training step loop with gradient accumulation
698
+ for step, (x, y) in enumerate(tqdm(self.data_loader, disable=not self.master_process)):
699
+ x = x.to(self.device)
700
+
701
+ # Process conditional inputs if conditional model exists
702
+ if self.conditional_model is not None:
703
+ y_encoded = self._process_conditional_input(y)
704
+ else:
705
+ y_encoded = None
706
+
707
+ # Forward pass with mixed precision
708
+ with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
709
+
710
+ # Generate noise and timesteps
711
+ noise = torch.randn_like(x).to(self.device)
712
+ t = torch.randint(0, self.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
713
+ assert x.device == noise.device == t.device, "Device mismatch detected"
714
+ assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
715
+
716
+ # Apply forward diffusion
717
+ noisy_x = self.forward_diffusion(x, noise, t)
718
+
719
+ # Predict noise
720
+ p_noise = self.noise_predictor(x=noisy_x, t=t, y=y_encoded)
721
+
722
+ # Compute loss and scale for gradient accumulation
723
+ loss = self.objective(p_noise, noise) / self.num_grad_accumulation
724
+
725
+ # Backward pass
726
+ scaler.scale(loss).backward()
727
+
728
+ # Gradient accumulation and optimizer step
729
+ if (step + 1) % self.num_grad_accumulation == 0:
730
+ # Clip gradients
731
+ scaler.unscale_(self.optimizer)
732
+ torch.nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
733
+ if self.conditional_model is not None:
734
+ torch.nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
735
+
736
+ # Optimizer step
737
+ scaler.step(self.optimizer)
738
+ scaler.update()
739
+ self.optimizer.zero_grad()
740
+
741
+ # Update learning rate (warmup scheduler)
742
+ self.warmup_lr_scheduler.step()
743
+
744
+ # Record loss (unscaled)
745
+ train_losses_epoch.append(loss.item() * self.num_grad_accumulation)
746
+
747
+ # Compute mean training loss
748
+ mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
749
+
750
+ # All-reduce loss across processes for DDP
751
+ if self.ddp:
752
+ loss_tensor = torch.tensor(mean_train_loss, device=self.device)
753
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
754
+ mean_train_loss = loss_tensor.item()
755
+
756
+ train_losses.append(mean_train_loss)
757
+
758
+ # Print training progress (only master process)
759
+ if self.master_process:
760
+ if (epoch + 1) % self.progress_frequency == 0:
761
+ print(f"\nEpoch: {epoch + 1} | Learning Rate: {self.optimizer.param_groups[0]['lr']} | Train Loss: {mean_train_loss:.4f}", end="")
762
+
763
+ # Validation step
764
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
765
+ val_metrics = self.validate()
766
+ val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
767
+
768
+ if self.master_process:
769
+ print(f" | Val Loss: {val_loss:.4f}", end="")
770
+ if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
771
+ print(f" | FID: {fid:.4f}", end="")
772
+ if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
773
+ print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
774
+ if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
775
+ print(f" | LPIPS: {lpips_score:.4f}", end="")
776
+ print()
777
+
778
+ current_best = val_loss
779
+ self.scheduler.step(val_loss)
780
+ else:
781
+ if self.master_process:
782
+ print()
783
+ current_best = mean_train_loss
784
+ self.scheduler.step(mean_train_loss)
785
+
786
+ # Save checkpoint and early stopping (only master process)
787
+ if self.master_process:
788
+ if current_best < best_val_loss and (epoch + 1) % self.val_frequency == 0:
789
+ best_val_loss = current_best
790
+ wait = 0
791
+ self._save_checkpoint(epoch + 1, best_val_loss)
792
+ else:
793
+ wait += 1
794
+ if wait >= self.patience:
795
+ print("Early stopping triggered")
796
+ self._save_checkpoint(epoch + 1, best_val_loss, "_early_stop")
797
+ break
798
+
799
+ # Clean up DDP
800
+ if self.ddp:
801
+ destroy_process_group()
802
+
803
+ return train_losses, best_val_loss
804
+
805
+ def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor:
806
+ """Process conditional input for text-to-image generation.
807
+
808
+ Parameters
809
+ ----------
810
+ y : torch.Tensor or list
811
+ Conditional input (text prompts).
812
+
813
+ Returns
814
+ -------
815
+ torch.Tensor
816
+ Encoded conditional input.
817
+ """
818
+ # Convert to string list
819
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
820
+ y_list = [str(item) for item in y_list]
821
+
822
+ # Tokenize
823
+ y_encoded = self.tokenizer(
824
+ y_list,
825
+ padding="max_length",
826
+ truncation=True,
827
+ max_length=self.max_length,
828
+ return_tensors="pt"
829
+ ).to(self.device)
830
+
831
+ # Get embeddings
832
+ input_ids = y_encoded["input_ids"]
833
+ attention_mask = y_encoded["attention_mask"]
834
+ y_encoded = self.conditional_model(input_ids, attention_mask)
835
+
836
+ return y_encoded
837
+ def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "") -> None:
838
+ """Save model checkpoint (only called by master process).
839
+
840
+ Parameters
841
+ ----------
842
+ epoch : int
843
+ Current epoch number.
844
+ loss : float
845
+ Current loss value.
846
+ suffix : str, optional
847
+ Suffix to add to checkpoint filename.
848
+ """
849
+ try:
850
+ # Get state dicts, handling DDP wrapping
851
+ noise_predictor_state = (
852
+ self.noise_predictor.module.state_dict() if self.ddp
853
+ else self.noise_predictor.state_dict()
854
+ )
855
+ conditional_state = None
856
+ if self.conditional_model is not None:
857
+ conditional_state = (
858
+ self.conditional_model.module.state_dict() if self.ddp
859
+ else self.conditional_model.state_dict()
860
+ )
861
+
862
+ checkpoint = {
863
+ 'epoch': epoch,
864
+ 'model_state_dict_noise_predictor': noise_predictor_state,
865
+ 'model_state_dict_conditional': conditional_state,
866
+ 'optimizer_state_dict': self.optimizer.state_dict(),
867
+ 'loss': loss,
868
+ 'variance_scheduler_model': (
869
+ self.variance_scheduler.state_dict() if isinstance(self.variance_scheduler, nn.Module)
870
+ else self.variance_scheduler
871
+ ),
872
+ 'max_epoch': self.max_epoch,
873
+ }
874
+
875
+ save_path = self.store_path + suffix + ".pth" if suffix else self.store_path + ".pth"
876
+ torch.save(checkpoint, save_path)
877
+ print(f"Model saved at epoch {epoch} to {save_path}")
878
+
879
+ except Exception as e:
880
+ print(f"Failed to save model: {e}")
881
+
882
+ def validate(self) -> Tuple[float, float, float, float, float, float]:
883
+ """Validates the noise predictor and computes evaluation Metrics.
884
+
885
+ Computes validation loss (MSE between predicted and ground truth noise) and generates
886
+ samples using the reverse diffusion model by manually iterating over timesteps.
887
+ Decodes samples to images and computes image-domain Metrics (MSE, PSNR, SSIM, FID, LPIPS)
888
+ if metrics_ is provided.
889
+
890
+ Returns
891
+ -------
892
+ val_loss : float
893
+ Mean validation loss.
894
+ fid : float, or `float('inf')` if not computed
895
+ Mean FID score.
896
+ mse : float, or None if not computed
897
+ Mean MSE
898
+ psnr : float, or None if not computed
899
+ Mean PSNR
900
+ ssim : float, or None if not computed
901
+ Mean SSIM
902
+ lpips_score : float, or None if not computed
903
+ Mean LPIPS score
904
+ """
905
+
906
+ self.noise_predictor.eval()
907
+ if self.conditional_model is not None:
908
+ self.conditional_model.eval()
909
+
910
+ val_losses = []
911
+ fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
912
+
913
+ with torch.no_grad():
914
+ for x, y in self.val_loader:
915
+ x = x.to(self.device)
916
+ x_orig = x.clone()
917
+
918
+ # Process conditional input
919
+ if self.conditional_model is not None:
920
+ y_encoded = self._process_conditional_input(y)
921
+ else:
922
+ y_encoded = None
923
+
924
+ # Compute validation loss
925
+ noise = torch.randn_like(x).to(self.device)
926
+ t = torch.randint(0, self.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
927
+
928
+ noisy_x = self.forward_diffusion(x, noise, t)
929
+ predicted_noise = self.noise_predictor(noisy_x, t, y_encoded)
930
+ loss = self.objective(predicted_noise, noise)
931
+ val_losses.append(loss.item())
932
+
933
+ # Generate samples for metrics evaluation
934
+ if self.metrics_ is not None and self.reverse_diffusion is not None:
935
+ xt = torch.randn_like(x).to(self.device)
936
+
937
+ # Reverse diffusion sampling
938
+ for t in reversed(range(self.variance_scheduler.tau_num_steps)):
939
+ time_steps = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
940
+ prev_time_steps = torch.full((xt.shape[0],), max(t - 1, 0), device=self.device, dtype=torch.long)
941
+ predicted_noise = self.noise_predictor(xt, time_steps, y_encoded)
942
+ xt, _ = self.reverse_diffusion(xt, predicted_noise, time_steps, prev_time_steps)
943
+
944
+ # Clamp and normalize generated samples
945
+ x_hat = torch.clamp(xt, min=self.output_range[0], max=self.output_range[1])
946
+ if self.normalize_output:
947
+ x_hat = (x_hat - self.output_range[0]) / (self.output_range[1] - self.output_range[0])
948
+ x_orig = (x_orig - self.output_range[0]) / (self.output_range[1] - self.output_range[0])
949
+
950
+ # Compute metrics
951
+ metrics_result = self.metrics_.forward(x_orig, x_hat)
952
+ fid, mse, psnr, ssim, lpips_score = metrics_result
953
+
954
+ if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
955
+ fid_scores.append(fid)
956
+ if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
957
+ mse_scores.append(mse)
958
+ psnr_scores.append(psnr)
959
+ ssim_scores.append(ssim)
960
+ if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
961
+ lpips_scores.append(lpips_score)
962
+
963
+ # Compute average metrics
964
+ val_loss = torch.tensor(val_losses).mean().item()
965
+
966
+ # All-reduce validation metrics across processes for DDP
967
+ if self.ddp:
968
+ val_loss_tensor = torch.tensor(val_loss, device=self.device)
969
+ dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
970
+ val_loss = val_loss_tensor.item()
971
+
972
+ fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
973
+ mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
974
+ psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
975
+ ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
976
+ lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
977
+
978
+ # Return to training mode
979
+ self.noise_predictor.train()
980
+ if self.conditional_model is not None:
981
+ self.conditional_model.train()
982
+
983
+ return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
984
+
985
+ ###==================================================================================================================###
986
+
987
+ class SampleDDIM(nn.Module):
988
+ """Image generation using a trained DDIM model.
989
+
990
+ Implements the sampling process for DDIM, generating images by iteratively denoising
991
+ random noise using a trained noise predictor and reverse diffusion process with a
992
+ subsampled time step schedule. Supports conditional generation with text prompts,
993
+ as inspired by Song et al. (2021).
994
+
995
+ Parameters
996
+ ----------
997
+ `reverse_diffusion` : nn.Module
998
+ Reverse diffusion module (e.g., ReverseDDIM) for the reverse process.
999
+ `noise_predictor` : nn.Module
1000
+ Trained model to predict noise at each time step.
1001
+ `image_shape` : tuple
1002
+ Tuple of (height, width) specifying the generated image dimensions.
1003
+ `conditional_model` : nn.Module, optional
1004
+ Model for conditional generation (e.g., text embeddings), default None.
1005
+ `tokenizer` : str, optional
1006
+ Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased").
1007
+ `max_length` : int, optional
1008
+ Maximum length for tokenized prompts (default: 77).
1009
+ `batch_size` : int, optional
1010
+ Number of images to generate per batch (default: 1).
1011
+ `in_channels` : int, optional
1012
+ Number of input channels for generated images (default: 3).
1013
+ `device` : torch.device, optional
1014
+ Device for computation (default: CUDA if available, else CPU).
1015
+ `output_range` : tuple, optional
1016
+ Tuple of (min, max) for clamping generated images (default: (-1, 1)).
1017
+ """
1018
+ def __init__(
1019
+ self,
1020
+ reverse_diffusion: torch.nn.Module,
1021
+ noise_predictor: torch.nn.Module,
1022
+ image_shape: Tuple[int, int],
1023
+ conditional_model: Optional[torch.nn.Module] = None,
1024
+ tokenizer: str = "bert-base-uncased",
1025
+ max_length: int = 77,
1026
+ batch_size: int = 1,
1027
+ in_channels: int = 3,
1028
+ device: Optional[str] = None,
1029
+ output_range: Tuple[float, float] = (-1.0, 1.0)
1030
+ ) -> None:
1031
+ super().__init__()
1032
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
1033
+ self.reverse = reverse_diffusion.to(self.device)
1034
+ self.noise_predictor = noise_predictor.to(self.device)
1035
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
1036
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
1037
+ self.max_length = max_length
1038
+ self.in_channels = in_channels
1039
+ self.image_shape = image_shape
1040
+ self.batch_size = batch_size
1041
+ self.output_range = output_range
1042
+
1043
+ if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(
1044
+ isinstance(s, int) and s > 0 for s in image_shape):
1045
+ raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
1046
+ if batch_size <= 0:
1047
+ raise ValueError("batch_size must be positive")
1048
+ if not isinstance(output_range, (tuple, list)) or len(output_range) != 2 or output_range[0] >= output_range[1]:
1049
+ raise ValueError("output_range must be a tuple (min, max) with min < max")
1050
+
1051
+
1052
+ def tokenize(self, prompts: Union[List, str]) -> Tuple[torch.Tensor, torch.Tensor]:
1053
+ """Tokenizes text prompts for conditional generation.
1054
+
1055
+ Converts input prompts into tokenized input IDs and attention masks using the
1056
+ specified tokenizer, suitable for use with the conditional model.
1057
+
1058
+ Parameters
1059
+ ----------
1060
+ `prompts` : str or list
1061
+ A single text prompt or a list of text prompts.
1062
+
1063
+ Returns
1064
+ -------
1065
+ input_ids : torch.Tensor
1066
+ Tokenized input IDs, shape (batch_size, max_length).
1067
+ attention_mask : torch.Tensor
1068
+ Attention mask, shape (batch_size, max_length).
1069
+ """
1070
+ if isinstance(prompts, str):
1071
+ prompts = [prompts]
1072
+ elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
1073
+ raise TypeError("prompts must be a string or list of strings")
1074
+ encoded = self.tokenizer(
1075
+ prompts,
1076
+ padding="max_length",
1077
+ truncation=True,
1078
+ max_length=self.max_length,
1079
+ return_tensors="pt"
1080
+ )
1081
+ return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
1082
+
1083
+ def forward(self, conditions: Optional[Union[str, List]] = None, normalize_output: bool = True, save_images: bool = True, save_path: str = "ddim_generated") -> torch.Tensor:
1084
+ """Generates images using the DDIM sampling process.
1085
+
1086
+ Iteratively denoises random noise to generate images using the reverse diffusion
1087
+ process with a subsampled time step schedule and noise predictor. Supports
1088
+ conditional generation with text prompts.
1089
+
1090
+ Parameters
1091
+ ----------
1092
+ `conditions` : str or list, optional
1093
+ Text prompt(s) for conditional generation, default None.
1094
+ `normalize_output` : bool, optional
1095
+ If True, normalizes output images to [0, 1] (default: True).
1096
+ `save_images` : bool, optional
1097
+ If True, saves generated images to `save_path` (default: True).
1098
+ `save_path` : str, optional
1099
+ Directory to save generated images (default: "ddim_generated").
1100
+
1101
+ Returns
1102
+ -------
1103
+ generated_imgs (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width). If `normalize_output` is True, images are normalized to [0, 1]; otherwise, they are clamped to `output_range`.
1104
+ """
1105
+
1106
+ if conditions is not None and self.conditional_model is None:
1107
+ raise ValueError("Conditions provided but no conditional model specified")
1108
+ if conditions is None and self.conditional_model is not None:
1109
+ raise ValueError("Conditions must be provided for conditional model")
1110
+
1111
+ noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
1112
+
1113
+ self.noise_predictor.eval()
1114
+ self.reverse.eval()
1115
+ if self.conditional_model:
1116
+ self.conditional_model.eval()
1117
+
1118
+ with torch.no_grad():
1119
+ xt = noisy_samples
1120
+ for t in reversed(range(self.reverse.hyper_params.tau_num_steps)):
1121
+ time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
1122
+ prev_time_steps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device, dtype=torch.long)
1123
+
1124
+ if self.conditional_model is not None and conditions is not None:
1125
+ input_ids, attention_masks = self.tokenize(conditions)
1126
+ key_padding_mask = (attention_masks == 0)
1127
+ y = self.conditional_model(input_ids, key_padding_mask)
1128
+ predicted_noise = self.noise_predictor(xt, time_steps, y)
1129
+ else:
1130
+ predicted_noise = self.noise_predictor(xt, time_steps)
1131
+
1132
+ xt, _ = self.reverse(xt, predicted_noise, time_steps, prev_time_steps)
1133
+
1134
+ generated_imgs = torch.clamp(xt, min=self.output_range[0], max=self.output_range[1])
1135
+ if normalize_output:
1136
+ generated_imgs = (generated_imgs - self.output_range[0]) / (self.output_range[1] - self.output_range[0])
1137
+
1138
+ if save_images:
1139
+ os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
1140
+ for i in range(generated_imgs.size(0)):
1141
+ img_path = os.path.join(save_path, f"image_{i}.png")
1142
+ save_image(generated_imgs[i], img_path)
1143
+
1144
+ return generated_imgs
1145
+
1146
+ def to(self, device: torch.device) -> Self:
1147
+ """Moves the module and its components to the specified device.
1148
+
1149
+ Updates the device attribute and moves the reverse diffusion, noise predictor,
1150
+ and conditional model (if present) to the specified device.
1151
+
1152
+ Parameters
1153
+ ----------
1154
+ `device` : torch.device
1155
+ Target device for the module and its components.
1156
+
1157
+ Returns
1158
+ -------
1159
+ sample_ddim (SampleDDIM) - moved to the specified device.
1160
+ """
1161
+ self.device = device
1162
+ self.noise_predictor.to(device)
1163
+ self.reverse.to(device)
1164
+ if self.conditional_model:
1165
+ self.conditional_model.to(device)
1166
+ return super().to(device)
1167
+
1168
+
1169
+ """
1170
+ from utils import NoisePredictor, Metrics, TextEncoder
1171
+ import os
1172
+ import matplotlib.pyplot as plt
1173
+ from PIL import Image
1174
+ import numpy as np
1175
+ import torch
1176
+ import torch.nn as nn
1177
+ import torchvision
1178
+ from torchvision import datasets, transforms
1179
+ from torch.utils.data import DataLoader, Subset
1180
+
1181
+
1182
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1183
+ transform = transforms.Compose([
1184
+ transforms.ToTensor(),
1185
+ transforms.Normalize((0.5,), (0.5,))
1186
+ ])
1187
+
1188
+ # Load the FashionMNIST training dataset (28x28 grayscale images)
1189
+ train_dataset = datasets.FashionMNIST(
1190
+ root='./data',
1191
+ train=True,
1192
+ download=True,
1193
+ transform=transform
1194
+ )
1195
+
1196
+ # Load the FashionMNIST test dataset (28x28 grayscale images)
1197
+ test_dataset = datasets.FashionMNIST(
1198
+ root='./data',
1199
+ train=False,
1200
+ download=True,
1201
+ transform=transform
1202
+ )
1203
+
1204
+ # Define subset sizes for training (200 samples) and validation (20 samples)
1205
+ train_subset_indices = torch.randperm(len(train_dataset))[:100]
1206
+ test_subset_indices = torch.randperm(len(test_dataset))[:10]
1207
+
1208
+ # Create subsets from the FashionMNIST training and test datasets
1209
+ train_subset = Subset(train_dataset, train_subset_indices)
1210
+ test_subset = Subset(test_dataset, test_subset_indices)
1211
+
1212
+ # Create DataLoaders for the training and validation subsets
1213
+ train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
1214
+ val_loader = DataLoader(test_subset, batch_size=10, shuffle=False, drop_last=False)
1215
+
1216
+ # Initialize the NoisePredictor for the DDPM model with parameters for grayscale images
1217
+ noise_predictor = NoisePredictor(
1218
+ in_channels=1, # Single channel for grayscale images in the training data
1219
+ down_channels=[16, 32],
1220
+ mid_channels=[32, 32],
1221
+ up_channels=[32, 16],
1222
+ down_sampling=[True, True],
1223
+ time_embed_dim=32,
1224
+ y_embed_dim=32,
1225
+ num_down_blocks=2,
1226
+ num_mid_blocks=2,
1227
+ num_up_blocks=2,
1228
+ down_sampling_factor=2
1229
+ ).to(device)
1230
+
1231
+ # label conditional model
1232
+ text_encoder = TextEncoder(
1233
+ use_pretrained_model=True,
1234
+ model_name="bert-base-uncased",
1235
+ vocabulary_size=30522,
1236
+ num_layers=2,
1237
+ input_dimension=32,
1238
+ output_dimension=32,
1239
+ num_heads=2,
1240
+ context_length=77
1241
+ ).to(device)
1242
+
1243
+ # Set up the AdamW optimizer for the NoisePredictor parameters with a learning rate of 1e-3
1244
+ optimizer = torch.optim.AdamW(
1245
+ [p for p in noise_predictor.parameters() if p.requires_grad] +
1246
+ [p for p in text_encoder.parameters() if p.requires_grad], lr=1e-3
1247
+ )
1248
+
1249
+ # Initialize the Mean Squared Error (MSE) loss function
1250
+ loss = nn.MSELoss()
1251
+
1252
+ # Configure the Metrics class for evaluation on CPU (GPUs are recommended for actual training)
1253
+ metrics = Metrics(
1254
+ device="cpu", # Using CPU for this tutorial, but GPUs are recommended for training diffusion models
1255
+ fid=True,
1256
+ metrics=True,
1257
+ lpips_=True
1258
+ )
1259
+
1260
+
1261
+ # Initialize DDPM hyperparameters for the noise schedule
1262
+ hyperparams_ddim = HyperParamsDDIM(
1263
+ num_steps=500,
1264
+ beta_start=1e-4,
1265
+ beta_end=0.02,
1266
+ trainable_beta=False,
1267
+ beta_method="linear"
1268
+ )
1269
+
1270
+ # Set up the reverse diffusion process for sampling
1271
+ reverse_ddim = ReverseDDIM(hyperparams_ddim)
1272
+
1273
+ # Configure the DDPM trainer for model training
1274
+ train_ddim = TrainDDIM(
1275
+ noise_predictor=noise_predictor,
1276
+ hyper_params=hyperparams_ddim,
1277
+ conditional_model=text_encoder,
1278
+ metrics_=metrics,
1279
+ optimizer=optimizer,
1280
+ objective=loss,
1281
+ data_loader=train_loader,
1282
+ val_loader=val_loader,
1283
+ max_epoch=5,
1284
+ device="cuda",
1285
+ store_path="test_ddim",
1286
+ val_frequency=3,
1287
+ ddp=False,
1288
+ num_grad_accumulation=3,
1289
+ progress_frequency=3,
1290
+ compilation=True
1291
+ )
1292
+
1293
+ train_losses, best_val_loss = train_ddim()
1294
+ """
1295
+
1296
+