TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
torchdiff/ddim.py
ADDED
|
@@ -0,0 +1,1222 @@
|
|
|
1
|
+
"""
|
|
2
|
+
**Denoising Diffusion Implicit Models (DDIM)**
|
|
3
|
+
|
|
4
|
+
This module provides a complete implementation of DDIM, as described in Song et al.
|
|
5
|
+
(2021, "Denoising Diffusion Implicit Models"). It includes components for forward and
|
|
6
|
+
reverse diffusion processes, hyperparameter management, training, and image sampling.
|
|
7
|
+
Supports both unconditional and conditional generation with text prompts, using a
|
|
8
|
+
subsampled time step schedule for faster sampling compared to DDPM.
|
|
9
|
+
|
|
10
|
+
**Components**
|
|
11
|
+
|
|
12
|
+
- **ForwardDDIM**: Forward diffusion process to add noise.
|
|
13
|
+
- **ReverseDDIM**: Reverse diffusion process to denoise with subsampled steps.
|
|
14
|
+
- **VarianceSchedulerDDIM**: Noise schedule management with subsampled (tau) schedule.
|
|
15
|
+
- **TrainDDIM**: Training loop with mixed precision and scheduling.
|
|
16
|
+
- **SampleDDIM**: Image generation from trained models with subsampled steps.
|
|
17
|
+
|
|
18
|
+
**Notes**
|
|
19
|
+
|
|
20
|
+
- The subsampled time step schedule (tau) enables faster sampling, controlled by the
|
|
21
|
+
`tau_num_steps` parameter in HyperParamsDDIM.
|
|
22
|
+
|
|
23
|
+
**References**:
|
|
24
|
+
|
|
25
|
+
- Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising diffusion implicit models." arXiv preprint arXiv:2010.02502 (2020).
|
|
26
|
+
|
|
27
|
+
-------------------------------------------------------------------------------
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
import torch
|
|
32
|
+
import torch.nn as nn
|
|
33
|
+
import torch.distributed as dist
|
|
34
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
35
|
+
from torch.distributed import init_process_group, destroy_process_group
|
|
36
|
+
from tqdm import tqdm
|
|
37
|
+
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
|
38
|
+
from transformers import BertTokenizer
|
|
39
|
+
import warnings
|
|
40
|
+
from torchvision.utils import save_image
|
|
41
|
+
from typing import Optional, Tuple, Callable, List, Any, Union, Self
|
|
42
|
+
import os
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
###==================================================================================================================###
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ForwardDDIM(nn.Module):
|
|
49
|
+
"""Forward diffusion process of DDIM.
|
|
50
|
+
|
|
51
|
+
Implements the forward diffusion process for Denoising Diffusion Implicit Models (DDIM),
|
|
52
|
+
which perturbs input data by adding Gaussian noise over a series of time steps,
|
|
53
|
+
as defined in Song et al. (2021, "Denoising Diffusion Implicit Models").
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
`variance_scheduler` : object
|
|
58
|
+
Hyperparameter object (VarianceSchedulerDDIM) containing the noise schedule parameters.
|
|
59
|
+
Expected to have attributes: `num_steps`, `trainable_beta`, `betas`, `sqrt_alpha_cumprod`,
|
|
60
|
+
`sqrt_one_minus_alpha_cumprod`, `compute_schedule`
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, variance_scheduler: torch.nn.Module) -> None:
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.variance_scheduler = variance_scheduler
|
|
66
|
+
|
|
67
|
+
def forward(self, x0: torch.Tensor, noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
|
|
68
|
+
"""Applies the forward diffusion process to the input data.
|
|
69
|
+
|
|
70
|
+
Perturbs the input data `x0` by adding Gaussian noise according to the DDIM
|
|
71
|
+
forward process at specified time steps, using cumulative noise schedule parameters.
|
|
72
|
+
|
|
73
|
+
Parameters
|
|
74
|
+
----------
|
|
75
|
+
`x0` : torch.Tensor
|
|
76
|
+
Input data tensor of shape (batch_size, channels, height, width).
|
|
77
|
+
`noise` : torch.Tensor
|
|
78
|
+
Gaussian noise tensor of the same shape as `x0`.
|
|
79
|
+
`time_steps` : torch.Tensor
|
|
80
|
+
Tensor of time step indices (long), shape (batch_size,),
|
|
81
|
+
where each value is in the range [0, hyper_params.num_steps - 1].
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
xt (torch.Tensor) - Noisy data tensor `xt` at the specified time steps, with the same shape as `x0`.
|
|
86
|
+
"""
|
|
87
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.num_steps)):
|
|
88
|
+
raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.num_steps - 1}")
|
|
89
|
+
|
|
90
|
+
if self.variance_scheduler.trainable_beta:
|
|
91
|
+
_, _, _, sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.compute_schedule(
|
|
92
|
+
time_steps
|
|
93
|
+
)
|
|
94
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.to(x0.device)
|
|
95
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.to(x0.device)
|
|
96
|
+
else:
|
|
97
|
+
sqrt_alpha_cumprod_t = self.variance_scheduler.sqrt_alpha_cumprod[time_steps].to(x0.device)
|
|
98
|
+
sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.sqrt_one_minus_alpha_cumprod[time_steps].to(x0.device)
|
|
99
|
+
|
|
100
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1, 1, 1)
|
|
101
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1, 1, 1)
|
|
102
|
+
|
|
103
|
+
xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
|
|
104
|
+
|
|
105
|
+
return xt
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
###==================================================================================================================###
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class ReverseDDIM(nn.Module):
|
|
112
|
+
"""Reverse diffusion process of DDIM.
|
|
113
|
+
|
|
114
|
+
Implements the reverse diffusion process for Denoising Diffusion Implicit Models
|
|
115
|
+
(DDIM), which denoises a noisy input `xt` using a predicted noise component and a
|
|
116
|
+
subsampled time step schedule, as defined in Song et al. (2021).
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
`variance_scheduler` : object
|
|
121
|
+
Hyperparameter object (VarianceSchedulerDDIM) containing the noise schedule parameters.
|
|
122
|
+
Expected to have attributes: `tau_num_steps`, `eta`, `get_tau_schedule`.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(self, variance_scheduler: torch.nn.Module):
|
|
126
|
+
super().__init__()
|
|
127
|
+
self.variance_scheduler = variance_scheduler
|
|
128
|
+
|
|
129
|
+
def forward(self, xt: torch.Tensor, predicted_noise: torch.Tensor, time_steps: torch.Tensor, prev_time_steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
130
|
+
"""Applies the reverse diffusion process to the noisy input.
|
|
131
|
+
|
|
132
|
+
Denoises the input `xt` at time step `t` to produce the previous step `xt_prev`
|
|
133
|
+
at `prev_time_steps` using the predicted noise and the DDIM reverse process.
|
|
134
|
+
Optionally includes stochastic noise scaled by `eta`.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
`xt` : torch.Tensor
|
|
139
|
+
Noisy input tensor at time step `t`, shape (batch_size, channels, height, width).
|
|
140
|
+
`predicted_noise` : torch.Tensor
|
|
141
|
+
Predicted noise tensor, same shape as `xt`, typically output by a neural network.
|
|
142
|
+
`time_steps` : torch.Tensor
|
|
143
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
144
|
+
is in the range [0, hyper_params.tau_num_steps - 1].
|
|
145
|
+
`prev_time_steps` : torch.Tensor
|
|
146
|
+
Tensor of previous time step indices (long), shape (batch_size,), where each
|
|
147
|
+
value is in the range [0, hyper_params.tau_num_steps - 1].
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
xt_prev : torch.Tensor
|
|
152
|
+
Denoised tensor at `prev_time_steps`, same shape as `xt`.
|
|
153
|
+
x0 : torch.Tensor
|
|
154
|
+
Estimated original data (t=0), same shape as `xt`.
|
|
155
|
+
"""
|
|
156
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.tau_num_steps)):
|
|
157
|
+
raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
|
|
158
|
+
if not torch.all((prev_time_steps >= 0) & (prev_time_steps < self.variance_scheduler.tau_num_steps)):
|
|
159
|
+
raise ValueError(f"prev_time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
|
|
160
|
+
|
|
161
|
+
_, _, _, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = self.variance_scheduler.get_tau_schedule()
|
|
162
|
+
tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
163
|
+
tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
164
|
+
prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
165
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
166
|
+
|
|
167
|
+
eta = self.variance_scheduler.eta
|
|
168
|
+
x0 = (xt - tau_sqrt_one_minus_alpha_cumprod_t * predicted_noise) / tau_sqrt_alpha_cumprod_t
|
|
169
|
+
noise_coeff = eta * ((tau_sqrt_one_minus_alpha_cumprod_t / prev_tau_sqrt_alpha_cumprod_t) *
|
|
170
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t / torch.clamp(tau_sqrt_one_minus_alpha_cumprod_t, min=1e-8))
|
|
171
|
+
direction_coeff = torch.clamp(prev_tau_sqrt_one_minus_alpha_cumprod_t ** 2 - noise_coeff ** 2, min=1e-8).sqrt()
|
|
172
|
+
xt_prev = prev_tau_sqrt_alpha_cumprod_t * x0 + noise_coeff * torch.randn_like(xt) + direction_coeff * predicted_noise
|
|
173
|
+
|
|
174
|
+
return xt_prev, x0
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
###==================================================================================================================###
|
|
178
|
+
|
|
179
|
+
class VarianceSchedulerDDIM(nn.Module):
|
|
180
|
+
"""Hyperparameters for DDIM noise schedule with flexible beta computation.
|
|
181
|
+
|
|
182
|
+
Manages the noise schedule parameters for DDIM, including beta values, derived
|
|
183
|
+
quantities (alphas, alpha_cumprod, etc.), and a subsampled time step schedule
|
|
184
|
+
(tau schedule), as inspired by Song et al. (2021). Supports trainable or fixed
|
|
185
|
+
schedules and various beta scheduling methods.
|
|
186
|
+
|
|
187
|
+
Parameters
|
|
188
|
+
----------
|
|
189
|
+
`eta` : float, optional
|
|
190
|
+
Noise scaling factor for the DDIM reverse process (default: 0, deterministic).
|
|
191
|
+
`num_steps` : int, optional
|
|
192
|
+
Total number of diffusion steps (default: 1000).
|
|
193
|
+
`tau_num_steps` : int, optional
|
|
194
|
+
Number of subsampled time steps for DDIM sampling (default: 100).
|
|
195
|
+
`beta_start` : float, optional
|
|
196
|
+
Starting value for beta (default: 1e-4).
|
|
197
|
+
`beta_end` : float, optional
|
|
198
|
+
Ending value for beta (default: 0.02).
|
|
199
|
+
`trainable_beta` : bool, optional
|
|
200
|
+
Whether the beta schedule is trainable (default: False).
|
|
201
|
+
`beta_method` : str, optional
|
|
202
|
+
Method for computing the beta schedule (default: "linear").
|
|
203
|
+
Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
eta: Optional[float] = None,
|
|
209
|
+
num_steps: int = 1000,
|
|
210
|
+
tau_num_steps: int = 100,
|
|
211
|
+
beta_start: float = 1e-4,
|
|
212
|
+
beta_end: float = 0.02,
|
|
213
|
+
trainable_beta: bool = False,
|
|
214
|
+
beta_method: str = "linear"
|
|
215
|
+
):
|
|
216
|
+
super().__init__()
|
|
217
|
+
self.eta = eta or 0
|
|
218
|
+
self.num_steps = num_steps
|
|
219
|
+
self.tau_num_steps = tau_num_steps
|
|
220
|
+
self.beta_start = beta_start
|
|
221
|
+
self.beta_end = beta_end
|
|
222
|
+
self.trainable_beta = trainable_beta
|
|
223
|
+
self.beta_method = beta_method
|
|
224
|
+
|
|
225
|
+
if not (0 < beta_start < beta_end < 1):
|
|
226
|
+
raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
|
|
227
|
+
if num_steps <= 0:
|
|
228
|
+
raise ValueError(f"num_steps ({num_steps}) must be positive")
|
|
229
|
+
|
|
230
|
+
beta_range = (beta_start, beta_end)
|
|
231
|
+
betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
|
|
232
|
+
|
|
233
|
+
if trainable_beta:
|
|
234
|
+
# Use reparameterization trick for trainable betas
|
|
235
|
+
# Initialize unconstrained parameters and transform them to valid beta range
|
|
236
|
+
self.beta_raw = nn.Parameter(torch.logit((betas_init - beta_start) / (beta_end - beta_start)))
|
|
237
|
+
else:
|
|
238
|
+
self.register_buffer('betas_buffer', betas_init)
|
|
239
|
+
self.register_buffer('alphas', 1 - self.betas)
|
|
240
|
+
self.register_buffer('alpha_cumprod', torch.cumprod(self.alphas, dim=0))
|
|
241
|
+
self.register_buffer('sqrt_alpha_cumprod', torch.sqrt(self.alpha_cumprod))
|
|
242
|
+
self.register_buffer('sqrt_one_minus_alpha_cumprod', torch.sqrt(1 - self.alpha_cumprod))
|
|
243
|
+
|
|
244
|
+
self.register_buffer('tau_indices', torch.linspace(0, num_steps - 1, tau_num_steps, dtype=torch.long))
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@property
|
|
248
|
+
def betas(self) -> torch.Tensor:
|
|
249
|
+
"""Returns the beta values, applying reparameterization if trainable."""
|
|
250
|
+
if self.trainable_beta:
|
|
251
|
+
# Transform unconstrained parameters to valid beta range using sigmoid
|
|
252
|
+
return self.beta_start + (self.beta_end - self.beta_start) * torch.sigmoid(self.beta_raw)
|
|
253
|
+
# Return the registered buffer directly if it exists
|
|
254
|
+
#return getattr(self, '_buffers', {}).get('betas_buffer', None) or ValueError("Betas buffer not found")
|
|
255
|
+
return self._buffers['betas_buffer']
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def compute_beta_schedule(self, beta_range: Tuple[float, float], num_steps: int, method: str) -> torch.Tensor:
|
|
259
|
+
"""Computes the beta schedule based on the specified method.
|
|
260
|
+
|
|
261
|
+
Generates a sequence of beta values for the DDIM noise schedule using the
|
|
262
|
+
chosen method, ensuring values are clamped within the specified range.
|
|
263
|
+
|
|
264
|
+
Parameters
|
|
265
|
+
----------
|
|
266
|
+
`beta_range` : tuple
|
|
267
|
+
Tuple of (min_beta, max_beta) specifying the valid range for beta values.
|
|
268
|
+
`num_steps` : int
|
|
269
|
+
Number of diffusion steps.
|
|
270
|
+
`method` : str
|
|
271
|
+
Method for computing the beta schedule. Supported methods:
|
|
272
|
+
"linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
273
|
+
|
|
274
|
+
Returns
|
|
275
|
+
-------
|
|
276
|
+
beta (torch.Tensor) - Tensor of beta values, shape (num_steps,).
|
|
277
|
+
"""
|
|
278
|
+
beta_min, beta_max = beta_range
|
|
279
|
+
if method == "sigmoid":
|
|
280
|
+
x = torch.linspace(-6, 6, num_steps)
|
|
281
|
+
beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
|
|
282
|
+
elif method == "quadratic":
|
|
283
|
+
x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
|
|
284
|
+
beta = x ** 2
|
|
285
|
+
elif method == "constant":
|
|
286
|
+
beta = torch.full((num_steps,), beta_max)
|
|
287
|
+
elif method == "inverse_time":
|
|
288
|
+
beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
|
|
289
|
+
beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
|
|
290
|
+
elif method == "linear":
|
|
291
|
+
beta = torch.linspace(beta_min, beta_max, num_steps)
|
|
292
|
+
else:
|
|
293
|
+
raise ValueError(
|
|
294
|
+
f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
|
|
295
|
+
|
|
296
|
+
beta = torch.clamp(beta, min=beta_min, max=beta_max)
|
|
297
|
+
return beta
|
|
298
|
+
|
|
299
|
+
def get_tau_schedule(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
300
|
+
"""Computes the subsampled (tau) noise schedule for DDIM.
|
|
301
|
+
|
|
302
|
+
Returns the noise schedule parameters for the subsampled time steps used in
|
|
303
|
+
DDIM sampling, based on the `tau_indices`.
|
|
304
|
+
|
|
305
|
+
Returns
|
|
306
|
+
-------
|
|
307
|
+
tau_betas : torch.Tensor
|
|
308
|
+
Beta values for subsampled steps, shape (tau_num_steps,).
|
|
309
|
+
tau_alphas : torch.Tensor
|
|
310
|
+
Alpha values for subsampled steps, shape (tau_num_steps,).
|
|
311
|
+
tau_alpha_cumprod : torch.Tensor
|
|
312
|
+
Cumulative product of alphas for subsampled steps, shape (tau_num_steps,).
|
|
313
|
+
tau_sqrt_alpha_cumprod : torch.Tensor
|
|
314
|
+
Square root of alpha_cumprod for subsampled steps, shape (tau_num_steps,).
|
|
315
|
+
tau_sqrt_one_minus_alpha_cumprod : torch.Tensor
|
|
316
|
+
Square root of (1 - alpha_cumprod) for subsampled steps, shape (tau_num_steps,).
|
|
317
|
+
"""
|
|
318
|
+
if self.trainable_beta:
|
|
319
|
+
# Use the property to get constrained betas
|
|
320
|
+
betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod = self.compute_schedule()
|
|
321
|
+
else:
|
|
322
|
+
betas = self.betas
|
|
323
|
+
alphas = self.alphas
|
|
324
|
+
alpha_cumprod = self.alpha_cumprod
|
|
325
|
+
sqrt_alpha_cumprod = self.sqrt_alpha_cumprod
|
|
326
|
+
sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod
|
|
327
|
+
|
|
328
|
+
tau_betas = betas[self.tau_indices]
|
|
329
|
+
tau_alphas = alphas[self.tau_indices]
|
|
330
|
+
tau_alpha_cumprod = alpha_cumprod[self.tau_indices]
|
|
331
|
+
tau_sqrt_alpha_cumprod = sqrt_alpha_cumprod[self.tau_indices]
|
|
332
|
+
tau_sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[self.tau_indices]
|
|
333
|
+
|
|
334
|
+
return tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod
|
|
335
|
+
|
|
336
|
+
def compute_schedule(self, time_steps: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
337
|
+
"""Computes noise schedule parameters dynamically from betas.
|
|
338
|
+
|
|
339
|
+
Calculates the derived noise schedule parameters (alphas, alpha_cumprod, etc.)
|
|
340
|
+
from the provided beta values, as used in the DDIM forward and reverse processes.
|
|
341
|
+
|
|
342
|
+
Parameters
|
|
343
|
+
----------
|
|
344
|
+
`time_steps` : torch.Tensor, optional
|
|
345
|
+
If provided, returns parameters only for specified time steps.
|
|
346
|
+
If None, returns parameters for all time steps.
|
|
347
|
+
|
|
348
|
+
Returns
|
|
349
|
+
-------
|
|
350
|
+
betas : torch.Tensor
|
|
351
|
+
Beta values, shape (num_steps,) or (len(time_steps),).
|
|
352
|
+
alphas : torch.Tensor
|
|
353
|
+
1 - betas, shape (num_steps,) or (len(time_steps),).
|
|
354
|
+
alpha_cumprod : torch.Tensor
|
|
355
|
+
Cumulative product of alphas, shape (num_steps,) or (len(time_steps),).
|
|
356
|
+
sqrt_alpha_cumprod : torch.Tensor
|
|
357
|
+
Square root of alpha_cumprod, shape (num_steps,) or (len(time_steps),).
|
|
358
|
+
sqrt_one_minus_alpha_cumprod : torch.Tensor
|
|
359
|
+
Square root of (1 - alpha_cumprod), shape (num_steps,) or (len(time_steps),).
|
|
360
|
+
"""
|
|
361
|
+
# Use the property to get constrained betas
|
|
362
|
+
betas = self.betas
|
|
363
|
+
alphas = 1 - betas
|
|
364
|
+
alpha_cumprod = torch.cumprod(alphas, dim=0)
|
|
365
|
+
sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod)
|
|
366
|
+
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alpha_cumprod)
|
|
367
|
+
|
|
368
|
+
if time_steps is not None:
|
|
369
|
+
return (betas[time_steps], alphas[time_steps], alpha_cumprod[time_steps],
|
|
370
|
+
sqrt_alpha_cumprod[time_steps], sqrt_one_minus_alpha_cumprod[time_steps])
|
|
371
|
+
else:
|
|
372
|
+
return betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
###==================================================================================================================###
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class TrainDDIM(nn.Module):
|
|
379
|
+
"""Trainer for Denoising Diffusion Implicit Models (DDIM).
|
|
380
|
+
|
|
381
|
+
Manages the training process for DDIM, optimizing a noise predictor model to learn
|
|
382
|
+
the noise added by the forward diffusion process. Supports conditional training with
|
|
383
|
+
text prompts, mixed precision training, learning rate scheduling, early stopping, and
|
|
384
|
+
checkpointing, as inspired by Song et al. (2021).
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
`noise_predictor` : nn.Module
|
|
389
|
+
Model to predict noise added during the forward diffusion process.
|
|
390
|
+
forward_diffusion : nn.Module
|
|
391
|
+
Forward DDIM diffusion module for adding noise.
|
|
392
|
+
reverse_diffusion: nn.Module
|
|
393
|
+
Reverse DDIM diffusion module for denoising.
|
|
394
|
+
`data_loader` : torch.utils.data.DataLoader
|
|
395
|
+
DataLoader for training data.
|
|
396
|
+
`optimizer` : torch.optim.Optimizer
|
|
397
|
+
Optimizer for training the noise predictor and conditional model (if applicable).
|
|
398
|
+
`objective` : callable
|
|
399
|
+
Loss function to compute the difference between predicted and actual noise.
|
|
400
|
+
`val_loader` : torch.utils.data.DataLoader, optional
|
|
401
|
+
DataLoader for validation data, default None.
|
|
402
|
+
`max_epochs` : int, optional
|
|
403
|
+
Maximum number of training epochs (default: 1000).
|
|
404
|
+
`device` : torch.device, optional
|
|
405
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
406
|
+
`conditional_model` : nn.Module, optional
|
|
407
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
408
|
+
`metrics_` : object, optional
|
|
409
|
+
Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
|
|
410
|
+
`bert_tokenizer` : BertTokenizer, optional
|
|
411
|
+
Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
|
|
412
|
+
`max_token_length` : int, optional
|
|
413
|
+
Maximum length for tokenized prompts (default: 77).
|
|
414
|
+
`store_path` : str, optional
|
|
415
|
+
Path to save model checkpoints (default: "ddim_model.pth").
|
|
416
|
+
`patience` : int, optional
|
|
417
|
+
Number of epochs to wait for improvement before early stopping (default: 100).
|
|
418
|
+
`warmup_epochs` : int, optional
|
|
419
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
420
|
+
`val_frequency` : int, optional
|
|
421
|
+
Frequency (in epochs) for validation (default: 10).
|
|
422
|
+
`output_range` : tuple, optional
|
|
423
|
+
Range for clamping generated images (default: (-1, 1)).
|
|
424
|
+
`normalize_output` : bool, optional
|
|
425
|
+
Whether to normalize generated images to [0, 1] for metrics (default: True).
|
|
426
|
+
`use_ddp` : bool, optional
|
|
427
|
+
Whether to use Distributed Data Parallel training (default: False).
|
|
428
|
+
`grad_accumulation_steps` : int, optional
|
|
429
|
+
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
430
|
+
`log_frequency` : int, optional
|
|
431
|
+
Number of epochs before printing loss.
|
|
432
|
+
use_compilation : bool, optional
|
|
433
|
+
whether the model is internally compiled using torch.compile (default: false)
|
|
434
|
+
"""
|
|
435
|
+
def __init__(
|
|
436
|
+
self,
|
|
437
|
+
noise_predictor: torch.nn.Module,
|
|
438
|
+
forward_diffusion: torch.nn.Module,
|
|
439
|
+
reverse_diffusion: torch.nn.Module,
|
|
440
|
+
data_loader: torch.utils.data.DataLoader,
|
|
441
|
+
optimizer: torch.optim.Optimizer,
|
|
442
|
+
objective: Callable,
|
|
443
|
+
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
444
|
+
max_epochs: int = 1000,
|
|
445
|
+
device: str = None,
|
|
446
|
+
conditional_model: torch.nn.Module = None,
|
|
447
|
+
metrics_: Optional[Any] = None,
|
|
448
|
+
bert_tokenizer: Optional[BertTokenizer] = None,
|
|
449
|
+
max_token_length: int = 77,
|
|
450
|
+
store_path: Optional[str] = None,
|
|
451
|
+
patience: int = 100,
|
|
452
|
+
warmup_epochs: int = 100,
|
|
453
|
+
val_frequency: int = 10,
|
|
454
|
+
image_output_range: Tuple[float, float] = (-1, 1),
|
|
455
|
+
normalize_output: bool = True,
|
|
456
|
+
use_ddp: bool = False,
|
|
457
|
+
grad_accumulation_steps: int = 1,
|
|
458
|
+
log_frequency: int = 1,
|
|
459
|
+
use_compilation: bool = False
|
|
460
|
+
) -> None:
|
|
461
|
+
super().__init__()
|
|
462
|
+
# initialize DDP settings first
|
|
463
|
+
self.use_ddp = use_ddp
|
|
464
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
465
|
+
if device is None:
|
|
466
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
467
|
+
elif isinstance(device, str):
|
|
468
|
+
self.device = torch.device(device)
|
|
469
|
+
else:
|
|
470
|
+
self.device = device
|
|
471
|
+
|
|
472
|
+
# setup distributed training if enabled
|
|
473
|
+
if self.use_ddp:
|
|
474
|
+
self._setup_ddp()
|
|
475
|
+
else:
|
|
476
|
+
self._setup_single_gpu()
|
|
477
|
+
|
|
478
|
+
# move models to appropriate device
|
|
479
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
480
|
+
self.forward_diffusion = forward_diffusion.to(self.device)
|
|
481
|
+
self.reverse_diffusion = reverse_diffusion.to(self.device)
|
|
482
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
483
|
+
|
|
484
|
+
# training components
|
|
485
|
+
self.metrics_ = metrics_
|
|
486
|
+
self.optimizer = optimizer
|
|
487
|
+
self.objective = objective
|
|
488
|
+
self.store_path = store_path or "ddim_model"
|
|
489
|
+
self.data_loader = data_loader
|
|
490
|
+
self.val_loader = val_loader
|
|
491
|
+
self.max_epochs = max_epochs
|
|
492
|
+
self.max_token_length = max_token_length
|
|
493
|
+
self.patience = patience
|
|
494
|
+
self.val_frequency = val_frequency
|
|
495
|
+
self.image_output_range = image_output_range
|
|
496
|
+
self.normalize_output = normalize_output
|
|
497
|
+
self.log_frequency = log_frequency
|
|
498
|
+
self.use_compilation = use_compilation
|
|
499
|
+
|
|
500
|
+
# learning rate scheduling
|
|
501
|
+
self.scheduler = ReduceLROnPlateau(
|
|
502
|
+
self.optimizer,
|
|
503
|
+
patience=self.patience,
|
|
504
|
+
factor=0.5
|
|
505
|
+
)
|
|
506
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
507
|
+
|
|
508
|
+
# initialize tokenizer
|
|
509
|
+
if bert_tokenizer is None:
|
|
510
|
+
try:
|
|
511
|
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
512
|
+
except Exception as e:
|
|
513
|
+
raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
|
|
514
|
+
|
|
515
|
+
def _setup_ddp(self) -> None:
|
|
516
|
+
"""Setup Distributed Data Parallel training configuration.
|
|
517
|
+
|
|
518
|
+
Initializes process group, determines rank information, and sets up
|
|
519
|
+
CUDA device for the current process.
|
|
520
|
+
"""
|
|
521
|
+
# check if DDP environment variables are set
|
|
522
|
+
if "RANK" not in os.environ:
|
|
523
|
+
raise ValueError("DDP enabled but RANK environment variable not set")
|
|
524
|
+
if "LOCAL_RANK" not in os.environ:
|
|
525
|
+
raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
|
|
526
|
+
if "WORLD_SIZE" not in os.environ:
|
|
527
|
+
raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
|
|
528
|
+
|
|
529
|
+
# ensure CUDA is available for DDP
|
|
530
|
+
if not torch.cuda.is_available():
|
|
531
|
+
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
532
|
+
|
|
533
|
+
# initialize process group only if not already initialized
|
|
534
|
+
if not torch.distributed.is_initialized():
|
|
535
|
+
init_process_group(backend="nccl")
|
|
536
|
+
|
|
537
|
+
# get rank information
|
|
538
|
+
self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
|
|
539
|
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
|
|
540
|
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
|
|
541
|
+
|
|
542
|
+
# set device and make it current
|
|
543
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
544
|
+
torch.cuda.set_device(self.device)
|
|
545
|
+
|
|
546
|
+
# master process handles logging, checkpointing, etc.
|
|
547
|
+
self.master_process = self.ddp_rank == 0
|
|
548
|
+
|
|
549
|
+
if self.master_process:
|
|
550
|
+
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
551
|
+
|
|
552
|
+
def _setup_single_gpu(self) -> None:
|
|
553
|
+
"""Setup single GPU or CPU training configuration."""
|
|
554
|
+
self.ddp_rank = 0
|
|
555
|
+
self.ddp_local_rank = 0
|
|
556
|
+
self.ddp_world_size = 1
|
|
557
|
+
self.master_process = True
|
|
558
|
+
|
|
559
|
+
def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
|
|
560
|
+
"""Loads a training checkpoint to resume training.
|
|
561
|
+
|
|
562
|
+
Restores the state of the noise predictor, conditional model (if applicable),
|
|
563
|
+
and optimizer from a saved checkpoint. Handles DDP model state dict loading.
|
|
564
|
+
|
|
565
|
+
Parameters
|
|
566
|
+
----------
|
|
567
|
+
checkpoint_path : str
|
|
568
|
+
Path to the checkpoint file.
|
|
569
|
+
|
|
570
|
+
Returns
|
|
571
|
+
-------
|
|
572
|
+
epoch : int
|
|
573
|
+
The epoch at which the checkpoint was saved.
|
|
574
|
+
loss : float
|
|
575
|
+
The loss at the checkpoint.
|
|
576
|
+
"""
|
|
577
|
+
try:
|
|
578
|
+
# load checkpoint with proper device mapping
|
|
579
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
580
|
+
except FileNotFoundError:
|
|
581
|
+
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
|
582
|
+
|
|
583
|
+
# load noise predictor state
|
|
584
|
+
if 'model_state_dict_noise_predictor' not in checkpoint:
|
|
585
|
+
raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
|
|
586
|
+
|
|
587
|
+
# handle DDP wrapped model state dict
|
|
588
|
+
state_dict = checkpoint['model_state_dict_noise_predictor']
|
|
589
|
+
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
|
|
590
|
+
# if loading non-DDP checkpoint into DDP model, add 'module.' prefix
|
|
591
|
+
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
|
|
592
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
593
|
+
# if loading DDP checkpoint into non-DDP model, remove 'module.' prefix
|
|
594
|
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
595
|
+
|
|
596
|
+
self.noise_predictor.load_state_dict(state_dict)
|
|
597
|
+
|
|
598
|
+
# load conditional model state if applicable
|
|
599
|
+
if self.conditional_model is not None:
|
|
600
|
+
if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
|
|
601
|
+
cond_state_dict = checkpoint['model_state_dict_conditional']
|
|
602
|
+
# handle DDP wrapping for conditional model
|
|
603
|
+
if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
|
|
604
|
+
cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
|
|
605
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
|
|
606
|
+
cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
|
|
607
|
+
self.conditional_model.load_state_dict(cond_state_dict)
|
|
608
|
+
else:
|
|
609
|
+
warnings.warn(
|
|
610
|
+
"Checkpoint contains no 'model_state_dict_conditional' or it is None, "
|
|
611
|
+
"skipping conditional model loading"
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
# load variance_scheduler state
|
|
615
|
+
if 'variance_scheduler_model' not in checkpoint:
|
|
616
|
+
raise KeyError("Checkpoint missing 'variance_scheduler_model' key")
|
|
617
|
+
try:
|
|
618
|
+
if isinstance(self.forward_diffusion.variance_scheduler, nn.Module):
|
|
619
|
+
self.forward_diffusion.variance_scheduler.load_state_dict(
|
|
620
|
+
checkpoint['variance_scheduler_model'])
|
|
621
|
+
if isinstance(self.reverse_diffusion.variance_scheduler, nn.Module):
|
|
622
|
+
self.reverse_diffusion.variance_scheduler.load_state_dict(
|
|
623
|
+
checkpoint['variance_scheduler_model'])
|
|
624
|
+
else:
|
|
625
|
+
self.forward_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
|
|
626
|
+
self.reverse_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
|
|
627
|
+
except Exception as e:
|
|
628
|
+
warnings.warn(f"Variance_scheduler loading failed: {e}. Continuing with current variance_scheduler.")
|
|
629
|
+
|
|
630
|
+
# load optimizer state
|
|
631
|
+
if 'optimizer_state_dict' not in checkpoint:
|
|
632
|
+
raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
|
|
633
|
+
try:
|
|
634
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
635
|
+
except ValueError as e:
|
|
636
|
+
warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
|
|
637
|
+
|
|
638
|
+
epoch = checkpoint.get('epoch', -1)
|
|
639
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
640
|
+
|
|
641
|
+
if self.master_process:
|
|
642
|
+
print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
|
|
643
|
+
return epoch, loss
|
|
644
|
+
|
|
645
|
+
@staticmethod
|
|
646
|
+
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
|
|
647
|
+
"""Creates a learning rate scheduler for warmup.
|
|
648
|
+
|
|
649
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
650
|
+
optimizer's initial value over the specified warmup epochs, then maintains it.
|
|
651
|
+
|
|
652
|
+
Parameters
|
|
653
|
+
----------
|
|
654
|
+
`optimizer` : torch.optim.Optimizer
|
|
655
|
+
Optimizer to apply the scheduler to.
|
|
656
|
+
`warmup_epochs` : int
|
|
657
|
+
Number of epochs for the warmup phase.
|
|
658
|
+
|
|
659
|
+
Returns
|
|
660
|
+
-------
|
|
661
|
+
lr_scheduler (torch.optim.lr_scheduler.LambdaLR) - Learning rate scheduler for warmup.
|
|
662
|
+
"""
|
|
663
|
+
def lr_lambda(epoch: int) -> float:
|
|
664
|
+
if epoch < warmup_epochs:
|
|
665
|
+
return epoch / warmup_epochs
|
|
666
|
+
return 1.0
|
|
667
|
+
|
|
668
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
669
|
+
|
|
670
|
+
def _wrap_models_for_ddp(self) -> None:
|
|
671
|
+
"""Wrap models with DistributedDataParallel for multi-GPU training."""
|
|
672
|
+
if self.use_ddp:
|
|
673
|
+
# wrap noise predictor with DDP
|
|
674
|
+
self.noise_predictor = DDP(
|
|
675
|
+
self.noise_predictor,
|
|
676
|
+
device_ids=[self.ddp_local_rank],
|
|
677
|
+
find_unused_parameters=True
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
# wrap conditional model with DDP if it exists
|
|
681
|
+
if self.conditional_model is not None:
|
|
682
|
+
self.conditional_model = DDP(
|
|
683
|
+
self.conditional_model,
|
|
684
|
+
device_ids=[self.ddp_local_rank],
|
|
685
|
+
find_unused_parameters=True
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
def forward(self) -> Tuple[List, float]:
|
|
689
|
+
"""Trains the DDIM model to predict noise added by the forward diffusion process.
|
|
690
|
+
|
|
691
|
+
Executes the training loop, optimizing the noise predictor and conditional model
|
|
692
|
+
(if applicable) using mixed precision, gradient clipping, and learning rate
|
|
693
|
+
scheduling. Supports validation, early stopping, and checkpointing.
|
|
694
|
+
|
|
695
|
+
Returns
|
|
696
|
+
-------
|
|
697
|
+
train_losses : list of float
|
|
698
|
+
List of mean training losses per epoch.
|
|
699
|
+
best_val_loss : float
|
|
700
|
+
Best validation or training loss achieved.
|
|
701
|
+
"""
|
|
702
|
+
|
|
703
|
+
# set models to training mode
|
|
704
|
+
self.noise_predictor.train()
|
|
705
|
+
if self.conditional_model is not None:
|
|
706
|
+
self.conditional_model.train()
|
|
707
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
708
|
+
self.reverse_diffusion.train()
|
|
709
|
+
self.forward_diffusion.train()
|
|
710
|
+
else:
|
|
711
|
+
self.reverse_diffusion.eval()
|
|
712
|
+
self.forward_diffusion.eval()
|
|
713
|
+
|
|
714
|
+
# compile models for optimization (if supported)
|
|
715
|
+
if self.use_compilation:
|
|
716
|
+
try:
|
|
717
|
+
self.noise_predictor = torch.compile(self.noise_predictor)
|
|
718
|
+
if self.conditional_model is not None:
|
|
719
|
+
self.conditional_model = torch.compile(self.conditional_model)
|
|
720
|
+
except Exception as e:
|
|
721
|
+
if self.master_process:
|
|
722
|
+
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
723
|
+
|
|
724
|
+
# wrap models for DDP after compilation
|
|
725
|
+
self._wrap_models_for_ddp()
|
|
726
|
+
|
|
727
|
+
# initialize training components
|
|
728
|
+
scaler = torch.GradScaler()
|
|
729
|
+
train_losses = []
|
|
730
|
+
best_val_loss = float("inf")
|
|
731
|
+
wait = 0
|
|
732
|
+
|
|
733
|
+
# main training loop
|
|
734
|
+
for epoch in range(self.max_epochs):
|
|
735
|
+
# set epoch for distributed sampler if using DDP
|
|
736
|
+
if self.use_ddp and hasattr(self.data_loader.sampler, 'set_epoch'):
|
|
737
|
+
self.data_loader.sampler.set_epoch(epoch)
|
|
738
|
+
|
|
739
|
+
train_losses_epoch = []
|
|
740
|
+
|
|
741
|
+
# training step loop with gradient accumulation
|
|
742
|
+
for step, (x, y) in enumerate(tqdm(self.data_loader, disable=not self.master_process)):
|
|
743
|
+
x = x.to(self.device)
|
|
744
|
+
|
|
745
|
+
# process conditional inputs if conditional model exists
|
|
746
|
+
if self.conditional_model is not None:
|
|
747
|
+
y_encoded = self._process_conditional_input(y)
|
|
748
|
+
else:
|
|
749
|
+
y_encoded = None
|
|
750
|
+
|
|
751
|
+
# forward pass with mixed precision
|
|
752
|
+
with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
|
|
753
|
+
# generate noise and timesteps
|
|
754
|
+
noise = torch.randn_like(x).to(self.device)
|
|
755
|
+
t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
|
|
756
|
+
|
|
757
|
+
# apply forward diffusion
|
|
758
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
759
|
+
|
|
760
|
+
# predict noise
|
|
761
|
+
predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
|
|
762
|
+
|
|
763
|
+
# compute loss and scale for gradient accumulation
|
|
764
|
+
loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
|
|
765
|
+
|
|
766
|
+
# backward pass
|
|
767
|
+
scaler.scale(loss).backward()
|
|
768
|
+
|
|
769
|
+
# gradient accumulation and optimizer step
|
|
770
|
+
if (step + 1) % self.grad_accumulation_steps == 0:
|
|
771
|
+
# clip gradients
|
|
772
|
+
scaler.unscale_(self.optimizer)
|
|
773
|
+
torch.nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
|
|
774
|
+
if self.conditional_model is not None:
|
|
775
|
+
torch.nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
|
|
776
|
+
|
|
777
|
+
# optimizer step
|
|
778
|
+
scaler.step(self.optimizer)
|
|
779
|
+
scaler.update()
|
|
780
|
+
self.optimizer.zero_grad()
|
|
781
|
+
|
|
782
|
+
# update learning rate (warmup scheduler)
|
|
783
|
+
self.warmup_lr_scheduler.step()
|
|
784
|
+
|
|
785
|
+
# record loss (unscaled)
|
|
786
|
+
train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
|
|
787
|
+
|
|
788
|
+
# compute mean training loss
|
|
789
|
+
mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
|
|
790
|
+
|
|
791
|
+
# all-reduce loss across processes for DDP
|
|
792
|
+
if self.use_ddp:
|
|
793
|
+
loss_tensor = torch.tensor(mean_train_loss, device=self.device)
|
|
794
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
|
795
|
+
mean_train_loss = loss_tensor.item()
|
|
796
|
+
|
|
797
|
+
train_losses.append(mean_train_loss)
|
|
798
|
+
|
|
799
|
+
# print training progress (only master process)
|
|
800
|
+
if self.master_process and (epoch + 1) % self.log_frequency == 0:
|
|
801
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
802
|
+
print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
|
|
803
|
+
|
|
804
|
+
# validation step
|
|
805
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
806
|
+
val_metrics = self.validate()
|
|
807
|
+
val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
|
|
808
|
+
|
|
809
|
+
if self.master_process:
|
|
810
|
+
print(f" | Val Loss: {val_loss:.4f}", end="")
|
|
811
|
+
if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
812
|
+
print(f" | FID: {fid:.4f}", end="")
|
|
813
|
+
if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
814
|
+
print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
|
|
815
|
+
if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
816
|
+
print(f" | LPIPS: {lpips_score:.4f}", end="")
|
|
817
|
+
print()
|
|
818
|
+
|
|
819
|
+
current_best = val_loss
|
|
820
|
+
self.scheduler.step(val_loss)
|
|
821
|
+
else:
|
|
822
|
+
if self.master_process:
|
|
823
|
+
print()
|
|
824
|
+
current_best = mean_train_loss
|
|
825
|
+
self.scheduler.step(mean_train_loss)
|
|
826
|
+
|
|
827
|
+
# save checkpoint and early stopping (only master process)
|
|
828
|
+
if self.master_process:
|
|
829
|
+
if current_best < best_val_loss and (epoch + 1) % self.val_frequency == 0:
|
|
830
|
+
best_val_loss = current_best
|
|
831
|
+
wait = 0
|
|
832
|
+
self._save_checkpoint(epoch + 1, best_val_loss)
|
|
833
|
+
else:
|
|
834
|
+
wait += 1
|
|
835
|
+
if wait >= self.patience:
|
|
836
|
+
print("Early stopping triggered")
|
|
837
|
+
self._save_checkpoint(epoch + 1, best_val_loss, "_early_stop")
|
|
838
|
+
break
|
|
839
|
+
|
|
840
|
+
# clean up DDP
|
|
841
|
+
if self.use_ddp:
|
|
842
|
+
destroy_process_group()
|
|
843
|
+
|
|
844
|
+
return train_losses, best_val_loss
|
|
845
|
+
|
|
846
|
+
def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor:
|
|
847
|
+
"""Process conditional input for text-to-image generation.
|
|
848
|
+
|
|
849
|
+
Parameters
|
|
850
|
+
----------
|
|
851
|
+
y : torch.Tensor or list
|
|
852
|
+
Conditional input (text prompts).
|
|
853
|
+
|
|
854
|
+
Returns
|
|
855
|
+
-------
|
|
856
|
+
torch.Tensor
|
|
857
|
+
Encoded conditional input.
|
|
858
|
+
"""
|
|
859
|
+
# convert to string list
|
|
860
|
+
y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
|
|
861
|
+
y_list = [str(item) for item in y_list]
|
|
862
|
+
|
|
863
|
+
# tokenize
|
|
864
|
+
y_encoded = self.tokenizer(
|
|
865
|
+
y_list,
|
|
866
|
+
padding="max_length",
|
|
867
|
+
truncation=True,
|
|
868
|
+
max_length=self.max_token_length,
|
|
869
|
+
return_tensors="pt"
|
|
870
|
+
).to(self.device)
|
|
871
|
+
|
|
872
|
+
# get embeddings
|
|
873
|
+
input_ids = y_encoded["input_ids"]
|
|
874
|
+
attention_mask = y_encoded["attention_mask"]
|
|
875
|
+
y_encoded = self.conditional_model(input_ids, attention_mask)
|
|
876
|
+
|
|
877
|
+
return y_encoded
|
|
878
|
+
def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "") -> None:
|
|
879
|
+
"""Save model checkpoint (only called by master process).
|
|
880
|
+
|
|
881
|
+
Parameters
|
|
882
|
+
----------
|
|
883
|
+
epoch : int
|
|
884
|
+
Current epoch number.
|
|
885
|
+
loss : float
|
|
886
|
+
Current loss value.
|
|
887
|
+
suffix : str, optional
|
|
888
|
+
Suffix to add to checkpoint filename.
|
|
889
|
+
"""
|
|
890
|
+
try:
|
|
891
|
+
# get state dicts, handling DDP wrapping
|
|
892
|
+
noise_predictor_state = (
|
|
893
|
+
self.noise_predictor.module.state_dict() if self.use_ddp
|
|
894
|
+
else self.noise_predictor.state_dict()
|
|
895
|
+
)
|
|
896
|
+
conditional_state = None
|
|
897
|
+
if self.conditional_model is not None:
|
|
898
|
+
conditional_state = (
|
|
899
|
+
self.conditional_model.module.state_dict() if self.use_ddp
|
|
900
|
+
else self.conditional_model.state_dict()
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
checkpoint = {
|
|
904
|
+
'epoch': epoch,
|
|
905
|
+
'model_state_dict_noise_predictor': noise_predictor_state,
|
|
906
|
+
'model_state_dict_conditional': conditional_state,
|
|
907
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
908
|
+
'loss': loss,
|
|
909
|
+
'variance_scheduler_model': (
|
|
910
|
+
self.forward_diffusion.variance_scheduler.state_dict() if isinstance(
|
|
911
|
+
self.forward_diffusion.variance_scheduler, nn.Module)
|
|
912
|
+
else self.forward_diffusion.variance_scheduler
|
|
913
|
+
),
|
|
914
|
+
'max_epochs': self.max_epochs,
|
|
915
|
+
}
|
|
916
|
+
|
|
917
|
+
filename = f"ddim_epoch_{epoch}{suffix}.pth"
|
|
918
|
+
filepath = os.path.join(self.store_path, filename)
|
|
919
|
+
os.makedirs(self.store_path, exist_ok=True)
|
|
920
|
+
torch.save(checkpoint, filepath)
|
|
921
|
+
|
|
922
|
+
print(f"Model saved at epoch {epoch}")
|
|
923
|
+
|
|
924
|
+
except Exception as e:
|
|
925
|
+
print(f"Failed to save model: {e}")
|
|
926
|
+
|
|
927
|
+
def validate(self) -> Tuple[float, float, float, float, float, float]:
|
|
928
|
+
"""Validates the noise predictor and computes evaluation Metrics.
|
|
929
|
+
|
|
930
|
+
Computes validation loss (MSE between predicted and ground truth noise) and generates
|
|
931
|
+
samples using the reverse diffusion model by manually iterating over timesteps.
|
|
932
|
+
Decodes samples to images and computes image-domain Metrics (MSE, PSNR, SSIM, FID, LPIPS)
|
|
933
|
+
if metrics_ is provided.
|
|
934
|
+
|
|
935
|
+
Returns
|
|
936
|
+
-------
|
|
937
|
+
val_loss : float
|
|
938
|
+
Mean validation loss.
|
|
939
|
+
fid : float, or `float('inf')` if not computed
|
|
940
|
+
Mean FID score.
|
|
941
|
+
mse : float, or None if not computed
|
|
942
|
+
Mean MSE
|
|
943
|
+
psnr : float, or None if not computed
|
|
944
|
+
Mean PSNR
|
|
945
|
+
ssim : float, or None if not computed
|
|
946
|
+
Mean SSIM
|
|
947
|
+
lpips_score : float, or None if not computed
|
|
948
|
+
Mean LPIPS score
|
|
949
|
+
"""
|
|
950
|
+
|
|
951
|
+
self.noise_predictor.eval()
|
|
952
|
+
if self.conditional_model is not None:
|
|
953
|
+
self.conditional_model.eval()
|
|
954
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
955
|
+
self.forward_diffusion.eval()
|
|
956
|
+
self.reverse_diffusion.eval()
|
|
957
|
+
|
|
958
|
+
val_losses = []
|
|
959
|
+
fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
|
|
960
|
+
|
|
961
|
+
with torch.no_grad():
|
|
962
|
+
for x, y in self.val_loader:
|
|
963
|
+
x = x.to(self.device)
|
|
964
|
+
x_orig = x.clone()
|
|
965
|
+
|
|
966
|
+
# process conditional input
|
|
967
|
+
if self.conditional_model is not None:
|
|
968
|
+
y_encoded = self._process_conditional_input(y)
|
|
969
|
+
else:
|
|
970
|
+
y_encoded = None
|
|
971
|
+
|
|
972
|
+
# compute validation loss
|
|
973
|
+
noise = torch.randn_like(x).to(self.device)
|
|
974
|
+
t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
|
|
975
|
+
|
|
976
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
977
|
+
predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
|
|
978
|
+
loss = self.objective(predicted_noise, noise)
|
|
979
|
+
val_losses.append(loss.item())
|
|
980
|
+
|
|
981
|
+
# generate samples for metrics evaluation
|
|
982
|
+
if self.metrics_ is not None and self.reverse_diffusion is not None:
|
|
983
|
+
xt = torch.randn_like(x).to(self.device)
|
|
984
|
+
|
|
985
|
+
# reverse diffusion sampling
|
|
986
|
+
for t in reversed(range(self.forward_diffusion.variance_scheduler.tau_num_steps)):
|
|
987
|
+
time_steps = torch.full((xt.shape[0],), t, device=self.device)#, dtype=torch.long)
|
|
988
|
+
prev_time_steps = torch.full((xt.shape[0],), max(t - 1, 0), device=self.device)#, dtype=torch.long)
|
|
989
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y_encoded, None)
|
|
990
|
+
xt, _ = self.reverse_diffusion(xt, predicted_noise, time_steps, prev_time_steps)
|
|
991
|
+
|
|
992
|
+
# clamp and normalize generated samples
|
|
993
|
+
x_hat = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
|
|
994
|
+
if self.normalize_output:
|
|
995
|
+
x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
996
|
+
x_orig = (x_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
997
|
+
|
|
998
|
+
# compute metrics
|
|
999
|
+
metrics_result = self.metrics_.forward(x_orig, x_hat)
|
|
1000
|
+
fid, mse, psnr, ssim, lpips_score = metrics_result
|
|
1001
|
+
|
|
1002
|
+
if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
1003
|
+
fid_scores.append(fid)
|
|
1004
|
+
if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
1005
|
+
mse_scores.append(mse)
|
|
1006
|
+
psnr_scores.append(psnr)
|
|
1007
|
+
ssim_scores.append(ssim)
|
|
1008
|
+
if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
1009
|
+
lpips_scores.append(lpips_score)
|
|
1010
|
+
|
|
1011
|
+
# compute average metrics
|
|
1012
|
+
val_loss = torch.tensor(val_losses).mean().item()
|
|
1013
|
+
|
|
1014
|
+
# all-reduce validation metrics across processes for DDP
|
|
1015
|
+
if self.use_ddp:
|
|
1016
|
+
val_loss_tensor = torch.tensor(val_loss, device=self.device)
|
|
1017
|
+
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
|
|
1018
|
+
val_loss = val_loss_tensor.item()
|
|
1019
|
+
|
|
1020
|
+
fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
|
|
1021
|
+
mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
|
|
1022
|
+
psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
|
|
1023
|
+
ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
|
|
1024
|
+
lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
|
|
1025
|
+
|
|
1026
|
+
# return to training mode
|
|
1027
|
+
self.noise_predictor.train()
|
|
1028
|
+
if self.conditional_model is not None:
|
|
1029
|
+
self.conditional_model.train()
|
|
1030
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
1031
|
+
self.reverse_diffusion.train()
|
|
1032
|
+
self.forward_diffusion.train()
|
|
1033
|
+
|
|
1034
|
+
return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
|
|
1035
|
+
|
|
1036
|
+
###==================================================================================================================###
|
|
1037
|
+
|
|
1038
|
+
class SampleDDIM(nn.Module):
|
|
1039
|
+
"""Image generation using a trained DDIM model.
|
|
1040
|
+
|
|
1041
|
+
Implements the sampling process for DDIM, generating images by iteratively denoising
|
|
1042
|
+
random noise using a trained noise predictor and reverse diffusion process with a
|
|
1043
|
+
subsampled time step schedule. Supports conditional generation with text prompts,
|
|
1044
|
+
as inspired by Song et al. (2021).
|
|
1045
|
+
|
|
1046
|
+
Parameters
|
|
1047
|
+
----------
|
|
1048
|
+
`reverse_diffusion` : nn.Module
|
|
1049
|
+
Reverse diffusion module (e.g., ReverseDDIM) for the reverse process.
|
|
1050
|
+
`noise_predictor` : nn.Module
|
|
1051
|
+
Trained model to predict noise at each time step.
|
|
1052
|
+
`image_shape` : tuple
|
|
1053
|
+
Tuple of (height, width) specifying the generated image dimensions.
|
|
1054
|
+
`conditional_model` : nn.Module, optional
|
|
1055
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
1056
|
+
`tokenizer` : str, optional
|
|
1057
|
+
Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased").
|
|
1058
|
+
`max_length` : int, optional
|
|
1059
|
+
Maximum length for tokenized prompts (default: 77).
|
|
1060
|
+
`batch_size` : int, optional
|
|
1061
|
+
Number of images to generate per batch (default: 1).
|
|
1062
|
+
`in_channels` : int, optional
|
|
1063
|
+
Number of input channels for generated images (default: 3).
|
|
1064
|
+
`device` : torch.device, optional
|
|
1065
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
1066
|
+
`output_range` : tuple, optional
|
|
1067
|
+
Tuple of (min, max) for clamping generated images (default: (-1, 1)).
|
|
1068
|
+
"""
|
|
1069
|
+
def __init__(
|
|
1070
|
+
self,
|
|
1071
|
+
reverse_diffusion: torch.nn.Module,
|
|
1072
|
+
noise_predictor: torch.nn.Module,
|
|
1073
|
+
image_shape: Tuple[int, int],
|
|
1074
|
+
conditional_model: Optional[torch.nn.Module] = None,
|
|
1075
|
+
bert_tokenizer: str = "bert-base-uncased",
|
|
1076
|
+
max_token_length: int = 77,
|
|
1077
|
+
batch_size: int = 1,
|
|
1078
|
+
in_channels: int = 3,
|
|
1079
|
+
device: Optional[str] = None,
|
|
1080
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0)
|
|
1081
|
+
) -> None:
|
|
1082
|
+
super().__init__()
|
|
1083
|
+
if device is None:
|
|
1084
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
1085
|
+
elif isinstance(device, str):
|
|
1086
|
+
self.device = torch.device(device)
|
|
1087
|
+
else:
|
|
1088
|
+
self.device = device
|
|
1089
|
+
self.reverse = reverse_diffusion.to(self.device)
|
|
1090
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
1091
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
1092
|
+
self.tokenizer = BertTokenizer.from_pretrained(bert_tokenizer)
|
|
1093
|
+
self.max_token_length = max_token_length
|
|
1094
|
+
self.in_channels = in_channels
|
|
1095
|
+
self.image_shape = image_shape
|
|
1096
|
+
self.batch_size = batch_size
|
|
1097
|
+
self.image_output_range = image_output_range
|
|
1098
|
+
|
|
1099
|
+
if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(
|
|
1100
|
+
isinstance(s, int) and s > 0 for s in image_shape):
|
|
1101
|
+
raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
|
|
1102
|
+
if batch_size <= 0:
|
|
1103
|
+
raise ValueError("batch_size must be positive")
|
|
1104
|
+
if not isinstance(image_output_range, (tuple, list)) or len(image_output_range) != 2 or image_output_range[0] >= image_output_range[1]:
|
|
1105
|
+
raise ValueError("image_output_range must be a tuple (min, max) with min < max")
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def tokenize(self, prompts: Union[List, str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1109
|
+
"""Tokenizes text prompts for conditional generation.
|
|
1110
|
+
|
|
1111
|
+
Converts input prompts into tokenized input IDs and attention masks using the
|
|
1112
|
+
specified tokenizer, suitable for use with the conditional model.
|
|
1113
|
+
|
|
1114
|
+
Parameters
|
|
1115
|
+
----------
|
|
1116
|
+
`prompts` : str or list
|
|
1117
|
+
A single text prompt or a list of text prompts.
|
|
1118
|
+
|
|
1119
|
+
Returns
|
|
1120
|
+
-------
|
|
1121
|
+
input_ids : torch.Tensor
|
|
1122
|
+
Tokenized input IDs, shape (batch_size, max_length).
|
|
1123
|
+
attention_mask : torch.Tensor
|
|
1124
|
+
Attention mask, shape (batch_size, max_length).
|
|
1125
|
+
"""
|
|
1126
|
+
if isinstance(prompts, str):
|
|
1127
|
+
prompts = [prompts]
|
|
1128
|
+
elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
|
|
1129
|
+
raise TypeError("prompts must be a string or list of strings")
|
|
1130
|
+
encoded = self.tokenizer(
|
|
1131
|
+
prompts,
|
|
1132
|
+
padding="max_length",
|
|
1133
|
+
truncation=True,
|
|
1134
|
+
max_length=self.max_token_length,
|
|
1135
|
+
return_tensors="pt"
|
|
1136
|
+
)
|
|
1137
|
+
return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
|
|
1138
|
+
|
|
1139
|
+
def forward(self, conditions: Optional[Union[str, List]] = None, normalize_output: bool = True, save_images: bool = True, save_path: str = "ddim_generated") -> torch.Tensor:
|
|
1140
|
+
"""Generates images using the DDIM sampling process.
|
|
1141
|
+
|
|
1142
|
+
Iteratively denoises random noise to generate images using the reverse diffusion
|
|
1143
|
+
process with a subsampled time step schedule and noise predictor. Supports
|
|
1144
|
+
conditional generation with text prompts.
|
|
1145
|
+
|
|
1146
|
+
Parameters
|
|
1147
|
+
----------
|
|
1148
|
+
`conditions` : str or list, optional
|
|
1149
|
+
Text prompt(s) for conditional generation, default None.
|
|
1150
|
+
`normalize_output` : bool, optional
|
|
1151
|
+
If True, normalizes output images to [0, 1] (default: True).
|
|
1152
|
+
`save_images` : bool, optional
|
|
1153
|
+
If True, saves generated images to `save_path` (default: True).
|
|
1154
|
+
`save_path` : str, optional
|
|
1155
|
+
Directory to save generated images (default: "ddim_generated").
|
|
1156
|
+
|
|
1157
|
+
Returns
|
|
1158
|
+
-------
|
|
1159
|
+
generated_imgs (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width). If `normalize_output` is True, images are normalized to [0, 1]; otherwise, they are clamped to `output_range`.
|
|
1160
|
+
"""
|
|
1161
|
+
|
|
1162
|
+
if conditions is not None and self.conditional_model is None:
|
|
1163
|
+
raise ValueError("Conditions provided but no conditional model specified")
|
|
1164
|
+
if conditions is None and self.conditional_model is not None:
|
|
1165
|
+
raise ValueError("Conditions must be provided for conditional model")
|
|
1166
|
+
|
|
1167
|
+
noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
|
|
1168
|
+
|
|
1169
|
+
self.noise_predictor.eval()
|
|
1170
|
+
self.reverse.eval()
|
|
1171
|
+
if self.conditional_model:
|
|
1172
|
+
self.conditional_model.eval()
|
|
1173
|
+
|
|
1174
|
+
with torch.no_grad():
|
|
1175
|
+
xt = noisy_samples
|
|
1176
|
+
for t in reversed(range(self.reverse.variance_scheduler.tau_num_steps)):
|
|
1177
|
+
time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
|
|
1178
|
+
prev_time_steps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device, dtype=torch.long)
|
|
1179
|
+
|
|
1180
|
+
if self.conditional_model is not None and conditions is not None:
|
|
1181
|
+
input_ids, attention_masks = self.tokenize(conditions)
|
|
1182
|
+
key_padding_mask = (attention_masks == 0)
|
|
1183
|
+
y = self.conditional_model(input_ids, key_padding_mask)
|
|
1184
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y, None)
|
|
1185
|
+
else:
|
|
1186
|
+
predicted_noise = self.noise_predictor(xt, time_steps, None)
|
|
1187
|
+
|
|
1188
|
+
xt, _ = self.reverse(xt, predicted_noise, time_steps, prev_time_steps)
|
|
1189
|
+
|
|
1190
|
+
generated_imgs = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
|
|
1191
|
+
if normalize_output:
|
|
1192
|
+
generated_imgs = (generated_imgs - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
1193
|
+
|
|
1194
|
+
if save_images:
|
|
1195
|
+
os.makedirs(save_path, exist_ok=True) # create directory if it doesn't exist
|
|
1196
|
+
for i in range(generated_imgs.size(0)):
|
|
1197
|
+
img_path = os.path.join(save_path, f"image_{i+1}.png")
|
|
1198
|
+
save_image(generated_imgs[i], img_path)
|
|
1199
|
+
|
|
1200
|
+
return generated_imgs
|
|
1201
|
+
|
|
1202
|
+
def to(self, device: torch.device) -> Self:
|
|
1203
|
+
"""Moves the module and its components to the specified device.
|
|
1204
|
+
|
|
1205
|
+
Updates the device attribute and moves the reverse diffusion, noise predictor,
|
|
1206
|
+
and conditional model (if present) to the specified device.
|
|
1207
|
+
|
|
1208
|
+
Parameters
|
|
1209
|
+
----------
|
|
1210
|
+
`device` : torch.device
|
|
1211
|
+
Target device for the module and its components.
|
|
1212
|
+
|
|
1213
|
+
Returns
|
|
1214
|
+
-------
|
|
1215
|
+
sample_ddim (SampleDDIM) - moved to the specified device.
|
|
1216
|
+
"""
|
|
1217
|
+
self.device = device
|
|
1218
|
+
self.noise_predictor.to(device)
|
|
1219
|
+
self.reverse.to(device)
|
|
1220
|
+
if self.conditional_model:
|
|
1221
|
+
self.conditional_model.to(device)
|
|
1222
|
+
return super().to(device)
|