TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
torchdiff/sde.py ADDED
@@ -0,0 +1,1231 @@
1
+ """
2
+ **Score-Based Generative Modeling with Stochastic Differential Equations (SDE)**
3
+
4
+ This module implements a complete framework for score-based generative models using SDEs,
5
+ as described in Song et al. (2021, "Score-Based Generative Modeling through Stochastic
6
+ Differential Equations"). It provides components for forward and reverse diffusion
7
+ processes, hyperparameter management, training, and image sampling, supporting Variance
8
+ Exploding (VE), Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE
9
+ methods for flexible noise schedules. Supports both unconditional and conditional
10
+ generation with text prompts.
11
+
12
+ **Components**
13
+
14
+ - **ForwardSDE**: Forward diffusion process to add noise using SDE methods.
15
+ - **ReverseSDE**: Reverse diffusion process to denoise using SDE methods.
16
+ - **VarianceSchedulerSDE**: Noise schedule and SDE-specific parameter management.
17
+ - **TrainSDE**: Training loop with mixed precision and scheduling.
18
+ - **SampleSDE**: Image generation from trained SDE models.
19
+
20
+ **References**
21
+
22
+ - Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." arXiv preprint arXiv:2011.13456 (2020).
23
+
24
+ ---------------------------------------------------------------------------------
25
+ """
26
+
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
31
+ from torch.nn.parallel import DistributedDataParallel as DDP
32
+ from torch.distributed import init_process_group, destroy_process_group
33
+ import torch.distributed as dist
34
+ from typing import Optional, Tuple, Callable, List, Any, Union, Self
35
+ from tqdm import tqdm
36
+ from torch.optim.lr_scheduler import LambdaLR
37
+ from transformers import BertTokenizer
38
+ import warnings
39
+ from torchvision.utils import save_image
40
+ import os
41
+
42
+
43
+ ###==================================================================================================================###
44
+
45
+ class ForwardSDE(nn.Module):
46
+ """Forward diffusion process for SDE-based generative models.
47
+
48
+ Implements the forward diffusion process for score-based generative models using
49
+ Stochastic Differential Equations (SDEs), supporting Variance Exploding (VE),
50
+ Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE methods, as
51
+ described in Song et al. (2021).
52
+
53
+ Parameters
54
+ ----------
55
+ variance_scheduler : object
56
+ Hyperparameter object (VarianceSchedulerSDE) containing SDE-specific parameters. Expected to have
57
+ attributes: `dt`, `sigmas`, `betas`, `cum_betas`.
58
+ sde_method : str
59
+ SDE method to use. Supported methods: "ve", "vp", "sub-vp", "ode".
60
+ """
61
+ def __init__(self, variance_scheduler: torch.nn.Module, sde_method: str) -> None:
62
+ super().__init__()
63
+ self.variance_scheduler = variance_scheduler
64
+ self.sde_method = sde_method
65
+
66
+ def forward(self, x0: torch.Tensor, noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
67
+ """Applies the forward SDE diffusion process to the input data.
68
+
69
+ Perturbs the input data `x0` by adding noise according to the specified SDE
70
+ method at given time steps, incorporating drift and diffusion terms as applicable.
71
+
72
+ Parameters
73
+ ----------
74
+ x0 : torch.Tensor
75
+ Input data tensor, shape (batch_size, channels, height, width).
76
+ noise : torch.Tensor
77
+ Gaussian noise tensor, same shape as `x0`.
78
+ time_steps : torch.Tensor
79
+ Tensor of time step indices (long), shape (batch_size,), where each value
80
+ is in the range [0, varinace_scheduler.num_steps - 1].
81
+
82
+ Returns
83
+ -------
84
+ xt (torch.Tensor) - Noisy data tensor at the specified time steps, same shape as `x0`.
85
+
86
+ """
87
+ dt = self.variance_scheduler.dt
88
+ if self.sde_method == "ve":
89
+ # use property to get sigmas (handles trainable case)
90
+ sigma_t = self.variance_scheduler.sigmas[time_steps]
91
+ sigma_t_prev = self.variance_scheduler.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
92
+ sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
93
+ x0 = x0 + noise * sigma_diff.view(-1, 1, 1, 1)
94
+
95
+ elif self.sde_method == "vp":
96
+ # use property to get betas (handles trainable case)
97
+ betas = self.variance_scheduler.betas[time_steps].view(-1, 1, 1, 1)
98
+ drift = -0.5 * betas * x0 * dt
99
+ diffusion = torch.sqrt(betas * dt) * noise
100
+ x0 = x0 + drift + diffusion
101
+
102
+ elif self.sde_method == "sub-vp":
103
+ # use properties to get betas and cum_betas (handles trainable case)
104
+ betas = self.variance_scheduler.betas[time_steps].view(-1, 1, 1, 1)
105
+ cum_betas = self.variance_scheduler._cum_betas[time_steps].view(-1, 1, 1, 1)
106
+ drift = -0.5 * betas * x0 * dt
107
+ diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise
108
+ x0 = x0 + drift + diffusion
109
+
110
+ elif self.sde_method == "ode":
111
+ # use property to get betas (handles trainable case)
112
+ betas = self.variance_scheduler.betas[time_steps].view(-1, 1, 1, 1)
113
+ drift = -0.5 * betas * x0 * dt
114
+ x0 = x0 + drift
115
+ else:
116
+ raise ValueError(f"Unknown method: {self.sde_method}")
117
+ return x0
118
+
119
+ ###==================================================================================================================###
120
+
121
+ class ReverseSDE(nn.Module):
122
+ """Reverse diffusion process for SDE-based generative models.
123
+
124
+ Implements the reverse diffusion process for score-based generative models using
125
+ Stochastic Differential Equations (SDEs), supporting Variance Exploding (VE),
126
+ Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE methods, as
127
+ described in Song et al. (2021). The reverse process denoises a noisy input using
128
+ predicted noise estimates.
129
+
130
+ Parameters
131
+ ----------
132
+ variance_scheduler : object
133
+ Hyperparameter object (VarianceSchedulerSDE) containing SDE-specific parameters. Expected to have
134
+ attributes: `dt`, `sigmas`, `betas`, `cum_betas`.
135
+ sde_method : str
136
+ SDE method to use. Supported methods: "ve", "vp", "sub-vp", "ode".
137
+ """
138
+ def __init__(self, variance_scheduler: torch.nn.Module, sde_method: str) -> None:
139
+ super().__init__()
140
+ self.variance_scheduler = variance_scheduler
141
+ self.sde_method = sde_method
142
+
143
+ def forward(self, xt: torch.Tensor, noise: torch.Tensor, predicted_noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
144
+ """Applies the reverse SDE diffusion process to the noisy input.
145
+
146
+ Denoises the input `xt` by applying the reverse SDE process, using predicted
147
+ noise estimates and optional stochastic noise, according to the specified SDE
148
+ method at given time steps. Incorporates drift and diffusion terms as applicable.
149
+
150
+ Parameters
151
+ ----------
152
+ xt : torch.Tensor
153
+ Noisy input tensor at time step `t`, shape (batch_size, channels, height, width).
154
+ noise : torch.Tensor or None
155
+ Gaussian noise tensor, same shape as `xt`, used for stochasticity. If None,
156
+ no stochastic noise is added (e.g., for deterministic ODE).
157
+ predicted_noise : torch.Tensor
158
+ Predicted noise tensor, same shape as `xt`, typically output by a neural network.
159
+ time_steps : torch.Tensor
160
+ Tensor of time step indices (long), shape (batch_size,), where each value
161
+ is in the range [0, variance_scheduler.num_steps - 1].
162
+
163
+ Returns
164
+ -------
165
+ xt (torch.Tensor) - Denoised tensor at the previous time step, same shape as `xt`.
166
+
167
+ **Notes**
168
+
169
+ - For the "ve" and "ode" methods, the output is clamped to [-1e5, 1e5] to prevent numerical instability.
170
+ - Stochastic noise (`noise`) is only added if provided and the method supports it (not applicable for "ode" in non-VE cases).
171
+ """
172
+ dt = self.variance_scheduler.dt
173
+ # use properties to get betas and cum_betas (handles trainable case)
174
+ betas = self.variance_scheduler.betas[time_steps].view(-1, 1, 1, 1)
175
+ cum_betas = self.variance_scheduler._cum_betas[time_steps].view(-1, 1, 1, 1)
176
+ if self.sde_method == "ve":
177
+ # use property to get sigmas (handles trainable case)
178
+ sigma_t = self.variance_scheduler.sigmas[time_steps]
179
+ sigma_t_prev = self.variance_scheduler.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
180
+ sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
181
+ drift = -(sigma_t ** 2 - sigma_t_prev ** 2).view(-1, 1, 1, 1) * predicted_noise * dt
182
+ diffusion = sigma_diff.view(-1, 1, 1, 1) * noise if noise is not None else 0
183
+ xt = xt + drift + diffusion
184
+ xt = torch.clamp(xt, -1e5, 1e5)
185
+
186
+ elif self.sde_method == "vp":
187
+ drift = -0.5 * betas * xt * dt - betas * predicted_noise * dt
188
+ diffusion = torch.sqrt(betas * dt) * noise if noise is not None else 0
189
+ xt = xt + drift + diffusion
190
+
191
+ elif self.sde_method == "sub-vp":
192
+ drift = -0.5 * betas * xt * dt - betas * (1 - torch.exp(-2 * cum_betas)) * predicted_noise * dt
193
+ diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise if noise is not None else 0
194
+ xt = xt + drift + diffusion
195
+
196
+ elif self.sde_method == "ode":
197
+ drift = -0.5 * betas * xt * dt - 0.5 * betas * predicted_noise * dt
198
+ xt = xt + drift
199
+ xt = torch.clamp(xt, -1e5, 1e5)
200
+ else:
201
+ raise ValueError(f"Unknown method: {self.sde_method}")
202
+ return xt
203
+
204
+ ###==================================================================================================================###
205
+
206
+ class VarianceSchedulerSDE(nn.Module):
207
+ """Hyperparameters for SDE-based generative models.
208
+
209
+ Manages the noise schedule and SDE-specific parameters for score-based generative
210
+ models, including beta and sigma schedules, time steps, and variance computations,
211
+ as described in Song et al. (2021). Supports trainable or fixed beta schedules and
212
+ multiple scheduling methods for flexible noise control.
213
+
214
+ Parameters
215
+ ----------
216
+ num_steps : int, optional
217
+ Number of diffusion steps (default: 1000).
218
+ beta_start : float, optional
219
+ Starting value for beta schedule (default: 1e-4).
220
+ beta_end : float, optional
221
+ Ending value for beta schedule (default: 0.02).
222
+ trainable_beta : bool, optional
223
+ Whether the beta schedule is trainable (default: False).
224
+ beta_method : str, optional
225
+ Method for computing the beta schedule (default: "linear").
226
+ Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
227
+ sigma_start : float, optional
228
+ Starting value for sigma schedule for VE method (default: 1e-3).
229
+ sigma_end : float, optional
230
+ Ending value for sigma schedule for VE method (default: 10.0).
231
+ start : float, optional
232
+ Start of the time interval for SDE integration (default: 0.0).
233
+ end : float, optional
234
+ End of the time interval for SDE integration (default: 1.0).
235
+ """
236
+ def __init__(
237
+ self,
238
+ num_steps: int = 1000,
239
+ beta_start: float = 1e-4,
240
+ beta_end: float = 0.02,
241
+ trainable_beta: bool = False,
242
+ beta_method: str = "linear",
243
+ sigma_start: float = 1e-3,
244
+ sigma_end: float = 10.0,
245
+ start: float = 0.0,
246
+ end: float = 1.0
247
+ ) -> None:
248
+ super().__init__()
249
+ self.num_steps = num_steps
250
+ self.beta_start = beta_start
251
+ self.beta_end = beta_end
252
+ self.trainable_beta = trainable_beta
253
+ self.beta_method = beta_method
254
+ self.sigma_start = sigma_start
255
+ self.sigma_end = sigma_end
256
+ self.start = start
257
+ self.end = end
258
+
259
+ if not (0 < self.beta_start < self.beta_end):
260
+ raise ValueError(f"beta_start ({self.beta_start}) and beta_end ({self.beta_end}) must satisfy 0 < start < end")
261
+ if not (0 < self.sigma_start < self.sigma_end):
262
+ raise ValueError(f"sigma_start ({self.sigma_start}) and sigma_end ({self.sigma_end}) must satisfy 0 < start < end")
263
+ if self.num_steps <= 0:
264
+ raise ValueError(f"num_steps ({self.num_steps}) must be positive")
265
+
266
+ beta_range = (beta_start, beta_end)
267
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
268
+ self.time = torch.linspace(self.start, self.end, self.num_steps, dtype=torch.float32)
269
+ self.dt = (self.end - self.start) / self.num_steps
270
+
271
+ if trainable_beta:
272
+ # use reparameterization trick for trainable betas
273
+ # initialize unconstrained parameters and transform them to valid beta range
274
+ self.beta_raw = nn.Parameter(torch.logit((betas_init - beta_start) / (beta_end - beta_start)))
275
+ else:
276
+ self.register_buffer('betas_buffer', betas_init)
277
+ self.register_buffer('cum_betas', torch.cumsum(betas_init, dim=0) * self.dt)
278
+ self.register_buffer("sigmas_buffer", self.sigma_start * (self.sigma_end / self.sigma_start) ** self.time)
279
+
280
+ @property
281
+ def betas(self) -> torch.Tensor:
282
+ """Returns the beta values, applying reparameterization if trainable."""
283
+ if self.trainable_beta:
284
+ # transform unconstrained parameters to valid beta range using sigmoid
285
+ return self.beta_start + (self.beta_end - self.beta_start) * torch.sigmoid(self.beta_raw)
286
+ else:
287
+ return self._buffers['betas_buffer']
288
+
289
+ @property
290
+ def _cum_betas(self) -> torch.Tensor:
291
+ """Returns the cumulative beta values, computing dynamically if trainable."""
292
+ if self.trainable_beta:
293
+ return torch.cumsum(self.betas, dim=0) * self.dt
294
+ else:
295
+ return self._buffers['cum_betas']
296
+
297
+ @property
298
+ def sigmas(self) -> torch.Tensor:
299
+ """Returns the sigma values, computing dynamically if trainable."""
300
+ if self.trainable_beta:
301
+ return self.sigma_start * (self.sigma_end / self.sigma_start) ** self.time
302
+ else:
303
+ return self._buffers['sigmas_buffer']
304
+
305
+ def compute_beta_schedule(self, beta_range: Tuple[float, float], num_steps: int, method: str) -> torch.Tensor:
306
+ """Computes the beta schedule based on the specified method.
307
+
308
+ Generates a sequence of beta values for the SDE noise schedule using the chosen
309
+ method, ensuring values are clamped within the specified range.
310
+
311
+ Parameters
312
+ ----------
313
+ beta_range : tuple
314
+ Tuple of (min_beta, max_beta) specifying the valid range for beta values.
315
+ num_steps : int
316
+ Number of diffusion steps.
317
+ method : str
318
+ Method for computing the beta schedule. Supported methods:
319
+ "linear", "sigmoid", "quadratic", "constant", "inverse_time".
320
+
321
+ Returns
322
+ -------
323
+ betas (torch.Tensor) - Tensor of beta values, shape (num_steps,).
324
+ """
325
+ beta_min, beta_max = beta_range
326
+ if method == "sigmoid":
327
+ x = torch.linspace(-6, 6, num_steps)
328
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
329
+ elif method == "quadratic":
330
+ x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
331
+ beta = x ** 2
332
+ elif method == "constant":
333
+ beta = torch.full((num_steps,), beta_max)
334
+ elif method == "inverse_time":
335
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
336
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
337
+ elif method == "linear":
338
+ beta = torch.linspace(beta_min, beta_max, num_steps)
339
+ else:
340
+ raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
341
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
342
+ return beta
343
+
344
+ def get_variance(self, time_steps: torch.Tensor, method: str) -> torch.Tensor:
345
+ """Computes the variance for the specified SDE method at given time steps.
346
+
347
+ Calculates the variance used in SDE diffusion processes based on the method
348
+ (VE, VP, or sub-VP), leveraging the sigma or cumulative beta schedules.
349
+
350
+ Parameters
351
+ ----------
352
+ time_steps : torch.Tensor
353
+ Tensor of time step indices (long), shape (batch_size,), where each value
354
+ is in the range [0, num_steps - 1].
355
+ method : str
356
+ SDE method to compute variance for. Supported methods: "ve", "vp", "sub-vp".
357
+
358
+ Returns
359
+ -------
360
+ variance_values (torch.Tensor) - Variance values for the specified time steps, shape (batch_size,).
361
+ """
362
+ if method == "ve":
363
+ return self.sigmas[time_steps] ** 2
364
+ elif method == "vp":
365
+ return 1 - torch.exp(-self.cum_betas[time_steps])
366
+ elif method == "sub-vp":
367
+ return 1 - torch.exp(-2 * self.cum_betas[time_steps])
368
+ else:
369
+ raise ValueError(f"Unknown method: {method}")
370
+
371
+ ###==================================================================================================================###
372
+
373
+ class TrainSDE(nn.Module):
374
+ """Trainer for score-based generative models using Stochastic Differential Equations.
375
+
376
+ Manages the training process for SDE-based generative models, optimizing a noise
377
+ predictor to learn the noise added by the forward SDE process, as described in Song
378
+ et al. (2021). Supports conditional training with text prompts, mixed precision,
379
+ learning rate scheduling, early stopping, and checkpointing.
380
+
381
+ Parameters
382
+ ----------
383
+
384
+ noise_predictor : nn.Module
385
+ Model to predict noise added during the forward SDE process.
386
+ forward_diffusion : nn.Module
387
+ Forward SDE diffusion module for adding noise.
388
+ reverse_diffusion: nn.Module
389
+ Reverse SDE diffusion module for denoising.
390
+ data_loader : torch.utils.data.DataLoader
391
+ DataLoader for training data.
392
+ optimizer : torch.optim.Optimizer
393
+ Optimizer for training the noise predictor and conditional model (if applicable).
394
+ objective : callable
395
+ Loss function to compute the difference between predicted and actual noise.
396
+ val_loader : torch.utils.data.DataLoader, optional
397
+ DataLoader for validation data, default None.
398
+ max_epochs : int, optional
399
+ Maximum number of training epochs (default: 1000).
400
+ device : torch.device, optional
401
+ Device for computation (default: CUDA if available, else CPU).
402
+ conditional_model : nn.Module, optional
403
+ Model for conditional generation (e.g., text embeddings), default None.
404
+ metrics_ : object, optional
405
+ Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
406
+ bert_tokenizer : BertTokenizer, optional
407
+ Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
408
+ max_token_length : int, optional
409
+ Maximum length for tokenized prompts (default: 77).
410
+ store_path : str, optional
411
+ Path to save model checkpoints (default: "sde_model.pth").
412
+ patience : int, optional
413
+ Number of epochs to wait for improvement before early stopping (default: 10).
414
+ warmup_epochs : int, optional
415
+ Number of epochs for learning rate warmup (default: 100).
416
+ val_frequency : int, optional
417
+ Frequency (in epochs) for validation (default: 10).
418
+ image_output_range : tuple, optional
419
+ Range for clamping generated images (default: (-1, 1)).
420
+ normalize_output : bool, optional
421
+ Whether to normalize generated images to [0, 1] for metrics (default: True).
422
+ use_ddp : bool, optional
423
+ Whether to use Distributed Data Parallel training (default: False).
424
+ grad_accumulation_steps : int, optional
425
+ Number of gradient accumulation steps before optimizer update (default: 1).
426
+ log_frequency : int, optional
427
+ Number of epochs before printing loss.
428
+ use_compilation : bool, optional
429
+ whether the model is internally compiled using torch.compile (default: false)
430
+ """
431
+ def __init__(
432
+ self,
433
+ noise_predictor: torch.nn.Module,
434
+ forward_diffusion: torch.nn.Module,
435
+ reverse_diffusion: torch.nn.Module,
436
+ data_loader: torch.utils.data.DataLoader,
437
+ optimizer: torch.optim.Optimizer,
438
+ objective: Callable,
439
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
440
+ max_epochs: int = 1000,
441
+ device: Optional[Union[str, torch.device]] = None,
442
+ conditional_model: Optional[torch.nn.Module] = None,
443
+ metrics_: Optional[Any] = None,
444
+ bert_tokenizer: Optional[BertTokenizer] = None,
445
+ max_token_length: int = 77,
446
+ store_path: Optional[str] = None,
447
+ patience: int = 100,
448
+ warmup_epochs: int = 100,
449
+ val_frequency: int = 10,
450
+ image_output_range: Tuple[float, float] = (-1.0, 1.0),
451
+ normalize_output: bool = True,
452
+ use_ddp: bool = False,
453
+ grad_accumulation_steps: int = 1,
454
+ log_frequency: int = 1,
455
+ use_compilation: bool = False
456
+ ) -> None:
457
+
458
+ super().__init__()
459
+ # initialize DDP settings first
460
+ self.use_ddp = use_ddp
461
+ self.grad_accumulation_steps = grad_accumulation_steps
462
+ if device is None:
463
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
464
+ elif isinstance(device, str):
465
+ self.device = torch.device(device)
466
+ else:
467
+ self.device = device
468
+
469
+ # setup distributed training if enabled
470
+ if self.use_ddp:
471
+ self._setup_ddp()
472
+ else:
473
+ self._setup_single_gpu()
474
+
475
+ # move models to appropriate device
476
+ self.noise_predictor = noise_predictor.to(self.device)
477
+ self.forward_diffusion = forward_diffusion.to(self.device)
478
+ self.reverse_diffusion = reverse_diffusion.to(self.device)
479
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
480
+
481
+ # training components
482
+ self.metrics_ = metrics_
483
+ self.optimizer = optimizer
484
+ self.objective = objective
485
+ self.store_path = store_path or "sde_model"
486
+ self.data_loader = data_loader
487
+ self.val_loader = val_loader
488
+ self.max_epochs = max_epochs
489
+ self.max_token_length = max_token_length
490
+ self.patience = patience
491
+ self.val_frequency = val_frequency
492
+ self.image_output_range = image_output_range
493
+ self.normalize_output = normalize_output
494
+ self.log_frequency = log_frequency
495
+ self.use_compilation = use_compilation
496
+
497
+ # learning rate scheduling
498
+ self.scheduler = ReduceLROnPlateau(
499
+ self.optimizer,
500
+ patience=self.patience,
501
+ factor=0.5
502
+ )
503
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
504
+
505
+ # initialize tokenizer
506
+ if bert_tokenizer is None:
507
+ try:
508
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
509
+ except Exception as e:
510
+ raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
511
+ else:
512
+ self.tokenizer = bert_tokenizer
513
+
514
+
515
+ def _setup_ddp(self) -> None:
516
+ """Setup Distributed Data Parallel training configuration.
517
+
518
+ Initializes process group, determines rank information, and sets up
519
+ CUDA device for the current process.
520
+ """
521
+ # check if DDP environment variables are set
522
+ if "RANK" not in os.environ:
523
+ raise ValueError("DDP enabled but RANK environment variable not set")
524
+ if "LOCAL_RANK" not in os.environ:
525
+ raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
526
+ if "WORLD_SIZE" not in os.environ:
527
+ raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
528
+
529
+ # ensure CUDA is available for DDP
530
+ if not torch.cuda.is_available():
531
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
532
+
533
+ # initialize process group only if not already initialized
534
+ if not torch.distributed.is_initialized():
535
+ init_process_group(backend="nccl")
536
+
537
+ # get rank information
538
+ self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
539
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
540
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
541
+
542
+ # set device and make it current
543
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
544
+ torch.cuda.set_device(self.device)
545
+
546
+ # master process handles logging, checkpointing, etc.
547
+ self.master_process = self.ddp_rank == 0
548
+
549
+ if self.master_process:
550
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
551
+
552
+ def _setup_single_gpu(self) -> None:
553
+ """Setup single GPU or CPU training configuration."""
554
+ self.ddp_rank = 0
555
+ self.ddp_local_rank = 0
556
+ self.ddp_world_size = 1
557
+ self.master_process = True
558
+
559
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
560
+ """Loads a training checkpoint to resume training.
561
+
562
+ Restores the state of the noise predictor, conditional model (if applicable),
563
+ and optimizer from a saved checkpoint. Handles DDP model state dict loading.
564
+
565
+ Parameters
566
+ ----------
567
+ checkpoint_path : str
568
+ Path to the checkpoint file.
569
+
570
+ Returns
571
+ -------
572
+ epoch : int
573
+ The epoch at which the checkpoint was saved.
574
+ loss : float
575
+ The loss at the checkpoint.
576
+ """
577
+ try:
578
+ # load checkpoint with proper device mapping
579
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
580
+ except FileNotFoundError:
581
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
582
+
583
+ # load noise predictor state
584
+ if 'model_state_dict_noise_predictor' not in checkpoint:
585
+ raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
586
+
587
+ # handle DDP wrapped model state dict
588
+ state_dict = checkpoint['model_state_dict_noise_predictor']
589
+ if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
590
+ # if loading non-DDP checkpoint into DDP model, add 'module.' prefix
591
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
592
+ elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
593
+ # if loading DDP checkpoint into non-DDP model, remove 'module.' prefix
594
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
595
+
596
+ self.noise_predictor.load_state_dict(state_dict)
597
+
598
+ # load conditional model state if applicable
599
+ if self.conditional_model is not None:
600
+ if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
601
+ cond_state_dict = checkpoint['model_state_dict_conditional']
602
+ # handle DDP wrapping for conditional model
603
+ if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
604
+ cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
605
+ elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
606
+ cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
607
+ self.conditional_model.load_state_dict(cond_state_dict)
608
+ else:
609
+ warnings.warn(
610
+ "Checkpoint contains no 'model_state_dict_conditional' or it is None, "
611
+ "skipping conditional model loading"
612
+ )
613
+
614
+ # load variance_scheduler state
615
+ if 'variance_scheduler_model' not in checkpoint:
616
+ raise KeyError("Checkpoint missing 'variance_scheduler_model' key")
617
+ try:
618
+ if isinstance(self.forward_diffusion.variance_scheduler, nn.Module):
619
+ self.forward_diffusion.variance_scheduler.load_state_dict(checkpoint['variance_scheduler_model'])
620
+ if isinstance(self.reverse_diffusion.variance_scheduler, nn.Module):
621
+ self.reverse_diffusion.variance_scheduler.load_state_dict(checkpoint['variance_scheduler_model'])
622
+ else:
623
+ self.forward_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
624
+ self.reverse_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
625
+ except Exception as e:
626
+ warnings.warn(f"Variance_scheduler loading failed: {e}. Continuing with current variance_scheduler.")
627
+
628
+ # load optimizer state
629
+ if 'optimizer_state_dict' not in checkpoint:
630
+ raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
631
+ try:
632
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
633
+ except ValueError as e:
634
+ warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
635
+
636
+ epoch = checkpoint.get('epoch', -1)
637
+ loss = checkpoint.get('loss', float('inf'))
638
+
639
+ if self.master_process:
640
+ print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
641
+ return epoch, loss
642
+
643
+ @staticmethod
644
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
645
+ """Creates a learning rate scheduler for warmup.
646
+
647
+ Generates a scheduler that linearly increases the learning rate from 0 to the
648
+ optimizer's initial value over the specified warmup epochs, then maintains it.
649
+
650
+ Parameters
651
+ ----------
652
+ optimizer : torch.optim.Optimizer
653
+ Optimizer to apply the scheduler to.
654
+ warmup_epochs : int
655
+ Number of epochs for the warmup phase.
656
+
657
+ Returns
658
+ -------
659
+ torch.optim.lr_scheduler.LambdaLR
660
+ Learning rate scheduler for warmup.
661
+ """
662
+
663
+ def lr_lambda(epoch):
664
+ if epoch < warmup_epochs:
665
+ return epoch / warmup_epochs
666
+ return 1.0
667
+
668
+ return LambdaLR(optimizer, lr_lambda)
669
+
670
+ def _wrap_models_for_ddp(self) -> None:
671
+ """Wrap models with DistributedDataParallel for multi-GPU training."""
672
+ if self.use_ddp:
673
+ # wrap noise predictor with DDP
674
+ self.noise_predictor = DDP(
675
+ self.noise_predictor,
676
+ device_ids=[self.ddp_local_rank],
677
+ find_unused_parameters=True
678
+ )
679
+
680
+ # wrap conditional model with DDP if it exists
681
+ if self.conditional_model is not None:
682
+ self.conditional_model = DDP(
683
+ self.conditional_model,
684
+ device_ids=[self.ddp_local_rank],
685
+ find_unused_parameters=True
686
+ )
687
+
688
+
689
+ def forward(self) -> Tuple[List, float]:
690
+ """Trains the SDE model to predict noise added by the forward diffusion process.
691
+
692
+ Executes the training loop, optimizing the noise predictor and conditional model
693
+ (if applicable) using mixed precision, gradient clipping, and learning rate
694
+ scheduling. Supports validation, early stopping, and checkpointing.
695
+
696
+ Returns
697
+ -------
698
+ train_losses : list of float
699
+ List of mean training losses per epoch.
700
+ best_val_loss : float
701
+ Best validation or training loss achieved.
702
+
703
+ **Notes**
704
+
705
+ - Training uses mixed precision via `torch.cuda.amp` or `torch.amp` for efficiency.
706
+ - Checkpoints are saved when the validation (or training) loss improves, and on early stopping.
707
+ - Early stopping is triggered if no improvement occurs for `patience` epochs.
708
+ """
709
+ # set models to training mode
710
+ self.noise_predictor.train()
711
+ if self.conditional_model is not None:
712
+ self.conditional_model.train()
713
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
714
+ self.reverse_diffusion.train()
715
+ self.forward_diffusion.train()
716
+ else:
717
+ self.reverse_diffusion.eval()
718
+ self.forward_diffusion.eval()
719
+
720
+ # compile models for optimization (if supported)
721
+ if self.use_compilation:
722
+ try:
723
+ self.noise_predictor = torch.compile(self.noise_predictor)
724
+ if self.conditional_model is not None:
725
+ self.conditional_model = torch.compile(self.conditional_model)
726
+ except Exception as e:
727
+ if self.master_process:
728
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
729
+
730
+
731
+ # wrap models for DDP after compilation
732
+ self._wrap_models_for_ddp()
733
+
734
+ # initialize training components
735
+ scaler = torch.GradScaler()
736
+ train_losses = []
737
+ best_val_loss = float("inf")
738
+ wait = 0
739
+
740
+ # main training loop
741
+ for epoch in range(self.max_epochs):
742
+ # set epoch for distributed sampler if using DDP
743
+ if self.use_ddp and hasattr(self.data_loader.sampler, 'set_epoch'):
744
+ self.data_loader.sampler.set_epoch(epoch)
745
+
746
+ train_losses_epoch = []
747
+ # training step loop with gradient accumulation
748
+ for step, (x, y) in enumerate(tqdm(self.data_loader, disable=not self.master_process)):
749
+ x = x.to(self.device)
750
+
751
+ # process conditional inputs if conditional model exists
752
+ if self.conditional_model is not None:
753
+ y_encoded = self._process_conditional_input(y)
754
+ else:
755
+ y_encoded = None
756
+
757
+ # forward pass with mixed precision
758
+ with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
759
+ # generate noise and timesteps
760
+ noise = torch.randn_like(x).to(self.device)
761
+ t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
762
+
763
+ # apply forward diffusion
764
+ noisy_x = self.forward_diffusion(x, noise, t)
765
+
766
+ # predict noise
767
+ predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
768
+
769
+ # compute loss and scale for gradient accumulation
770
+ loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
771
+
772
+ # backward pass
773
+ scaler.scale(loss).backward()
774
+
775
+ # gradient accumulation and optimizer step
776
+ if (step + 1) % self.grad_accumulation_steps == 0:
777
+ # clip gradients
778
+ scaler.unscale_(self.optimizer)
779
+ torch.nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
780
+ if self.conditional_model is not None:
781
+ torch.nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
782
+
783
+ # optimizer step
784
+ scaler.step(self.optimizer)
785
+ scaler.update()
786
+ self.optimizer.zero_grad()
787
+
788
+ # update learning rate (warmup scheduler)
789
+ self.warmup_lr_scheduler.step()
790
+
791
+ # record loss (unscaled)
792
+ train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
793
+
794
+ # compute mean training loss
795
+ mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
796
+ train_losses.append(mean_train_loss)
797
+
798
+ # all-reduce loss across processes for DDP
799
+ if self.use_ddp:
800
+ loss_tensor = torch.tensor(mean_train_loss, device=self.device)
801
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
802
+ mean_train_loss = loss_tensor.item()
803
+
804
+ # print training progress (only master process)
805
+ if self.master_process and (epoch + 1) % self.log_frequency == 0:
806
+ current_lr = self.optimizer.param_groups[0]['lr']
807
+ print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
808
+
809
+ # validation step
810
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
811
+ val_metrics = self.validate()
812
+ val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
813
+
814
+ if self.master_process:
815
+ print(f" | Val Loss: {val_loss:.4f}", end="")
816
+ if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
817
+ print(f" | FID: {fid:.4f}", end="")
818
+ if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
819
+ print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
820
+ if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
821
+ print(f" | LPIPS: {lpips_score:.4f}", end="")
822
+ print()
823
+
824
+ current_best = val_loss
825
+ self.scheduler.step(val_loss)
826
+ else:
827
+ if self.master_process:
828
+ print()
829
+ current_best = mean_train_loss
830
+ self.scheduler.step(mean_train_loss)
831
+
832
+ # save checkpoint and early stopping (only master process)
833
+ if self.master_process:
834
+ if current_best < best_val_loss and (epoch + 1) % self.val_frequency == 0:
835
+ best_val_loss = current_best
836
+ wait = 0
837
+ self._save_checkpoint(epoch + 1, best_val_loss)
838
+ else:
839
+ wait += 1
840
+ if wait >= self.patience:
841
+ print("Early stopping triggered")
842
+ self._save_checkpoint(epoch + 1, best_val_loss, "_early_stop")
843
+ break
844
+
845
+ # clean up DDP
846
+ if self.use_ddp:
847
+ destroy_process_group()
848
+
849
+ return train_losses, best_val_loss
850
+
851
+ def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor:
852
+ """Process conditional input for text-to-image generation.
853
+
854
+ Parameters
855
+ ----------
856
+ y : torch.Tensor or list
857
+ Conditional input (text prompts).
858
+
859
+ Returns
860
+ -------
861
+ torch.Tensor
862
+ Encoded conditional input.
863
+ """
864
+ # convert to string list
865
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
866
+ y_list = [str(item) for item in y_list]
867
+
868
+ # tokenize
869
+ y_encoded = self.tokenizer(
870
+ y_list,
871
+ padding="max_length",
872
+ truncation=True,
873
+ max_length=self.max_token_length,
874
+ return_tensors="pt"
875
+ ).to(self.device)
876
+
877
+ # get embeddings
878
+ input_ids = y_encoded["input_ids"]
879
+ attention_mask = y_encoded["attention_mask"]
880
+ y_encoded = self.conditional_model(input_ids, attention_mask)
881
+
882
+ return y_encoded
883
+
884
+
885
+ def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "") -> None:
886
+ """Save model checkpoint (only called by master process).
887
+
888
+ Parameters
889
+ ----------
890
+ epoch : int
891
+ Current epoch number.
892
+ loss : float
893
+ Current loss value.
894
+ suffix : str, optional
895
+ Suffix to add to checkpoint filename.
896
+ """
897
+ try:
898
+ # get state dicts, handling DDP wrapping
899
+ noise_predictor_state = (
900
+ self.noise_predictor.module.state_dict() if self.use_ddp
901
+ else self.noise_predictor.state_dict()
902
+ )
903
+ conditional_state = None
904
+ if self.conditional_model is not None:
905
+ conditional_state = (
906
+ self.conditional_model.module.state_dict() if self.use_ddp
907
+ else self.conditional_model.state_dict()
908
+ )
909
+
910
+ checkpoint = {
911
+ 'epoch': epoch,
912
+ 'model_state_dict_noise_predictor': noise_predictor_state,
913
+ 'model_state_dict_conditional': conditional_state,
914
+ 'optimizer_state_dict': self.optimizer.state_dict(),
915
+ 'loss': loss,
916
+ 'variance_scheduler_model': (
917
+ self.forward_diffusion.variance_scheduler.state_dict() if isinstance(self.forward_diffusion.variance_scheduler, nn.Module)
918
+ else self.forward_diffusion.variance_scheduler
919
+ ),
920
+ 'max_epochs': self.max_epochs,
921
+ }
922
+
923
+ filename = f"sde_epoch_{epoch}{suffix}.pth"
924
+ filepath = os.path.join(self.store_path, filename)
925
+ os.makedirs(self.store_path, exist_ok=True)
926
+ torch.save(checkpoint, filepath)
927
+
928
+ print(f"Model saved at epoch {epoch}")
929
+
930
+ except Exception as e:
931
+ print(f"Failed to save model: {e}")
932
+
933
+
934
+ def validate(self) -> Tuple[float, float, float, float, float, float]:
935
+ """Validates the noise predictor and computes evaluation Metrics.
936
+
937
+ Computes validation loss (MSE between predicted and ground truth noise) and generates
938
+ samples using the reverse diffusion model by manually iterating over timesteps.
939
+ Decodes samples to images and computes image-domain Metrics (MSE, PSNR, SSIM, FID, LPIPS)
940
+ if metrics_ is provided.
941
+
942
+ Returns
943
+ -------
944
+ val_loss : float
945
+ Mean validation loss.
946
+ fid : float, or `float('inf')` if not computed
947
+ Mean FID score.
948
+ mse : float, or None if not computed
949
+ Mean MSE
950
+ psnr : float, or None if not computed
951
+ Mean PSNR
952
+ ssim : float, or None if not computed
953
+ Mean SSIM
954
+ lpips_score : float, or None if not computed
955
+ Mean LPIPS score
956
+ """
957
+ self.noise_predictor.eval()
958
+ if self.conditional_model is not None:
959
+ self.conditional_model.eval()
960
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
961
+ self.forward_diffusion.eval()
962
+ self.reverse_diffusion.eval()
963
+
964
+ val_losses = []
965
+ fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
966
+
967
+ with torch.no_grad():
968
+ for x, y in self.val_loader:
969
+ x = x.to(self.device)
970
+ x_orig = x.clone()
971
+
972
+ # process conditional input
973
+ if self.conditional_model is not None:
974
+ y_encoded = self._process_conditional_input(y)
975
+ else:
976
+ y_encoded = None
977
+
978
+ # compute validation loss
979
+ noise = torch.randn_like(x).to(self.device)
980
+ t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
981
+
982
+ noisy_x = self.forward_diffusion(x, noise, t)
983
+ predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
984
+ loss = self.objective(predicted_noise, noise)
985
+ val_losses.append(loss.item())
986
+
987
+ # generate samples for metrics evaluation
988
+ if self.metrics_ is not None and self.reverse_diffusion is not None:
989
+ xt = torch.randn_like(x).to(self.device)
990
+
991
+ # reverse diffusion sampling
992
+ for t in reversed(range(self.forward_diffusion.variance_scheduler.num_steps)):
993
+ time_steps = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
994
+ predicted_noise = self.noise_predictor(xt, time_steps, y_encoded, None)
995
+ noise = torch.randn_like(xt) if getattr(self.reverse_diffusion, "method", None) != "ode" else None
996
+ xt = self.reverse_diffusion(xt, noise, predicted_noise, time_steps)
997
+
998
+ # clamp and normalize generated samples
999
+ x_hat = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
1000
+ if self.normalize_output:
1001
+ x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
1002
+ x_orig = (x_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
1003
+
1004
+ # compute metrics
1005
+ metrics_result = self.metrics_.forward(x_orig, x_hat)
1006
+ fid, mse, psnr, ssim, lpips_score = metrics_result
1007
+
1008
+ if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
1009
+ fid_scores.append(fid)
1010
+ if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
1011
+ mse_scores.append(mse)
1012
+ psnr_scores.append(psnr)
1013
+ ssim_scores.append(ssim)
1014
+ if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
1015
+ lpips_scores.append(lpips_score)
1016
+
1017
+ # compute average metrics
1018
+ val_loss = torch.tensor(val_losses).mean().item()
1019
+
1020
+ # all-reduce validation metrics across processes for DDP
1021
+ if self.use_ddp:
1022
+ val_loss_tensor = torch.tensor(val_loss, device=self.device)
1023
+ dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
1024
+ val_loss = val_loss_tensor.item()
1025
+
1026
+ fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
1027
+ mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
1028
+ psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
1029
+ ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
1030
+ lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
1031
+
1032
+ # return to training mode
1033
+ self.noise_predictor.train()
1034
+ if self.conditional_model is not None:
1035
+ self.conditional_model.train()
1036
+ if self.forward_diffusion.variance_scheduler.trainable_beta:
1037
+ self.reverse_diffusion.train()
1038
+ self.forward_diffusion.train()
1039
+
1040
+ return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
1041
+
1042
+
1043
+ ###==================================================================================================================###
1044
+
1045
+ class SampleSDE(nn.Module):
1046
+ """Sampler for generating images using SDE-based generative models.
1047
+
1048
+ Generates images by iteratively denoising random noise using the reverse SDE process
1049
+ and a trained noise predictor, as described in Song et al. (2021). Supports both
1050
+ unconditional and conditional generation with text prompts.
1051
+
1052
+ Parameters
1053
+ ----------
1054
+ reverse_diffusion : ReverseSDE
1055
+ Reverse SDE diffusion module for denoising.
1056
+ noise_predictor : nn.Module
1057
+ Model to predict noise added during the forward SDE process.
1058
+ image_shape : tuple
1059
+ Shape of generated images as (height, width).
1060
+ conditional_model : nn.Module, optional
1061
+ Model for conditional generation (e.g., TextEncoder), default None.
1062
+ tokenizer : str or BertTokenizer, optional
1063
+ Tokenizer for processing text prompts, default "bert-base-uncased".
1064
+ max_token_length : int, optional
1065
+ Maximum length for tokenized prompts (default: 77).
1066
+ batch_size : int, optional
1067
+ Number of images to generate per batch (default: 1).
1068
+ in_channels : int, optional
1069
+ Number of input channels for generated images (default: 3).
1070
+ device : torch.device, optional
1071
+ Device for computation (default: CUDA if available, else CPU).
1072
+ image_output_range : tuple, optional
1073
+ Range for clamping generated images (min, max), default (-1, 1).
1074
+ """
1075
+ def __init__(
1076
+ self,
1077
+ reverse_diffusion: torch.nn.Module,
1078
+ noise_predictor: torch.nn.Module,
1079
+ image_shape: Tuple[int, int],
1080
+ conditional_model: Optional[torch.nn.Module] = None,
1081
+ tokenizer: str = "bert-base-uncased",
1082
+ max_token_length: int = 77,
1083
+ batch_size: int = 1,
1084
+ in_channels: int = 3,
1085
+ device: Optional[Union[str, torch.device]] = None,
1086
+ image_output_range: Tuple[float, float] = (-1.0, 1.0)
1087
+ ) -> None:
1088
+ super().__init__()
1089
+ if device is None:
1090
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1091
+ elif isinstance(device, str):
1092
+ self.device = torch.device(device)
1093
+ else:
1094
+ self.device = device
1095
+ self.reverse = reverse_diffusion.to(self.device)
1096
+ self.noise_predictor = noise_predictor.to(self.device)
1097
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
1098
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
1099
+ self.max_token_length = max_token_length
1100
+ self.in_channels = in_channels
1101
+ self.image_shape = image_shape
1102
+ self.batch_size = batch_size
1103
+ self.image_output_range = image_output_range
1104
+
1105
+ if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(isinstance(s, int) and s > 0 for s in image_shape):
1106
+ raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
1107
+ if batch_size <= 0:
1108
+ raise ValueError("batch_size must be positive")
1109
+ if not isinstance(image_output_range, (tuple, list)) or len(image_output_range) != 2 or image_output_range[0] >= image_output_range[1]:
1110
+ raise ValueError("output_range must be a tuple (min, max) with min < max")
1111
+
1112
+ def tokenize(self, prompts: Union[str, List]) -> Tuple[torch.Tensor, torch.Tensor]:
1113
+ """Tokenizes text prompts for conditional generation.
1114
+
1115
+ Converts input prompts into tokenized tensors using the specified tokenizer.
1116
+
1117
+ Parameters
1118
+ ----------
1119
+ prompts : str or list
1120
+ Text prompt(s) for conditional generation. Can be a single string or a list
1121
+ of strings.
1122
+
1123
+ Returns
1124
+ -------
1125
+ input_ids : torch.Tensor
1126
+ Tokenized input IDs, shape (batch_size, max_token_length).
1127
+ attention_mask : torch.Tensor
1128
+ Attention mask, shape (batch_size, max_token_length).
1129
+ """
1130
+ if isinstance(prompts, str):
1131
+ prompts = [prompts]
1132
+ elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
1133
+ raise TypeError("prompts must be a string or list of strings")
1134
+ encoded = self.tokenizer(
1135
+ prompts,
1136
+ padding="max_length",
1137
+ truncation=True,
1138
+ max_length=self.max_token_length,
1139
+ return_tensors="pt"
1140
+ )
1141
+ return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
1142
+
1143
+ def forward(
1144
+ self,
1145
+ conditions: Optional[Union[str, List]] = None,
1146
+ normalize_output: bool = True,
1147
+ save_images: bool = True,
1148
+ save_path: str = "sde_generated"
1149
+ ) -> torch.Tensor:
1150
+ """Generates images using the reverse SDE sampling process.
1151
+
1152
+ Iteratively denoises random noise to generate images using the reverse SDE process
1153
+ and noise predictor. Supports conditional generation with text prompts.
1154
+
1155
+ Parameters
1156
+ ----------
1157
+ conditions : str or list, optional
1158
+ Text prompt(s) for conditional generation, default None.
1159
+ normalize_output : bool, optional
1160
+ If True, normalizes output images to [0, 1] (default: True).
1161
+ save_images : bool, optional
1162
+ If True, saves generated images to `save_path` (default: True).
1163
+ save_path : str, optional
1164
+ Directory to save generated images (default: "sde_generated").
1165
+
1166
+ Returns
1167
+ -------
1168
+ generated_imgs (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width). If `normalize_output` is True, images are normalized to [0, 1]; otherwise, they are clamped to `output_range`.
1169
+ """
1170
+ if conditions is not None and self.conditional_model is None:
1171
+ raise ValueError("Conditions provided but no conditional model specified")
1172
+ if conditions is None and self.conditional_model is not None:
1173
+ raise ValueError("Conditions must be provided for conditional model")
1174
+
1175
+ noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
1176
+
1177
+ self.noise_predictor.eval()
1178
+ self.reverse.eval()
1179
+ if self.conditional_model:
1180
+ self.conditional_model.eval()
1181
+
1182
+ with torch.no_grad():
1183
+ xt = noisy_samples
1184
+ for t in reversed(range(self.reverse.variance_scheduler.num_steps)):
1185
+ noise = torch.randn_like(xt) if self.reverse.sde_method != "ode" else None
1186
+ time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
1187
+
1188
+ if self.conditional_model is not None and conditions is not None:
1189
+ input_ids, attention_masks = self.tokenize(conditions)
1190
+ key_padding_mask = (attention_masks == 0)
1191
+ y = self.conditional_model(input_ids, key_padding_mask)
1192
+ predicted_noise = self.noise_predictor(xt, time_steps, y)
1193
+ else:
1194
+ predicted_noise = self.noise_predictor(xt, time_steps)
1195
+
1196
+ xt = self.reverse(xt, noise, predicted_noise, time_steps)
1197
+
1198
+ generated_imgs = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
1199
+ if normalize_output:
1200
+ generated_imgs = (generated_imgs - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
1201
+
1202
+ # save images if save_images is True
1203
+ if save_images:
1204
+ os.makedirs(save_path, exist_ok=True)
1205
+ for i in range(generated_imgs.size(0)):
1206
+ img_path = os.path.join(save_path, f"image_{i+1}.png")
1207
+ save_image(generated_imgs[i], img_path)
1208
+
1209
+ return generated_imgs
1210
+
1211
+ def to(self, device: torch.device) -> Self:
1212
+ """Moves the module and its components to the specified device.
1213
+
1214
+ Updates the device attribute and moves the reverse diffusion, noise predictor,
1215
+ and conditional model (if present) to the specified device.
1216
+
1217
+ Parameters
1218
+ ----------
1219
+ device : torch.device
1220
+ Target device for the module and its components.
1221
+
1222
+ Returns
1223
+ -------
1224
+ sample_sde (SampleSDE) - moved to the specified device.
1225
+ """
1226
+ self.device = device
1227
+ self.noise_predictor.to(device)
1228
+ self.reverse.to(device)
1229
+ if self.conditional_model:
1230
+ self.conditional_model.to(device)
1231
+ return super().to(device)