TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
torchdiff/ddpm.py
ADDED
|
@@ -0,0 +1,1153 @@
|
|
|
1
|
+
"""
|
|
2
|
+
**Denoising Diffusion Probabilistic Models (DDPM) implementation**
|
|
3
|
+
|
|
4
|
+
This module provides a complete implementation of DDPM, as described in Ho et al.
|
|
5
|
+
(2020, "Denoising Diffusion Probabilistic Models"). It includes components for forward
|
|
6
|
+
and reverse diffusion processes, hyperparameter management, training, and image
|
|
7
|
+
sampling. Supports both unconditional and conditional generation with text prompts.
|
|
8
|
+
|
|
9
|
+
**Components**
|
|
10
|
+
|
|
11
|
+
- **ForwardDDPM**: Forward diffusion process to add noise.
|
|
12
|
+
- **ReverseDDPM**: Reverse diffusion process to denoise.
|
|
13
|
+
- **VarianceSchedulerDDPM**: Noise schedule management.
|
|
14
|
+
- **TrainDDPM**: Training loop with mixed precision and scheduling.
|
|
15
|
+
- **SampleDDPM**: Image generation from trained models.
|
|
16
|
+
|
|
17
|
+
**References**
|
|
18
|
+
|
|
19
|
+
- Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models.
|
|
20
|
+
|
|
21
|
+
- Salimans, Tim, et al. "Pixelcnn++: Improving the pixelcnn with discretized logistic mixture likelihood and other modifications."
|
|
22
|
+
arXiv preprint arXiv:1701.05517 (2017).
|
|
23
|
+
|
|
24
|
+
-------------------------------------------------------------------------------
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
import torch
|
|
30
|
+
import torch.nn as nn
|
|
31
|
+
from typing import Optional, Tuple, Callable, List, Any, Union, Self
|
|
32
|
+
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
|
33
|
+
import torch.distributed as dist
|
|
34
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
35
|
+
from torch.distributed import init_process_group, destroy_process_group
|
|
36
|
+
from tqdm import tqdm
|
|
37
|
+
from transformers import BertTokenizer
|
|
38
|
+
import warnings
|
|
39
|
+
from torchvision.utils import save_image
|
|
40
|
+
import os
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
###==================================================================================================================###
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ForwardDDPM(nn.Module):
|
|
47
|
+
"""Forward diffusion process for Denoising Diffusion Probabilistic Models (DDPM).
|
|
48
|
+
|
|
49
|
+
Implements the forward diffusion process for DDPM, which perturbs input data by
|
|
50
|
+
adding Gaussian noise over a series of time steps, as defined in Ho et al. (2020).
|
|
51
|
+
The noise schedule can be either fixed or trainable, depending on the provided
|
|
52
|
+
hyperparameters.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
variance_scheuler : object
|
|
57
|
+
Hyperparameter object (VarianceSchedulerDDPM) containing the noise schedule parameters. Expected to have
|
|
58
|
+
attributes: `num_steps`, `trainable_beta`, `betas`, `sqrt_alpha_bars`, `sqrt_one_minus_alpha_bars`, `compute_schedule`.
|
|
59
|
+
"""
|
|
60
|
+
def __init__(self, variance_scheduler: torch.nn.Module) -> None:
|
|
61
|
+
super().__init__()
|
|
62
|
+
self.variance_scheduler = variance_scheduler
|
|
63
|
+
|
|
64
|
+
def forward(self, x0: torch.Tensor, noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
|
|
65
|
+
"""Applies the forward diffusion process to the input data.
|
|
66
|
+
|
|
67
|
+
Perturbs the input data `x0` by adding Gaussian noise according to the DDPM
|
|
68
|
+
forward process at specified time steps. Uses the reparameterization trick:
|
|
69
|
+
x_t = sqrt(ᾱ_t) * x_0 + sqrt(1 - ᾱ_t) * ε.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
x0 : torch.Tensor
|
|
74
|
+
Input data tensor of shape (batch_size, channels, height, width).
|
|
75
|
+
noise : torch.Tensor
|
|
76
|
+
Gaussian noise tensor of the same shape as `x0`.
|
|
77
|
+
time_steps : torch.Tensor
|
|
78
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
79
|
+
is in the range [0, hyper_params.num_steps - 1].
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
xt : torch.Tensor
|
|
84
|
+
Noisy data tensor `xt` at the specified time steps, with the same shape as `x0`.
|
|
85
|
+
"""
|
|
86
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.num_steps)):
|
|
87
|
+
raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.num_steps - 1}")
|
|
88
|
+
|
|
89
|
+
if self.variance_scheduler.trainable_beta:
|
|
90
|
+
_, _, _, sqrt_alpha_bar_t, sqrt_one_minus_alpha_bar_t = self.variance_scheduler.compute_schedule(time_steps)
|
|
91
|
+
sqrt_alpha_bar_t = sqrt_alpha_bar_t.to(x0.device)
|
|
92
|
+
sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.to(x0.device)
|
|
93
|
+
else:
|
|
94
|
+
sqrt_alpha_bar_t = self.variance_scheduler.sqrt_alpha_bars[time_steps].to(x0.device)
|
|
95
|
+
sqrt_one_minus_alpha_bar_t = self.variance_scheduler.sqrt_one_minus_alpha_bars[time_steps].to(x0.device)
|
|
96
|
+
|
|
97
|
+
sqrt_alpha_bar_t = sqrt_alpha_bar_t.view(-1, 1, 1, 1)
|
|
98
|
+
sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.view(-1, 1, 1, 1)
|
|
99
|
+
xt = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise
|
|
100
|
+
|
|
101
|
+
return xt
|
|
102
|
+
|
|
103
|
+
###==================================================================================================================###
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ReverseDDPM(nn.Module):
|
|
107
|
+
"""Reverse diffusion process for Denoising Diffusion Probabilistic Models (DDPM).
|
|
108
|
+
|
|
109
|
+
Implements the reverse diffusion process for DDPM, which iteratively denoises a
|
|
110
|
+
noisy input `xt` using a predicted noise component, as defined in Ho et al. (2020).
|
|
111
|
+
The process relies on a noise schedule that can be either fixed or trainable,
|
|
112
|
+
specified through the provided hyperparameters.
|
|
113
|
+
|
|
114
|
+
Parameters
|
|
115
|
+
----------
|
|
116
|
+
variance_scheduler : object
|
|
117
|
+
Hyperparameter object (VarianceSchedulerDDPM) containing the noise schedule parameters. Expected to have
|
|
118
|
+
attributes: `num_steps`, `trainable_beta`, `betas`, `alphas`, `alpha_bars`, `compute_schedule`.
|
|
119
|
+
"""
|
|
120
|
+
def __init__(self, variance_scheduler: torch.nn.Module) -> None:
|
|
121
|
+
super().__init__()
|
|
122
|
+
self.variance_scheduler = variance_scheduler
|
|
123
|
+
|
|
124
|
+
def forward(self, xt: torch.Tensor, predicted_noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
|
|
125
|
+
"""Applies the reverse diffusion process to the noisy input.
|
|
126
|
+
|
|
127
|
+
Denoises the input `xt` by computing the mean of the reverse process
|
|
128
|
+
distribution using the predicted noise and optionally adding stochastic noise
|
|
129
|
+
for time steps greater than 0, as per the DDPM reverse process.
|
|
130
|
+
|
|
131
|
+
Parameters
|
|
132
|
+
----------
|
|
133
|
+
xt : torch.Tensor
|
|
134
|
+
Noisy input tensor at time step `t`, of shape (batch_size, channels, height, width).
|
|
135
|
+
predicted_noise : torch.Tensor
|
|
136
|
+
Predicted noise tensor, of the same shape as `xt`, typically output by a neural network.
|
|
137
|
+
time_steps : torch.Tensor
|
|
138
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
139
|
+
is in the range [0, hyper_params.num_steps - 1].
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
xt_minus_1 : torch.Tensor
|
|
144
|
+
Denoised tensor `xt_minus_1` at time step `t-1`, with the same shape as `xt`.
|
|
145
|
+
"""
|
|
146
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.num_steps)):
|
|
147
|
+
raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.num_steps - 1}")
|
|
148
|
+
|
|
149
|
+
if self.variance_scheduler.trainable_beta:
|
|
150
|
+
betas_t, alphas_t, alpha_bars_t, _, _ = self.variance_scheduler.compute_schedule(time_steps)
|
|
151
|
+
betas_t = betas_t.to(xt.device)
|
|
152
|
+
alphas_t = alphas_t.to(xt.device)
|
|
153
|
+
alpha_bars_t = alpha_bars_t.to(xt.device)
|
|
154
|
+
alpha_bars_t_minus_1 = torch.zeros_like(alpha_bars_t).to(xt.device)
|
|
155
|
+
non_zero_mask = time_steps > 0
|
|
156
|
+
if non_zero_mask.any():
|
|
157
|
+
_, _, alpha_bars_t_minus_1_tmp, _, _ = self.variance_scheduler.compute_schedule(time_steps[non_zero_mask] - 1)
|
|
158
|
+
alpha_bars_t_minus_1[non_zero_mask] = alpha_bars_t_minus_1_tmp.to(xt.device)
|
|
159
|
+
else:
|
|
160
|
+
betas_t = self.variance_scheduler.betas[time_steps].to(xt.device)
|
|
161
|
+
alphas_t = self.variance_scheduler.alphas[time_steps].to(xt.device)
|
|
162
|
+
alpha_bars_t = self.variance_scheduler.alpha_bars[time_steps].to(xt.device)
|
|
163
|
+
alpha_bars_t_minus_1 = torch.zeros_like(alpha_bars_t).to(xt.device)
|
|
164
|
+
non_zero_mask = time_steps > 0
|
|
165
|
+
if non_zero_mask.any():
|
|
166
|
+
alpha_bars_t_minus_1[non_zero_mask] = self.variance_scheduler.alpha_bars[time_steps[non_zero_mask] - 1].to(xt.device)
|
|
167
|
+
|
|
168
|
+
sqrt_alphas_t = torch.sqrt(alphas_t).view(-1, 1, 1, 1)
|
|
169
|
+
sqrt_one_minus_alpha_bars_t = torch.sqrt(1 - alpha_bars_t).view(-1, 1, 1, 1)
|
|
170
|
+
betas_t = betas_t.view(-1, 1, 1, 1)
|
|
171
|
+
|
|
172
|
+
mu = (xt - (betas_t / sqrt_one_minus_alpha_bars_t) * predicted_noise) / sqrt_alphas_t
|
|
173
|
+
|
|
174
|
+
mask = (time_steps == 0)
|
|
175
|
+
if mask.all():
|
|
176
|
+
return mu
|
|
177
|
+
|
|
178
|
+
variance = (1 - alpha_bars_t_minus_1) / (1 - alpha_bars_t) * betas_t.squeeze()
|
|
179
|
+
std = torch.sqrt(variance).view(-1, 1, 1, 1)
|
|
180
|
+
|
|
181
|
+
z = torch.randn_like(xt).to(xt.device)
|
|
182
|
+
xt_minus_1 = mu + (~mask).float().view(-1, 1, 1, 1) * std * z
|
|
183
|
+
return xt_minus_1
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
###==================================================================================================================###
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class VarianceSchedulerDDPM(nn.Module):
|
|
190
|
+
"""Hyperparameters for Denoising Diffusion Probabilistic Models (DDPM) noise schedule.
|
|
191
|
+
|
|
192
|
+
Manages the noise schedule parameters for DDPM, including the computation of beta
|
|
193
|
+
values and derived quantities (alphas, alpha_bars, etc.), with support for
|
|
194
|
+
trainable or fixed schedules and various beta scheduling methods, as inspired by
|
|
195
|
+
Ho et al. (2020).
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
num_steps : int, optional
|
|
200
|
+
Number of diffusion steps (default: 1000).
|
|
201
|
+
beta_start : float, optional
|
|
202
|
+
Starting value for beta (default: 1e-4).
|
|
203
|
+
beta_end : float, optional
|
|
204
|
+
Ending value for beta (default: 0.02).
|
|
205
|
+
trainable_beta : bool, optional
|
|
206
|
+
Whether the beta schedule is trainable (default: False).
|
|
207
|
+
beta_method : str, optional
|
|
208
|
+
Method for computing the beta schedule (default: "linear").
|
|
209
|
+
Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
210
|
+
"""
|
|
211
|
+
def __init__(self, num_steps: int = 1000, beta_start: float = 1e-4, beta_end: float = 0.02, trainable_beta: bool = False, beta_method: str = "linear") -> None:
|
|
212
|
+
super().__init__()
|
|
213
|
+
self.num_steps = num_steps
|
|
214
|
+
self.beta_start = beta_start
|
|
215
|
+
self.beta_end = beta_end
|
|
216
|
+
self.trainable_beta = trainable_beta
|
|
217
|
+
self.beta_method = beta_method
|
|
218
|
+
|
|
219
|
+
if not (0 < beta_start < beta_end < 1):
|
|
220
|
+
raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
|
|
221
|
+
if num_steps <= 0:
|
|
222
|
+
raise ValueError(f"num_steps ({num_steps}) must be positive")
|
|
223
|
+
|
|
224
|
+
beta_range = (beta_start, beta_end)
|
|
225
|
+
betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
|
|
226
|
+
|
|
227
|
+
if trainable_beta:
|
|
228
|
+
self.betas = nn.Parameter(torch.log(betas_init))
|
|
229
|
+
else:
|
|
230
|
+
self.register_buffer('betas', betas_init)
|
|
231
|
+
self.register_buffer('alphas', 1 - self.betas)
|
|
232
|
+
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
|
|
233
|
+
self.register_buffer('sqrt_alpha_bars', torch.sqrt(self.alpha_bars))
|
|
234
|
+
self.register_buffer('sqrt_one_minus_alpha_bars', torch.sqrt(1 - self.alpha_bars))
|
|
235
|
+
|
|
236
|
+
def compute_beta_schedule(self, beta_range: Tuple[float, float], num_steps: int, method: str) -> torch.Tensor:
|
|
237
|
+
"""Computes the beta schedule based on the specified method.
|
|
238
|
+
|
|
239
|
+
Generates a sequence of beta values for the DDPM noise schedule using the
|
|
240
|
+
chosen method, ensuring values are clamped within the specified range.
|
|
241
|
+
|
|
242
|
+
Parameters
|
|
243
|
+
----------
|
|
244
|
+
beta_range : tuple
|
|
245
|
+
Tuple of (min_beta, max_beta) specifying the valid range for beta values.
|
|
246
|
+
num_steps : int
|
|
247
|
+
Number of diffusion steps.
|
|
248
|
+
method : str
|
|
249
|
+
Method for computing the beta schedule. Supported methods:
|
|
250
|
+
"linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
251
|
+
|
|
252
|
+
Returns
|
|
253
|
+
-------
|
|
254
|
+
beta (torch.Tensor) - Tensor of beta values, shape (num_steps,).
|
|
255
|
+
"""
|
|
256
|
+
beta_min, beta_max = beta_range
|
|
257
|
+
if method == "sigmoid":
|
|
258
|
+
x = torch.linspace(-6, 6, num_steps)
|
|
259
|
+
beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
|
|
260
|
+
elif method == "quadratic":
|
|
261
|
+
x = torch.linspace(beta_min**0.5, beta_max**0.5, num_steps)
|
|
262
|
+
beta = x**2
|
|
263
|
+
elif method == "constant":
|
|
264
|
+
beta = torch.full((num_steps,), beta_max)
|
|
265
|
+
elif method == "inverse_time":
|
|
266
|
+
beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
|
|
267
|
+
beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
|
|
268
|
+
elif method == "linear":
|
|
269
|
+
beta = torch.linspace(beta_min, beta_max, num_steps)
|
|
270
|
+
else:
|
|
271
|
+
raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
|
|
272
|
+
|
|
273
|
+
beta = torch.clamp(beta, min=beta_min, max=beta_max)
|
|
274
|
+
return beta
|
|
275
|
+
|
|
276
|
+
def compute_schedule(self, time_steps: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
277
|
+
"""Computes noise schedule parameters dynamically from betas.
|
|
278
|
+
|
|
279
|
+
Parameters-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
|
280
|
+
----------
|
|
281
|
+
time_steps : torch.Tensor, optional
|
|
282
|
+
Tensor of time step indices (long), shape (batch_size,). If None, returns parameters for all steps.
|
|
283
|
+
|
|
284
|
+
Returns
|
|
285
|
+
-------
|
|
286
|
+
betas, alphas, alpha_bars, sqrt_alpha_bars, sqrt_one_minus_alpha_bars : torch.Tensor
|
|
287
|
+
Schedule parameters, shape (batch_size,) if time_steps is provided, else (num_steps,).
|
|
288
|
+
"""
|
|
289
|
+
if self.trainable_beta:
|
|
290
|
+
# Compute betas from trainable log_betas using sigmoid
|
|
291
|
+
betas = torch.sigmoid(self.betas) * (self.beta_end - self.beta_start) + self.beta_start
|
|
292
|
+
else:
|
|
293
|
+
betas = self.betas
|
|
294
|
+
|
|
295
|
+
alphas = 1 - betas
|
|
296
|
+
alpha_bars = torch.cumprod(alphas, dim=0)
|
|
297
|
+
|
|
298
|
+
if time_steps is not None:
|
|
299
|
+
betas = betas[time_steps]
|
|
300
|
+
alphas = alphas[time_steps]
|
|
301
|
+
alpha_bars = alpha_bars[time_steps]
|
|
302
|
+
|
|
303
|
+
return betas, alphas, alpha_bars, torch.sqrt(alpha_bars), torch.sqrt(1 - alpha_bars)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
###==================================================================================================================###
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class TrainDDPM(nn.Module):
|
|
310
|
+
"""Trainer for Denoising Diffusion Probabilistic Models (DDPM) with Multi-GPU Support.
|
|
311
|
+
|
|
312
|
+
Manages the training process for DDPM, optimizing a noise predictor model to learn
|
|
313
|
+
the noise added by the forward diffusion process. Supports conditional training with
|
|
314
|
+
text prompts, mixed precision training, learning rate scheduling, early stopping,
|
|
315
|
+
checkpointing, and distributed data parallel (DDP) training across multiple GPUs.
|
|
316
|
+
|
|
317
|
+
Parameters
|
|
318
|
+
----------
|
|
319
|
+
noise_predictor : nn.Module
|
|
320
|
+
Model to predict noise added during the forward diffusion process.
|
|
321
|
+
forward_diffusion : nn.Module
|
|
322
|
+
Forward DDPM diffusion module for adding noise.
|
|
323
|
+
reverse_diffusion: nn.Module
|
|
324
|
+
Reverse DDPM diffusion module for denoising.
|
|
325
|
+
data_loader : torch.utils.data.DataLoader
|
|
326
|
+
DataLoader for training data. Should be wrapped with DistributedSampler for DDP.
|
|
327
|
+
optimizer : torch.optim.Optimizer
|
|
328
|
+
Optimizer for training the noise predictor and conditional model (if applicable).
|
|
329
|
+
objective : callable
|
|
330
|
+
Loss function to compute the difference between predicted and actual noise.
|
|
331
|
+
val_loader : torch.utils.data.DataLoader, optional
|
|
332
|
+
DataLoader for validation data, default None.
|
|
333
|
+
max_epochs : int, optional
|
|
334
|
+
Maximum number of training epochs (default: 1000).
|
|
335
|
+
device : torch.device, optional
|
|
336
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
337
|
+
conditional_model : nn.Module, optional
|
|
338
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
339
|
+
metrics_ : object, optional
|
|
340
|
+
Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
|
|
341
|
+
tokenizer : BertTokenizer, optional
|
|
342
|
+
Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
|
|
343
|
+
max_token_length : int, optional
|
|
344
|
+
Maximum length for tokenized prompts (default: 77).
|
|
345
|
+
store_path : str, optional
|
|
346
|
+
Path to save model checkpoints (default: "ddpm_model").
|
|
347
|
+
patience : int, optional
|
|
348
|
+
Number of epochs to wait for improvement before early stopping (default: 100).
|
|
349
|
+
warmup_epochs : int, optional
|
|
350
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
351
|
+
val_frequency : int, optional
|
|
352
|
+
Frequency (in epochs) for validation (default: 10).
|
|
353
|
+
image_output_range : tuple, optional
|
|
354
|
+
Range for clamping generated images (default: (-1, 1)).
|
|
355
|
+
normalize_output : bool, optional
|
|
356
|
+
Whether to normalize generated images to [0, 1] for metrics (default: True).
|
|
357
|
+
use_ddp : bool, optional
|
|
358
|
+
Whether to use Distributed Data Parallel training (default: False).
|
|
359
|
+
grad_accumulation_steps : int, optional
|
|
360
|
+
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
361
|
+
log_frequency : int, optional
|
|
362
|
+
Number of epochs before printing loss.
|
|
363
|
+
use_compilation : bool, optional
|
|
364
|
+
whether the model is internally compiled using torch.compile (default: false)
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
def __init__(
|
|
368
|
+
self,
|
|
369
|
+
noise_predictor: torch.nn.Module,
|
|
370
|
+
forward_diffusion: torch.nn.Module,
|
|
371
|
+
reverse_diffusion: torch.nn.Module,
|
|
372
|
+
data_loader: torch.utils.data.DataLoader,
|
|
373
|
+
optimizer: torch.optim.Optimizer,
|
|
374
|
+
objective: Callable,
|
|
375
|
+
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
376
|
+
max_epochs: int = 1000,
|
|
377
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
378
|
+
conditional_model: torch.nn.Module = None,
|
|
379
|
+
metrics_: Optional[Any] = None,
|
|
380
|
+
bert_tokenizer: Optional[BertTokenizer] = None,
|
|
381
|
+
max_token_length: int = 77,
|
|
382
|
+
store_path: Optional[str] = None,
|
|
383
|
+
patience: int = 100,
|
|
384
|
+
warmup_epochs: int = 100,
|
|
385
|
+
val_frequency: int = 10,
|
|
386
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0),
|
|
387
|
+
normalize_output: bool = True,
|
|
388
|
+
use_ddp: bool = False,
|
|
389
|
+
grad_accumulation_steps: int = 1,
|
|
390
|
+
log_frequency: int = 1,
|
|
391
|
+
use_compilation: bool = False
|
|
392
|
+
) -> None:
|
|
393
|
+
super().__init__()
|
|
394
|
+
|
|
395
|
+
# initialize DDP settings first
|
|
396
|
+
self.use_ddp = use_ddp
|
|
397
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
398
|
+
if device is None:
|
|
399
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
400
|
+
elif isinstance(device, str):
|
|
401
|
+
self.device = torch.device(device)
|
|
402
|
+
else:
|
|
403
|
+
self.device = device
|
|
404
|
+
|
|
405
|
+
# setup distributed training if enabled
|
|
406
|
+
if self.use_ddp:
|
|
407
|
+
self._setup_ddp()
|
|
408
|
+
else:
|
|
409
|
+
self._setup_single_gpu()
|
|
410
|
+
|
|
411
|
+
# move models to appropriate device
|
|
412
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
413
|
+
self.forward_diffusion = forward_diffusion.to(self.device)
|
|
414
|
+
self.reverse_diffusion = reverse_diffusion.to(self.device)
|
|
415
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
416
|
+
|
|
417
|
+
# training components
|
|
418
|
+
self.metrics_ = metrics_
|
|
419
|
+
self.optimizer = optimizer
|
|
420
|
+
self.objective = objective
|
|
421
|
+
self.store_path = store_path or "ddpm_model"
|
|
422
|
+
self.data_loader = data_loader
|
|
423
|
+
self.val_loader = val_loader
|
|
424
|
+
self.max_epochs = max_epochs
|
|
425
|
+
self.max_token_length = max_token_length
|
|
426
|
+
self.patience = patience
|
|
427
|
+
self.val_frequency = val_frequency
|
|
428
|
+
self.image_output_range = image_output_range
|
|
429
|
+
self.normalize_output = normalize_output
|
|
430
|
+
self.log_frequency = log_frequency
|
|
431
|
+
self.use_compilation = use_compilation
|
|
432
|
+
|
|
433
|
+
# learning rate scheduling
|
|
434
|
+
self.scheduler = ReduceLROnPlateau(
|
|
435
|
+
self.optimizer,
|
|
436
|
+
patience=self.patience,
|
|
437
|
+
factor=0.5
|
|
438
|
+
)
|
|
439
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
440
|
+
|
|
441
|
+
# initialize tokenizer
|
|
442
|
+
if bert_tokenizer is None:
|
|
443
|
+
try:
|
|
444
|
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
445
|
+
except Exception as e:
|
|
446
|
+
raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
|
|
447
|
+
else:
|
|
448
|
+
self.tokenizer = bert_tokenizer
|
|
449
|
+
|
|
450
|
+
def _setup_ddp(self) -> None:
|
|
451
|
+
"""Setup Distributed Data Parallel training configuration.
|
|
452
|
+
|
|
453
|
+
Initializes process group, determines rank information, and sets up
|
|
454
|
+
CUDA device for the current process.
|
|
455
|
+
"""
|
|
456
|
+
# check if DDP environment variables are set
|
|
457
|
+
if "RANK" not in os.environ:
|
|
458
|
+
raise ValueError("DDP enabled but RANK environment variable not set")
|
|
459
|
+
if "LOCAL_RANK" not in os.environ:
|
|
460
|
+
raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
|
|
461
|
+
if "WORLD_SIZE" not in os.environ:
|
|
462
|
+
raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
|
|
463
|
+
|
|
464
|
+
# ensure CUDA is available for DDP
|
|
465
|
+
if not torch.cuda.is_available():
|
|
466
|
+
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
467
|
+
|
|
468
|
+
# initialize process group only if not already initialized
|
|
469
|
+
if not torch.distributed.is_initialized():
|
|
470
|
+
init_process_group(backend="nccl")
|
|
471
|
+
|
|
472
|
+
# get rank information
|
|
473
|
+
self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
|
|
474
|
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
|
|
475
|
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
|
|
476
|
+
|
|
477
|
+
# set device and make it current
|
|
478
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
479
|
+
torch.cuda.set_device(self.device)
|
|
480
|
+
|
|
481
|
+
# master process handles logging, checkpointing, etc.
|
|
482
|
+
self.master_process = self.ddp_rank == 0
|
|
483
|
+
|
|
484
|
+
if self.master_process:
|
|
485
|
+
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
486
|
+
|
|
487
|
+
def _setup_single_gpu(self) -> None:
|
|
488
|
+
"""Setup single GPU or CPU training configuration."""
|
|
489
|
+
self.ddp_rank = 0
|
|
490
|
+
self.ddp_local_rank = 0
|
|
491
|
+
self.ddp_world_size = 1
|
|
492
|
+
self.master_process = True
|
|
493
|
+
|
|
494
|
+
def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
|
|
495
|
+
"""Loads a training checkpoint to resume training.
|
|
496
|
+
|
|
497
|
+
Restores the state of the noise predictor, conditional model (if applicable),
|
|
498
|
+
and optimizer from a saved checkpoint. Handles DDP model state dict loading.
|
|
499
|
+
|
|
500
|
+
Parameters
|
|
501
|
+
----------
|
|
502
|
+
checkpoint_path : str
|
|
503
|
+
Path to the checkpoint file.
|
|
504
|
+
|
|
505
|
+
Returns
|
|
506
|
+
-------
|
|
507
|
+
epoch : int
|
|
508
|
+
The epoch at which the checkpoint was saved.
|
|
509
|
+
loss : float
|
|
510
|
+
The loss at the checkpoint.
|
|
511
|
+
"""
|
|
512
|
+
try:
|
|
513
|
+
# load checkpoint with proper device mapping
|
|
514
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
515
|
+
except FileNotFoundError:
|
|
516
|
+
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
|
517
|
+
|
|
518
|
+
# load noise predictor state
|
|
519
|
+
if 'model_state_dict_noise_predictor' not in checkpoint:
|
|
520
|
+
raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
|
|
521
|
+
|
|
522
|
+
# handle DDP wrapped model state dict
|
|
523
|
+
state_dict = checkpoint['model_state_dict_noise_predictor']
|
|
524
|
+
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
|
|
525
|
+
# if loading non-DDP checkpoint into DDP model, add 'module.' prefix
|
|
526
|
+
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
|
|
527
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
528
|
+
# if loading DDP checkpoint into non-DDP model, remove 'module.' prefix
|
|
529
|
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
530
|
+
|
|
531
|
+
self.noise_predictor.load_state_dict(state_dict)
|
|
532
|
+
|
|
533
|
+
# load conditional model state if applicable
|
|
534
|
+
if self.conditional_model is not None:
|
|
535
|
+
if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
|
|
536
|
+
cond_state_dict = checkpoint['model_state_dict_conditional']
|
|
537
|
+
# handle DDP wrapping for conditional model
|
|
538
|
+
if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
|
|
539
|
+
cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
|
|
540
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
|
|
541
|
+
cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
|
|
542
|
+
self.conditional_model.load_state_dict(cond_state_dict)
|
|
543
|
+
else:
|
|
544
|
+
warnings.warn(
|
|
545
|
+
"Checkpoint contains no 'model_state_dict_conditional' or it is None, "
|
|
546
|
+
"skipping conditional model loading"
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
# load variance_scheduler state
|
|
550
|
+
if 'variance_scheduler_model' not in checkpoint:
|
|
551
|
+
raise KeyError("Checkpoint missing 'variance_scheduler_model' key")
|
|
552
|
+
try:
|
|
553
|
+
if isinstance(self.forward_diffusion.variance_scheduler, nn.Module):
|
|
554
|
+
self.forward_diffusion.variance_scheduler.load_state_dict(
|
|
555
|
+
checkpoint['variance_scheduler_model'])
|
|
556
|
+
if isinstance(self.reverse_diffusion.variance_scheduler, nn.Module):
|
|
557
|
+
self.reverse_diffusion.variance_scheduler.load_state_dict(
|
|
558
|
+
checkpoint['variance_scheduler_model'])
|
|
559
|
+
else:
|
|
560
|
+
self.forward_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
|
|
561
|
+
self.reverse_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
|
|
562
|
+
except Exception as e:
|
|
563
|
+
warnings.warn(
|
|
564
|
+
f"Variance_scheduler loading failed: {e}. Continuing with current variance_scheduler.")
|
|
565
|
+
|
|
566
|
+
# load optimizer state
|
|
567
|
+
if 'optimizer_state_dict' not in checkpoint:
|
|
568
|
+
raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
|
|
569
|
+
try:
|
|
570
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
571
|
+
except ValueError as e:
|
|
572
|
+
warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
|
|
573
|
+
|
|
574
|
+
epoch = checkpoint.get('epoch', -1)
|
|
575
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
576
|
+
|
|
577
|
+
if self.master_process:
|
|
578
|
+
print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
|
|
579
|
+
return epoch, loss
|
|
580
|
+
|
|
581
|
+
@staticmethod
|
|
582
|
+
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
|
|
583
|
+
"""Creates a learning rate scheduler for warmup.
|
|
584
|
+
|
|
585
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
586
|
+
optimizer's initial value over the specified warmup epochs, then maintains it.
|
|
587
|
+
|
|
588
|
+
Parameters
|
|
589
|
+
----------
|
|
590
|
+
optimizer : torch.optim.Optimizer
|
|
591
|
+
Optimizer to apply the scheduler to.
|
|
592
|
+
warmup_epochs : int
|
|
593
|
+
Number of epochs for the warmup phase.
|
|
594
|
+
|
|
595
|
+
Returns
|
|
596
|
+
-------
|
|
597
|
+
torch.optim.lr_scheduler.LambdaLR
|
|
598
|
+
Learning rate scheduler for warmup.
|
|
599
|
+
"""
|
|
600
|
+
|
|
601
|
+
def lr_lambda(epoch):
|
|
602
|
+
if epoch < warmup_epochs:
|
|
603
|
+
return epoch / warmup_epochs
|
|
604
|
+
return 1.0
|
|
605
|
+
|
|
606
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
607
|
+
|
|
608
|
+
def _wrap_models_for_ddp(self) -> None:
|
|
609
|
+
"""Wrap models with DistributedDataParallel for multi-GPU training."""
|
|
610
|
+
if self.use_ddp:
|
|
611
|
+
# wrap noise predictor with DDP
|
|
612
|
+
self.noise_predictor = DDP(
|
|
613
|
+
self.noise_predictor,
|
|
614
|
+
device_ids=[self.ddp_local_rank],
|
|
615
|
+
find_unused_parameters=True
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# wrap conditional model with DDP if it exists
|
|
619
|
+
if self.conditional_model is not None:
|
|
620
|
+
self.conditional_model = DDP(
|
|
621
|
+
self.conditional_model,
|
|
622
|
+
device_ids=[self.ddp_local_rank],
|
|
623
|
+
find_unused_parameters=True
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
def forward(self) -> Tuple[List, float]:
|
|
627
|
+
"""Trains the DDPM model to predict noise added by the forward diffusion process.
|
|
628
|
+
|
|
629
|
+
Executes the training loop with support for distributed training, gradient accumulation,
|
|
630
|
+
mixed precision, gradient clipping, and learning rate scheduling. Includes validation,
|
|
631
|
+
early stopping, and checkpointing functionality.
|
|
632
|
+
|
|
633
|
+
Returns
|
|
634
|
+
-------
|
|
635
|
+
train_losses : list of float
|
|
636
|
+
List of mean training losses per epoch.
|
|
637
|
+
best_val_loss : float
|
|
638
|
+
Best validation or training loss achieved.
|
|
639
|
+
"""
|
|
640
|
+
# set models to training mode
|
|
641
|
+
self.noise_predictor.train()
|
|
642
|
+
if self.conditional_model is not None:
|
|
643
|
+
self.conditional_model.train()
|
|
644
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
645
|
+
self.reverse_diffusion.train()
|
|
646
|
+
self.forward_diffusion.train()
|
|
647
|
+
else:
|
|
648
|
+
self.reverse_diffusion.eval()
|
|
649
|
+
self.forward_diffusion.eval()
|
|
650
|
+
|
|
651
|
+
# compile models for optimization (if supported)
|
|
652
|
+
if self.use_compilation:
|
|
653
|
+
try:
|
|
654
|
+
self.noise_predictor = torch.compile(self.noise_predictor)
|
|
655
|
+
if self.conditional_model is not None:
|
|
656
|
+
self.conditional_model = torch.compile(self.conditional_model)
|
|
657
|
+
except Exception as e:
|
|
658
|
+
if self.master_process:
|
|
659
|
+
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
# wrap models for DDP after compilation
|
|
663
|
+
self._wrap_models_for_ddp()
|
|
664
|
+
|
|
665
|
+
# initialize training components
|
|
666
|
+
scaler = torch.GradScaler()
|
|
667
|
+
train_losses = []
|
|
668
|
+
best_val_loss = float("inf")
|
|
669
|
+
wait = 0
|
|
670
|
+
|
|
671
|
+
# main training loop
|
|
672
|
+
for epoch in range(self.max_epochs):
|
|
673
|
+
# set epoch for distributed sampler if using DDP
|
|
674
|
+
if self.use_ddp and hasattr(self.data_loader.sampler, 'set_epoch'):
|
|
675
|
+
self.data_loader.sampler.set_epoch(epoch)
|
|
676
|
+
|
|
677
|
+
train_losses_epoch = []
|
|
678
|
+
|
|
679
|
+
# training step loop with gradient accumulation
|
|
680
|
+
for step, (x, y) in enumerate(tqdm(self.data_loader, disable=not self.master_process)):
|
|
681
|
+
x = x.to(self.device)
|
|
682
|
+
|
|
683
|
+
# process conditional inputs if conditional model exists
|
|
684
|
+
if self.conditional_model is not None:
|
|
685
|
+
y_encoded = self._process_conditional_input(y)
|
|
686
|
+
else:
|
|
687
|
+
y_encoded = None
|
|
688
|
+
|
|
689
|
+
# forward pass with mixed precision
|
|
690
|
+
with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
|
|
691
|
+
# generate noise and timesteps
|
|
692
|
+
noise = torch.randn_like(x).to(self.device)
|
|
693
|
+
t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
|
|
694
|
+
|
|
695
|
+
# apply forward diffusion
|
|
696
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
697
|
+
|
|
698
|
+
# predict noise
|
|
699
|
+
predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
|
|
700
|
+
|
|
701
|
+
# compute loss and scale for gradient accumulation
|
|
702
|
+
loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
|
|
703
|
+
|
|
704
|
+
# backward pass
|
|
705
|
+
scaler.scale(loss).backward()
|
|
706
|
+
|
|
707
|
+
# gradient accumulation and optimizer step
|
|
708
|
+
if (step + 1) % self.grad_accumulation_steps == 0:
|
|
709
|
+
# clip gradients
|
|
710
|
+
scaler.unscale_(self.optimizer)
|
|
711
|
+
torch.nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
|
|
712
|
+
if self.conditional_model is not None:
|
|
713
|
+
torch.nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
|
|
714
|
+
|
|
715
|
+
# optimizer step
|
|
716
|
+
scaler.step(self.optimizer)
|
|
717
|
+
scaler.update()
|
|
718
|
+
self.optimizer.zero_grad()
|
|
719
|
+
|
|
720
|
+
# update learning rate (warmup scheduler)
|
|
721
|
+
self.warmup_lr_scheduler.step()
|
|
722
|
+
|
|
723
|
+
# record loss (unscaled)
|
|
724
|
+
train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
|
|
725
|
+
|
|
726
|
+
# compute mean training loss
|
|
727
|
+
mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
|
|
728
|
+
|
|
729
|
+
# all-reduce loss across processes for DDP
|
|
730
|
+
if self.use_ddp:
|
|
731
|
+
loss_tensor = torch.tensor(mean_train_loss, device=self.device)
|
|
732
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
|
733
|
+
mean_train_loss = loss_tensor.item()
|
|
734
|
+
|
|
735
|
+
train_losses.append(mean_train_loss)
|
|
736
|
+
|
|
737
|
+
# print training progress (only master process)
|
|
738
|
+
if self.master_process and (epoch + 1) % self.log_frequency == 0:
|
|
739
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
740
|
+
print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
|
|
741
|
+
|
|
742
|
+
# validation step
|
|
743
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
744
|
+
val_metrics = self.validate()
|
|
745
|
+
val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
|
|
746
|
+
|
|
747
|
+
if self.master_process:
|
|
748
|
+
print(f" | Val Loss: {val_loss:.4f}", end="")
|
|
749
|
+
if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
750
|
+
print(f" | FID: {fid:.4f}", end="")
|
|
751
|
+
if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
752
|
+
print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
|
|
753
|
+
if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
754
|
+
print(f" | LPIPS: {lpips_score:.4f}", end="")
|
|
755
|
+
print()
|
|
756
|
+
|
|
757
|
+
current_best = val_loss
|
|
758
|
+
self.scheduler.step(val_loss)
|
|
759
|
+
else:
|
|
760
|
+
if self.master_process:
|
|
761
|
+
print()
|
|
762
|
+
current_best = mean_train_loss
|
|
763
|
+
self.scheduler.step(mean_train_loss)
|
|
764
|
+
|
|
765
|
+
# save checkpoint and early stopping (only master process)
|
|
766
|
+
if self.master_process:
|
|
767
|
+
if current_best < best_val_loss and (epoch + 1) % self.val_frequency == 0:
|
|
768
|
+
best_val_loss = current_best
|
|
769
|
+
wait = 0
|
|
770
|
+
self._save_checkpoint(epoch + 1, best_val_loss)
|
|
771
|
+
else:
|
|
772
|
+
wait += 1
|
|
773
|
+
if wait >= self.patience:
|
|
774
|
+
print("Early stopping triggered")
|
|
775
|
+
self._save_checkpoint(epoch + 1, best_val_loss, "_early_stop")
|
|
776
|
+
break
|
|
777
|
+
|
|
778
|
+
# clean up DDP
|
|
779
|
+
if self.use_ddp:
|
|
780
|
+
destroy_process_group()
|
|
781
|
+
|
|
782
|
+
return train_losses, best_val_loss
|
|
783
|
+
|
|
784
|
+
def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor:
|
|
785
|
+
"""Process conditional input for text-to-image generation.
|
|
786
|
+
|
|
787
|
+
Parameters
|
|
788
|
+
----------
|
|
789
|
+
y : torch.Tensor or list
|
|
790
|
+
Conditional input (text prompts).
|
|
791
|
+
|
|
792
|
+
Returns
|
|
793
|
+
-------
|
|
794
|
+
torch.Tensor
|
|
795
|
+
Encoded conditional input.
|
|
796
|
+
"""
|
|
797
|
+
# Convert to string list
|
|
798
|
+
y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
|
|
799
|
+
y_list = [str(item) for item in y_list]
|
|
800
|
+
|
|
801
|
+
# Tokenize
|
|
802
|
+
y_encoded = self.tokenizer(
|
|
803
|
+
y_list,
|
|
804
|
+
padding="max_length",
|
|
805
|
+
truncation=True,
|
|
806
|
+
max_length=self.max_token_length,
|
|
807
|
+
return_tensors="pt"
|
|
808
|
+
).to(self.device)
|
|
809
|
+
|
|
810
|
+
# get embeddings
|
|
811
|
+
input_ids = y_encoded["input_ids"]
|
|
812
|
+
attention_mask = y_encoded["attention_mask"]
|
|
813
|
+
y_encoded = self.conditional_model(input_ids, attention_mask)
|
|
814
|
+
|
|
815
|
+
return y_encoded
|
|
816
|
+
|
|
817
|
+
def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "") -> None:
|
|
818
|
+
"""Save model checkpoint (only called by master process).
|
|
819
|
+
|
|
820
|
+
Parameters
|
|
821
|
+
----------
|
|
822
|
+
epoch : int
|
|
823
|
+
Current epoch number.
|
|
824
|
+
loss : float
|
|
825
|
+
Current loss value.
|
|
826
|
+
suffix : str, optional
|
|
827
|
+
Suffix to add to checkpoint filename.
|
|
828
|
+
"""
|
|
829
|
+
try:
|
|
830
|
+
# get state dicts, handling DDP wrapping
|
|
831
|
+
noise_predictor_state = (
|
|
832
|
+
self.noise_predictor.module.state_dict() if self.use_ddp
|
|
833
|
+
else self.noise_predictor.state_dict()
|
|
834
|
+
)
|
|
835
|
+
conditional_state = None
|
|
836
|
+
if self.conditional_model is not None:
|
|
837
|
+
conditional_state = (
|
|
838
|
+
self.conditional_model.module.state_dict() if self.use_ddp
|
|
839
|
+
else self.conditional_model.state_dict()
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
checkpoint = {
|
|
843
|
+
'epoch': epoch,
|
|
844
|
+
'model_state_dict_noise_predictor': noise_predictor_state,
|
|
845
|
+
'model_state_dict_conditional': conditional_state,
|
|
846
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
847
|
+
'loss': loss,
|
|
848
|
+
'variance_scheduler_model': (
|
|
849
|
+
self.forward_diffusion.variance_scheduler.state_dict() if isinstance(self.forward_diffusion.variance_scheduler, nn.Module)
|
|
850
|
+
else self.forward_diffusion.variance_scheduler
|
|
851
|
+
),
|
|
852
|
+
'max_epochs': self.max_epochs,
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
filename = f"ddpm_epoch_{epoch}{suffix}.pth"
|
|
856
|
+
filepath = os.path.join(self.store_path, filename)
|
|
857
|
+
os.makedirs(self.store_path, exist_ok=True)
|
|
858
|
+
torch.save(checkpoint, filepath)
|
|
859
|
+
|
|
860
|
+
print(f"Model saved at epoch {epoch}")
|
|
861
|
+
|
|
862
|
+
except Exception as e:
|
|
863
|
+
print(f"Failed to save model: {e}")
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
def validate(self) -> Tuple[float, float, float, float, float, float]:
|
|
868
|
+
"""Validates the noise predictor and computes evaluation metrics.
|
|
869
|
+
|
|
870
|
+
Computes validation loss (MSE between predicted and ground truth noise) and generates
|
|
871
|
+
samples using the reverse diffusion model. Evaluates image quality metrics if available.
|
|
872
|
+
|
|
873
|
+
Returns
|
|
874
|
+
-------
|
|
875
|
+
tuple
|
|
876
|
+
(val_loss, fid, mse, psnr, ssim, lpips_score) where metrics may be None if not computed.
|
|
877
|
+
"""
|
|
878
|
+
self.noise_predictor.eval()
|
|
879
|
+
if self.conditional_model is not None:
|
|
880
|
+
self.conditional_model.eval()
|
|
881
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
882
|
+
self.forward_diffusion.eval()
|
|
883
|
+
self.reverse_diffusion.eval()
|
|
884
|
+
|
|
885
|
+
val_losses = []
|
|
886
|
+
fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
|
|
887
|
+
|
|
888
|
+
with torch.no_grad():
|
|
889
|
+
for x, y in self.val_loader:
|
|
890
|
+
x = x.to(self.device)
|
|
891
|
+
x_orig = x.clone()
|
|
892
|
+
|
|
893
|
+
# process conditional input
|
|
894
|
+
if self.conditional_model is not None:
|
|
895
|
+
y_encoded = self._process_conditional_input(y)
|
|
896
|
+
else:
|
|
897
|
+
y_encoded = None
|
|
898
|
+
|
|
899
|
+
# compute validation loss
|
|
900
|
+
noise = torch.randn_like(x).to(self.device)
|
|
901
|
+
t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
|
|
902
|
+
|
|
903
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
904
|
+
predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
|
|
905
|
+
loss = self.objective(predicted_noise, noise)
|
|
906
|
+
val_losses.append(loss.item())
|
|
907
|
+
|
|
908
|
+
# generate samples for metrics evaluation
|
|
909
|
+
if self.metrics_ is not None and self.reverse_diffusion is not None:
|
|
910
|
+
xt = torch.randn_like(x).to(self.device)
|
|
911
|
+
|
|
912
|
+
# reverse diffusion sampling
|
|
913
|
+
for t in reversed(range(self.forward_diffusion.variance_scheduler.num_steps)):
|
|
914
|
+
time_steps = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
|
|
915
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y_encoded, None)
|
|
916
|
+
xt = self.reverse_diffusion(xt, predicted_noise, time_steps)
|
|
917
|
+
|
|
918
|
+
# clamp and normalize generated samples
|
|
919
|
+
x_hat = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
|
|
920
|
+
if self.normalize_output:
|
|
921
|
+
x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
922
|
+
x_orig = (x_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
923
|
+
|
|
924
|
+
# compute metrics
|
|
925
|
+
metrics_result = self.metrics_.forward(x_orig, x_hat)
|
|
926
|
+
fid, mse, psnr, ssim, lpips_score = metrics_result
|
|
927
|
+
|
|
928
|
+
if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
929
|
+
fid_scores.append(fid)
|
|
930
|
+
if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
931
|
+
mse_scores.append(mse)
|
|
932
|
+
psnr_scores.append(psnr)
|
|
933
|
+
ssim_scores.append(ssim)
|
|
934
|
+
if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
935
|
+
lpips_scores.append(lpips_score)
|
|
936
|
+
|
|
937
|
+
# compute average metrics
|
|
938
|
+
val_loss = torch.tensor(val_losses).mean().item()
|
|
939
|
+
|
|
940
|
+
# all-reduce validation metrics across processes for DDP
|
|
941
|
+
if self.use_ddp:
|
|
942
|
+
val_loss_tensor = torch.tensor(val_loss, device=self.device)
|
|
943
|
+
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
|
|
944
|
+
val_loss = val_loss_tensor.item()
|
|
945
|
+
|
|
946
|
+
fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
|
|
947
|
+
mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
|
|
948
|
+
psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
|
|
949
|
+
ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
|
|
950
|
+
lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
|
|
951
|
+
|
|
952
|
+
# return to training mode
|
|
953
|
+
self.noise_predictor.train()
|
|
954
|
+
if self.conditional_model is not None:
|
|
955
|
+
self.conditional_model.train()
|
|
956
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
957
|
+
self.reverse_diffusion.train()
|
|
958
|
+
self.forward_diffusion.train()
|
|
959
|
+
|
|
960
|
+
return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
###==================================================================================================================###
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
class SampleDDPM(nn.Module):
|
|
967
|
+
"""mage generation using a trained Denoising Diffusion Probabilistic Model (DDPM).
|
|
968
|
+
|
|
969
|
+
Implements the sampling process for DDPM, generating images by iteratively
|
|
970
|
+
denoising random noise using a trained noise predictor and reverse diffusion
|
|
971
|
+
process. Supports conditional generation with text prompts via a conditional
|
|
972
|
+
model, as inspired by Ho et al. (2020).
|
|
973
|
+
|
|
974
|
+
Parameters
|
|
975
|
+
----------
|
|
976
|
+
reverse_diffusion : nn.Module
|
|
977
|
+
Reverse diffusion module (e.g., ReverseDDPM) for the reverse process.
|
|
978
|
+
noise_predictor : nn.Module
|
|
979
|
+
Trained model to predict noise at each time step.
|
|
980
|
+
image_shape : tuple
|
|
981
|
+
Tuple of (height, width) specifying the generated image dimensions.
|
|
982
|
+
conditional_model : nn.Module, optional
|
|
983
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
984
|
+
tokenizer : str, optional
|
|
985
|
+
Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased").
|
|
986
|
+
max_token_length : int, optional
|
|
987
|
+
Maximum length for tokenized prompts (default: 77).
|
|
988
|
+
batch_size : int, optional
|
|
989
|
+
Number of images to generate per batch (default: 1).
|
|
990
|
+
in_channels : int, optional
|
|
991
|
+
Number of input channels for generated images (default: 3).
|
|
992
|
+
device : torch.device, optional
|
|
993
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
994
|
+
image_output_range : tuple, optional
|
|
995
|
+
Tuple of (min, max) for clamping generated images (default: (-1, 1)).
|
|
996
|
+
"""
|
|
997
|
+
def __init__(
|
|
998
|
+
self,
|
|
999
|
+
reverse_diffusion: torch.nn.Module,
|
|
1000
|
+
noise_predictor: torch.nn.Module,
|
|
1001
|
+
image_shape: Tuple[int, int],
|
|
1002
|
+
conditional_model: Optional[torch.nn.Module] = None,
|
|
1003
|
+
tokenizer: str = "bert-base-uncased",
|
|
1004
|
+
max_token_length: int = 77,
|
|
1005
|
+
batch_size: int = 1,
|
|
1006
|
+
in_channels: int = 3,
|
|
1007
|
+
device: Optional[str] = None,
|
|
1008
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0)
|
|
1009
|
+
) -> None:
|
|
1010
|
+
super().__init__()
|
|
1011
|
+
if device is None:
|
|
1012
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
1013
|
+
elif isinstance(device, str):
|
|
1014
|
+
self.device = torch.device(device)
|
|
1015
|
+
else:
|
|
1016
|
+
self.device = device
|
|
1017
|
+
self.reverse = reverse_diffusion.to(self.device)
|
|
1018
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
1019
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
1020
|
+
self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
|
|
1021
|
+
self.max_token_length = max_token_length
|
|
1022
|
+
self.in_channels = in_channels
|
|
1023
|
+
self.image_shape = image_shape
|
|
1024
|
+
self.batch_size = batch_size
|
|
1025
|
+
self.image_output_range = image_output_range
|
|
1026
|
+
|
|
1027
|
+
if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(
|
|
1028
|
+
isinstance(s, int) and s > 0 for s in image_shape):
|
|
1029
|
+
raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
|
|
1030
|
+
if batch_size <= 0:
|
|
1031
|
+
raise ValueError("batch_size must be positive")
|
|
1032
|
+
if not isinstance(image_output_range, (tuple, list)) or len(image_output_range) != 2 or image_output_range[0] >= image_output_range[1]:
|
|
1033
|
+
raise ValueError("output_range must be a tuple (min, max) with min < max")
|
|
1034
|
+
|
|
1035
|
+
def tokenize(self, prompts: Union[List, str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1036
|
+
"""Tokenizes text prompts for conditional generation.
|
|
1037
|
+
|
|
1038
|
+
Converts input prompts into tokenized input IDs and attention masks using the
|
|
1039
|
+
specified tokenizer, suitable for use with the conditional model.
|
|
1040
|
+
|
|
1041
|
+
Parameters
|
|
1042
|
+
----------
|
|
1043
|
+
prompts : str or list
|
|
1044
|
+
A single text prompt or a list of text prompts.
|
|
1045
|
+
|
|
1046
|
+
Returns
|
|
1047
|
+
-------
|
|
1048
|
+
input_ids : torch.Tensor
|
|
1049
|
+
Tokenized input IDs, shape (batch_size, max_length).
|
|
1050
|
+
attention_mask : torch.Tensor
|
|
1051
|
+
Attention mask, shape (batch_size, max_length).
|
|
1052
|
+
"""
|
|
1053
|
+
if isinstance(prompts, str):
|
|
1054
|
+
prompts = [prompts]
|
|
1055
|
+
elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
|
|
1056
|
+
raise TypeError("prompts must be a string or list of strings")
|
|
1057
|
+
encoded = self.tokenizer(
|
|
1058
|
+
prompts,
|
|
1059
|
+
padding="max_length",
|
|
1060
|
+
truncation=True,
|
|
1061
|
+
max_length=self.max_token_length,
|
|
1062
|
+
return_tensors="pt"
|
|
1063
|
+
)
|
|
1064
|
+
return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
|
|
1065
|
+
|
|
1066
|
+
def forward(
|
|
1067
|
+
self,
|
|
1068
|
+
conditions: Optional[Union[str, List]] = None,
|
|
1069
|
+
normalize_output: bool = True,
|
|
1070
|
+
save_images: bool = True,
|
|
1071
|
+
save_path: str = "ddpm_generated"
|
|
1072
|
+
) -> torch.Tensor:
|
|
1073
|
+
"""Generates images using the DDPM sampling process.
|
|
1074
|
+
|
|
1075
|
+
Iteratively denoises random noise to generate images using the reverse diffusion
|
|
1076
|
+
process and noise predictor. Supports conditional generation with text prompts.
|
|
1077
|
+
Optionally saves generated images to a specified directory.
|
|
1078
|
+
|
|
1079
|
+
Parameters
|
|
1080
|
+
----------
|
|
1081
|
+
conditions : str or list, optional
|
|
1082
|
+
Text prompt(s) for conditional generation, default None.
|
|
1083
|
+
normalize_output : bool, optional
|
|
1084
|
+
If True, normalizes output images to [0, 1] (default: True).
|
|
1085
|
+
save_images : bool, optional
|
|
1086
|
+
If True, saves generated images to `save_path` (default: True).
|
|
1087
|
+
save_path : str, optional
|
|
1088
|
+
Directory to save generated images (default: "ddpm_generated").
|
|
1089
|
+
|
|
1090
|
+
Returns
|
|
1091
|
+
-------
|
|
1092
|
+
generated_imgs (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width). If `normalize_output` is True, images are normalized to [0, 1]; otherwise, they are clamped to `output_range`.
|
|
1093
|
+
"""
|
|
1094
|
+
if conditions is not None and self.conditional_model is None:
|
|
1095
|
+
raise ValueError("Conditions provided but no conditional model specified")
|
|
1096
|
+
if conditions is None and self.conditional_model is not None:
|
|
1097
|
+
raise ValueError("Conditions must be provided for conditional model")
|
|
1098
|
+
|
|
1099
|
+
noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(
|
|
1100
|
+
self.device)
|
|
1101
|
+
|
|
1102
|
+
self.noise_predictor.eval()
|
|
1103
|
+
self.reverse.eval()
|
|
1104
|
+
if self.conditional_model:
|
|
1105
|
+
self.conditional_model.eval()
|
|
1106
|
+
|
|
1107
|
+
with torch.no_grad():
|
|
1108
|
+
xt = noisy_samples
|
|
1109
|
+
for t in reversed(range(self.reverse.variance_scheduler.num_steps)):
|
|
1110
|
+
time_steps = torch.full((self.batch_size,), t, device=self.device)#, dtype=torch.long)
|
|
1111
|
+
if self.conditional_model is not None and conditions is not None:
|
|
1112
|
+
input_ids, attention_masks = self.tokenize(conditions)
|
|
1113
|
+
key_padding_mask = (attention_masks == 0)
|
|
1114
|
+
y = self.conditional_model(input_ids, key_padding_mask)
|
|
1115
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y, None)
|
|
1116
|
+
else:
|
|
1117
|
+
predicted_noise = self.noise_predictor(xt, time_steps, None, None)
|
|
1118
|
+
xt = self.reverse(xt, predicted_noise, time_steps)
|
|
1119
|
+
|
|
1120
|
+
generated_imgs = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
|
|
1121
|
+
if normalize_output:
|
|
1122
|
+
generated_imgs = (generated_imgs - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
1123
|
+
|
|
1124
|
+
# save images if save_images is True
|
|
1125
|
+
if save_images:
|
|
1126
|
+
os.makedirs(save_path, exist_ok=True) # create directory if it doesn't exist
|
|
1127
|
+
for i in range(generated_imgs.size(0)):
|
|
1128
|
+
img_path = os.path.join(save_path, f"image_{i+1}.png")
|
|
1129
|
+
save_image(generated_imgs[i], img_path)
|
|
1130
|
+
|
|
1131
|
+
return generated_imgs
|
|
1132
|
+
|
|
1133
|
+
def to(self, device: torch.device) -> Self:
|
|
1134
|
+
"""Moves the module and its components to the specified device.
|
|
1135
|
+
|
|
1136
|
+
Updates the device attribute and moves the reverse diffusion, noise predictor,
|
|
1137
|
+
and conditional model (if present) to the specified device.
|
|
1138
|
+
|
|
1139
|
+
Parameters
|
|
1140
|
+
----------
|
|
1141
|
+
device : torch.device
|
|
1142
|
+
Target device for the module and its components.
|
|
1143
|
+
|
|
1144
|
+
Returns
|
|
1145
|
+
-------
|
|
1146
|
+
sample_ddpm (SampleDDPM) - moved to the specified device.
|
|
1147
|
+
"""
|
|
1148
|
+
self.device = device
|
|
1149
|
+
self.noise_predictor.to(device)
|
|
1150
|
+
self.reverse.to(device)
|
|
1151
|
+
if self.conditional_model:
|
|
1152
|
+
self.conditional_model.to(device)
|
|
1153
|
+
return super().to(device)
|