TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
torchdiff/sde.py
ADDED
|
@@ -0,0 +1,1231 @@
|
|
|
1
|
+
"""
|
|
2
|
+
**Score-Based Generative Modeling with Stochastic Differential Equations (SDE)**
|
|
3
|
+
|
|
4
|
+
This module implements a complete framework for score-based generative models using SDEs,
|
|
5
|
+
as described in Song et al. (2021, "Score-Based Generative Modeling through Stochastic
|
|
6
|
+
Differential Equations"). It provides components for forward and reverse diffusion
|
|
7
|
+
processes, hyperparameter management, training, and image sampling, supporting Variance
|
|
8
|
+
Exploding (VE), Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE
|
|
9
|
+
methods for flexible noise schedules. Supports both unconditional and conditional
|
|
10
|
+
generation with text prompts.
|
|
11
|
+
|
|
12
|
+
**Components**
|
|
13
|
+
|
|
14
|
+
- **ForwardSDE**: Forward diffusion process to add noise using SDE methods.
|
|
15
|
+
- **ReverseSDE**: Reverse diffusion process to denoise using SDE methods.
|
|
16
|
+
- **VarianceSchedulerSDE**: Noise schedule and SDE-specific parameter management.
|
|
17
|
+
- **TrainSDE**: Training loop with mixed precision and scheduling.
|
|
18
|
+
- **SampleSDE**: Image generation from trained SDE models.
|
|
19
|
+
|
|
20
|
+
**References**
|
|
21
|
+
|
|
22
|
+
- Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." arXiv preprint arXiv:2011.13456 (2020).
|
|
23
|
+
|
|
24
|
+
---------------------------------------------------------------------------------
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
import torch
|
|
29
|
+
import torch.nn as nn
|
|
30
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
31
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
32
|
+
from torch.distributed import init_process_group, destroy_process_group
|
|
33
|
+
import torch.distributed as dist
|
|
34
|
+
from typing import Optional, Tuple, Callable, List, Any, Union, Self
|
|
35
|
+
from tqdm import tqdm
|
|
36
|
+
from torch.optim.lr_scheduler import LambdaLR
|
|
37
|
+
from transformers import BertTokenizer
|
|
38
|
+
import warnings
|
|
39
|
+
from torchvision.utils import save_image
|
|
40
|
+
import os
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
###==================================================================================================================###
|
|
44
|
+
|
|
45
|
+
class ForwardSDE(nn.Module):
|
|
46
|
+
"""Forward diffusion process for SDE-based generative models.
|
|
47
|
+
|
|
48
|
+
Implements the forward diffusion process for score-based generative models using
|
|
49
|
+
Stochastic Differential Equations (SDEs), supporting Variance Exploding (VE),
|
|
50
|
+
Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE methods, as
|
|
51
|
+
described in Song et al. (2021).
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
variance_scheduler : object
|
|
56
|
+
Hyperparameter object (VarianceSchedulerSDE) containing SDE-specific parameters. Expected to have
|
|
57
|
+
attributes: `dt`, `sigmas`, `betas`, `cum_betas`.
|
|
58
|
+
sde_method : str
|
|
59
|
+
SDE method to use. Supported methods: "ve", "vp", "sub-vp", "ode".
|
|
60
|
+
"""
|
|
61
|
+
def __init__(self, variance_scheduler: torch.nn.Module, sde_method: str) -> None:
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.variance_scheduler = variance_scheduler
|
|
64
|
+
self.sde_method = sde_method
|
|
65
|
+
|
|
66
|
+
def forward(self, x0: torch.Tensor, noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
"""Applies the forward SDE diffusion process to the input data.
|
|
68
|
+
|
|
69
|
+
Perturbs the input data `x0` by adding noise according to the specified SDE
|
|
70
|
+
method at given time steps, incorporating drift and diffusion terms as applicable.
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
x0 : torch.Tensor
|
|
75
|
+
Input data tensor, shape (batch_size, channels, height, width).
|
|
76
|
+
noise : torch.Tensor
|
|
77
|
+
Gaussian noise tensor, same shape as `x0`.
|
|
78
|
+
time_steps : torch.Tensor
|
|
79
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
80
|
+
is in the range [0, varinace_scheduler.num_steps - 1].
|
|
81
|
+
|
|
82
|
+
Returns
|
|
83
|
+
-------
|
|
84
|
+
xt (torch.Tensor) - Noisy data tensor at the specified time steps, same shape as `x0`.
|
|
85
|
+
|
|
86
|
+
"""
|
|
87
|
+
dt = self.variance_scheduler.dt
|
|
88
|
+
if self.sde_method == "ve":
|
|
89
|
+
# use property to get sigmas (handles trainable case)
|
|
90
|
+
sigma_t = self.variance_scheduler.sigmas[time_steps]
|
|
91
|
+
sigma_t_prev = self.variance_scheduler.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
|
|
92
|
+
sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
|
|
93
|
+
x0 = x0 + noise * sigma_diff.view(-1, 1, 1, 1)
|
|
94
|
+
|
|
95
|
+
elif self.sde_method == "vp":
|
|
96
|
+
# use property to get betas (handles trainable case)
|
|
97
|
+
betas = self.variance_scheduler.betas[time_steps].view(-1, 1, 1, 1)
|
|
98
|
+
drift = -0.5 * betas * x0 * dt
|
|
99
|
+
diffusion = torch.sqrt(betas * dt) * noise
|
|
100
|
+
x0 = x0 + drift + diffusion
|
|
101
|
+
|
|
102
|
+
elif self.sde_method == "sub-vp":
|
|
103
|
+
# use properties to get betas and cum_betas (handles trainable case)
|
|
104
|
+
betas = self.variance_scheduler.betas[time_steps].view(-1, 1, 1, 1)
|
|
105
|
+
cum_betas = self.variance_scheduler._cum_betas[time_steps].view(-1, 1, 1, 1)
|
|
106
|
+
drift = -0.5 * betas * x0 * dt
|
|
107
|
+
diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise
|
|
108
|
+
x0 = x0 + drift + diffusion
|
|
109
|
+
|
|
110
|
+
elif self.sde_method == "ode":
|
|
111
|
+
# use property to get betas (handles trainable case)
|
|
112
|
+
betas = self.variance_scheduler.betas[time_steps].view(-1, 1, 1, 1)
|
|
113
|
+
drift = -0.5 * betas * x0 * dt
|
|
114
|
+
x0 = x0 + drift
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError(f"Unknown method: {self.sde_method}")
|
|
117
|
+
return x0
|
|
118
|
+
|
|
119
|
+
###==================================================================================================================###
|
|
120
|
+
|
|
121
|
+
class ReverseSDE(nn.Module):
|
|
122
|
+
"""Reverse diffusion process for SDE-based generative models.
|
|
123
|
+
|
|
124
|
+
Implements the reverse diffusion process for score-based generative models using
|
|
125
|
+
Stochastic Differential Equations (SDEs), supporting Variance Exploding (VE),
|
|
126
|
+
Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE methods, as
|
|
127
|
+
described in Song et al. (2021). The reverse process denoises a noisy input using
|
|
128
|
+
predicted noise estimates.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
variance_scheduler : object
|
|
133
|
+
Hyperparameter object (VarianceSchedulerSDE) containing SDE-specific parameters. Expected to have
|
|
134
|
+
attributes: `dt`, `sigmas`, `betas`, `cum_betas`.
|
|
135
|
+
sde_method : str
|
|
136
|
+
SDE method to use. Supported methods: "ve", "vp", "sub-vp", "ode".
|
|
137
|
+
"""
|
|
138
|
+
def __init__(self, variance_scheduler: torch.nn.Module, sde_method: str) -> None:
|
|
139
|
+
super().__init__()
|
|
140
|
+
self.variance_scheduler = variance_scheduler
|
|
141
|
+
self.sde_method = sde_method
|
|
142
|
+
|
|
143
|
+
def forward(self, xt: torch.Tensor, noise: torch.Tensor, predicted_noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
|
|
144
|
+
"""Applies the reverse SDE diffusion process to the noisy input.
|
|
145
|
+
|
|
146
|
+
Denoises the input `xt` by applying the reverse SDE process, using predicted
|
|
147
|
+
noise estimates and optional stochastic noise, according to the specified SDE
|
|
148
|
+
method at given time steps. Incorporates drift and diffusion terms as applicable.
|
|
149
|
+
|
|
150
|
+
Parameters
|
|
151
|
+
----------
|
|
152
|
+
xt : torch.Tensor
|
|
153
|
+
Noisy input tensor at time step `t`, shape (batch_size, channels, height, width).
|
|
154
|
+
noise : torch.Tensor or None
|
|
155
|
+
Gaussian noise tensor, same shape as `xt`, used for stochasticity. If None,
|
|
156
|
+
no stochastic noise is added (e.g., for deterministic ODE).
|
|
157
|
+
predicted_noise : torch.Tensor
|
|
158
|
+
Predicted noise tensor, same shape as `xt`, typically output by a neural network.
|
|
159
|
+
time_steps : torch.Tensor
|
|
160
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
161
|
+
is in the range [0, variance_scheduler.num_steps - 1].
|
|
162
|
+
|
|
163
|
+
Returns
|
|
164
|
+
-------
|
|
165
|
+
xt (torch.Tensor) - Denoised tensor at the previous time step, same shape as `xt`.
|
|
166
|
+
|
|
167
|
+
**Notes**
|
|
168
|
+
|
|
169
|
+
- For the "ve" and "ode" methods, the output is clamped to [-1e5, 1e5] to prevent numerical instability.
|
|
170
|
+
- Stochastic noise (`noise`) is only added if provided and the method supports it (not applicable for "ode" in non-VE cases).
|
|
171
|
+
"""
|
|
172
|
+
dt = self.variance_scheduler.dt
|
|
173
|
+
# use properties to get betas and cum_betas (handles trainable case)
|
|
174
|
+
betas = self.variance_scheduler.betas[time_steps].view(-1, 1, 1, 1)
|
|
175
|
+
cum_betas = self.variance_scheduler._cum_betas[time_steps].view(-1, 1, 1, 1)
|
|
176
|
+
if self.sde_method == "ve":
|
|
177
|
+
# use property to get sigmas (handles trainable case)
|
|
178
|
+
sigma_t = self.variance_scheduler.sigmas[time_steps]
|
|
179
|
+
sigma_t_prev = self.variance_scheduler.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
|
|
180
|
+
sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
|
|
181
|
+
drift = -(sigma_t ** 2 - sigma_t_prev ** 2).view(-1, 1, 1, 1) * predicted_noise * dt
|
|
182
|
+
diffusion = sigma_diff.view(-1, 1, 1, 1) * noise if noise is not None else 0
|
|
183
|
+
xt = xt + drift + diffusion
|
|
184
|
+
xt = torch.clamp(xt, -1e5, 1e5)
|
|
185
|
+
|
|
186
|
+
elif self.sde_method == "vp":
|
|
187
|
+
drift = -0.5 * betas * xt * dt - betas * predicted_noise * dt
|
|
188
|
+
diffusion = torch.sqrt(betas * dt) * noise if noise is not None else 0
|
|
189
|
+
xt = xt + drift + diffusion
|
|
190
|
+
|
|
191
|
+
elif self.sde_method == "sub-vp":
|
|
192
|
+
drift = -0.5 * betas * xt * dt - betas * (1 - torch.exp(-2 * cum_betas)) * predicted_noise * dt
|
|
193
|
+
diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise if noise is not None else 0
|
|
194
|
+
xt = xt + drift + diffusion
|
|
195
|
+
|
|
196
|
+
elif self.sde_method == "ode":
|
|
197
|
+
drift = -0.5 * betas * xt * dt - 0.5 * betas * predicted_noise * dt
|
|
198
|
+
xt = xt + drift
|
|
199
|
+
xt = torch.clamp(xt, -1e5, 1e5)
|
|
200
|
+
else:
|
|
201
|
+
raise ValueError(f"Unknown method: {self.sde_method}")
|
|
202
|
+
return xt
|
|
203
|
+
|
|
204
|
+
###==================================================================================================================###
|
|
205
|
+
|
|
206
|
+
class VarianceSchedulerSDE(nn.Module):
|
|
207
|
+
"""Hyperparameters for SDE-based generative models.
|
|
208
|
+
|
|
209
|
+
Manages the noise schedule and SDE-specific parameters for score-based generative
|
|
210
|
+
models, including beta and sigma schedules, time steps, and variance computations,
|
|
211
|
+
as described in Song et al. (2021). Supports trainable or fixed beta schedules and
|
|
212
|
+
multiple scheduling methods for flexible noise control.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
num_steps : int, optional
|
|
217
|
+
Number of diffusion steps (default: 1000).
|
|
218
|
+
beta_start : float, optional
|
|
219
|
+
Starting value for beta schedule (default: 1e-4).
|
|
220
|
+
beta_end : float, optional
|
|
221
|
+
Ending value for beta schedule (default: 0.02).
|
|
222
|
+
trainable_beta : bool, optional
|
|
223
|
+
Whether the beta schedule is trainable (default: False).
|
|
224
|
+
beta_method : str, optional
|
|
225
|
+
Method for computing the beta schedule (default: "linear").
|
|
226
|
+
Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
227
|
+
sigma_start : float, optional
|
|
228
|
+
Starting value for sigma schedule for VE method (default: 1e-3).
|
|
229
|
+
sigma_end : float, optional
|
|
230
|
+
Ending value for sigma schedule for VE method (default: 10.0).
|
|
231
|
+
start : float, optional
|
|
232
|
+
Start of the time interval for SDE integration (default: 0.0).
|
|
233
|
+
end : float, optional
|
|
234
|
+
End of the time interval for SDE integration (default: 1.0).
|
|
235
|
+
"""
|
|
236
|
+
def __init__(
|
|
237
|
+
self,
|
|
238
|
+
num_steps: int = 1000,
|
|
239
|
+
beta_start: float = 1e-4,
|
|
240
|
+
beta_end: float = 0.02,
|
|
241
|
+
trainable_beta: bool = False,
|
|
242
|
+
beta_method: str = "linear",
|
|
243
|
+
sigma_start: float = 1e-3,
|
|
244
|
+
sigma_end: float = 10.0,
|
|
245
|
+
start: float = 0.0,
|
|
246
|
+
end: float = 1.0
|
|
247
|
+
) -> None:
|
|
248
|
+
super().__init__()
|
|
249
|
+
self.num_steps = num_steps
|
|
250
|
+
self.beta_start = beta_start
|
|
251
|
+
self.beta_end = beta_end
|
|
252
|
+
self.trainable_beta = trainable_beta
|
|
253
|
+
self.beta_method = beta_method
|
|
254
|
+
self.sigma_start = sigma_start
|
|
255
|
+
self.sigma_end = sigma_end
|
|
256
|
+
self.start = start
|
|
257
|
+
self.end = end
|
|
258
|
+
|
|
259
|
+
if not (0 < self.beta_start < self.beta_end):
|
|
260
|
+
raise ValueError(f"beta_start ({self.beta_start}) and beta_end ({self.beta_end}) must satisfy 0 < start < end")
|
|
261
|
+
if not (0 < self.sigma_start < self.sigma_end):
|
|
262
|
+
raise ValueError(f"sigma_start ({self.sigma_start}) and sigma_end ({self.sigma_end}) must satisfy 0 < start < end")
|
|
263
|
+
if self.num_steps <= 0:
|
|
264
|
+
raise ValueError(f"num_steps ({self.num_steps}) must be positive")
|
|
265
|
+
|
|
266
|
+
beta_range = (beta_start, beta_end)
|
|
267
|
+
betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
|
|
268
|
+
self.time = torch.linspace(self.start, self.end, self.num_steps, dtype=torch.float32)
|
|
269
|
+
self.dt = (self.end - self.start) / self.num_steps
|
|
270
|
+
|
|
271
|
+
if trainable_beta:
|
|
272
|
+
# use reparameterization trick for trainable betas
|
|
273
|
+
# initialize unconstrained parameters and transform them to valid beta range
|
|
274
|
+
self.beta_raw = nn.Parameter(torch.logit((betas_init - beta_start) / (beta_end - beta_start)))
|
|
275
|
+
else:
|
|
276
|
+
self.register_buffer('betas_buffer', betas_init)
|
|
277
|
+
self.register_buffer('cum_betas', torch.cumsum(betas_init, dim=0) * self.dt)
|
|
278
|
+
self.register_buffer("sigmas_buffer", self.sigma_start * (self.sigma_end / self.sigma_start) ** self.time)
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def betas(self) -> torch.Tensor:
|
|
282
|
+
"""Returns the beta values, applying reparameterization if trainable."""
|
|
283
|
+
if self.trainable_beta:
|
|
284
|
+
# transform unconstrained parameters to valid beta range using sigmoid
|
|
285
|
+
return self.beta_start + (self.beta_end - self.beta_start) * torch.sigmoid(self.beta_raw)
|
|
286
|
+
else:
|
|
287
|
+
return self._buffers['betas_buffer']
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def _cum_betas(self) -> torch.Tensor:
|
|
291
|
+
"""Returns the cumulative beta values, computing dynamically if trainable."""
|
|
292
|
+
if self.trainable_beta:
|
|
293
|
+
return torch.cumsum(self.betas, dim=0) * self.dt
|
|
294
|
+
else:
|
|
295
|
+
return self._buffers['cum_betas']
|
|
296
|
+
|
|
297
|
+
@property
|
|
298
|
+
def sigmas(self) -> torch.Tensor:
|
|
299
|
+
"""Returns the sigma values, computing dynamically if trainable."""
|
|
300
|
+
if self.trainable_beta:
|
|
301
|
+
return self.sigma_start * (self.sigma_end / self.sigma_start) ** self.time
|
|
302
|
+
else:
|
|
303
|
+
return self._buffers['sigmas_buffer']
|
|
304
|
+
|
|
305
|
+
def compute_beta_schedule(self, beta_range: Tuple[float, float], num_steps: int, method: str) -> torch.Tensor:
|
|
306
|
+
"""Computes the beta schedule based on the specified method.
|
|
307
|
+
|
|
308
|
+
Generates a sequence of beta values for the SDE noise schedule using the chosen
|
|
309
|
+
method, ensuring values are clamped within the specified range.
|
|
310
|
+
|
|
311
|
+
Parameters
|
|
312
|
+
----------
|
|
313
|
+
beta_range : tuple
|
|
314
|
+
Tuple of (min_beta, max_beta) specifying the valid range for beta values.
|
|
315
|
+
num_steps : int
|
|
316
|
+
Number of diffusion steps.
|
|
317
|
+
method : str
|
|
318
|
+
Method for computing the beta schedule. Supported methods:
|
|
319
|
+
"linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
320
|
+
|
|
321
|
+
Returns
|
|
322
|
+
-------
|
|
323
|
+
betas (torch.Tensor) - Tensor of beta values, shape (num_steps,).
|
|
324
|
+
"""
|
|
325
|
+
beta_min, beta_max = beta_range
|
|
326
|
+
if method == "sigmoid":
|
|
327
|
+
x = torch.linspace(-6, 6, num_steps)
|
|
328
|
+
beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
|
|
329
|
+
elif method == "quadratic":
|
|
330
|
+
x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
|
|
331
|
+
beta = x ** 2
|
|
332
|
+
elif method == "constant":
|
|
333
|
+
beta = torch.full((num_steps,), beta_max)
|
|
334
|
+
elif method == "inverse_time":
|
|
335
|
+
beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
|
|
336
|
+
beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
|
|
337
|
+
elif method == "linear":
|
|
338
|
+
beta = torch.linspace(beta_min, beta_max, num_steps)
|
|
339
|
+
else:
|
|
340
|
+
raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
|
|
341
|
+
beta = torch.clamp(beta, min=beta_min, max=beta_max)
|
|
342
|
+
return beta
|
|
343
|
+
|
|
344
|
+
def get_variance(self, time_steps: torch.Tensor, method: str) -> torch.Tensor:
|
|
345
|
+
"""Computes the variance for the specified SDE method at given time steps.
|
|
346
|
+
|
|
347
|
+
Calculates the variance used in SDE diffusion processes based on the method
|
|
348
|
+
(VE, VP, or sub-VP), leveraging the sigma or cumulative beta schedules.
|
|
349
|
+
|
|
350
|
+
Parameters
|
|
351
|
+
----------
|
|
352
|
+
time_steps : torch.Tensor
|
|
353
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
354
|
+
is in the range [0, num_steps - 1].
|
|
355
|
+
method : str
|
|
356
|
+
SDE method to compute variance for. Supported methods: "ve", "vp", "sub-vp".
|
|
357
|
+
|
|
358
|
+
Returns
|
|
359
|
+
-------
|
|
360
|
+
variance_values (torch.Tensor) - Variance values for the specified time steps, shape (batch_size,).
|
|
361
|
+
"""
|
|
362
|
+
if method == "ve":
|
|
363
|
+
return self.sigmas[time_steps] ** 2
|
|
364
|
+
elif method == "vp":
|
|
365
|
+
return 1 - torch.exp(-self.cum_betas[time_steps])
|
|
366
|
+
elif method == "sub-vp":
|
|
367
|
+
return 1 - torch.exp(-2 * self.cum_betas[time_steps])
|
|
368
|
+
else:
|
|
369
|
+
raise ValueError(f"Unknown method: {method}")
|
|
370
|
+
|
|
371
|
+
###==================================================================================================================###
|
|
372
|
+
|
|
373
|
+
class TrainSDE(nn.Module):
|
|
374
|
+
"""Trainer for score-based generative models using Stochastic Differential Equations.
|
|
375
|
+
|
|
376
|
+
Manages the training process for SDE-based generative models, optimizing a noise
|
|
377
|
+
predictor to learn the noise added by the forward SDE process, as described in Song
|
|
378
|
+
et al. (2021). Supports conditional training with text prompts, mixed precision,
|
|
379
|
+
learning rate scheduling, early stopping, and checkpointing.
|
|
380
|
+
|
|
381
|
+
Parameters
|
|
382
|
+
----------
|
|
383
|
+
|
|
384
|
+
noise_predictor : nn.Module
|
|
385
|
+
Model to predict noise added during the forward SDE process.
|
|
386
|
+
forward_diffusion : nn.Module
|
|
387
|
+
Forward SDE diffusion module for adding noise.
|
|
388
|
+
reverse_diffusion: nn.Module
|
|
389
|
+
Reverse SDE diffusion module for denoising.
|
|
390
|
+
data_loader : torch.utils.data.DataLoader
|
|
391
|
+
DataLoader for training data.
|
|
392
|
+
optimizer : torch.optim.Optimizer
|
|
393
|
+
Optimizer for training the noise predictor and conditional model (if applicable).
|
|
394
|
+
objective : callable
|
|
395
|
+
Loss function to compute the difference between predicted and actual noise.
|
|
396
|
+
val_loader : torch.utils.data.DataLoader, optional
|
|
397
|
+
DataLoader for validation data, default None.
|
|
398
|
+
max_epochs : int, optional
|
|
399
|
+
Maximum number of training epochs (default: 1000).
|
|
400
|
+
device : torch.device, optional
|
|
401
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
402
|
+
conditional_model : nn.Module, optional
|
|
403
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
404
|
+
metrics_ : object, optional
|
|
405
|
+
Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
|
|
406
|
+
bert_tokenizer : BertTokenizer, optional
|
|
407
|
+
Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
|
|
408
|
+
max_token_length : int, optional
|
|
409
|
+
Maximum length for tokenized prompts (default: 77).
|
|
410
|
+
store_path : str, optional
|
|
411
|
+
Path to save model checkpoints (default: "sde_model.pth").
|
|
412
|
+
patience : int, optional
|
|
413
|
+
Number of epochs to wait for improvement before early stopping (default: 10).
|
|
414
|
+
warmup_epochs : int, optional
|
|
415
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
416
|
+
val_frequency : int, optional
|
|
417
|
+
Frequency (in epochs) for validation (default: 10).
|
|
418
|
+
image_output_range : tuple, optional
|
|
419
|
+
Range for clamping generated images (default: (-1, 1)).
|
|
420
|
+
normalize_output : bool, optional
|
|
421
|
+
Whether to normalize generated images to [0, 1] for metrics (default: True).
|
|
422
|
+
use_ddp : bool, optional
|
|
423
|
+
Whether to use Distributed Data Parallel training (default: False).
|
|
424
|
+
grad_accumulation_steps : int, optional
|
|
425
|
+
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
426
|
+
log_frequency : int, optional
|
|
427
|
+
Number of epochs before printing loss.
|
|
428
|
+
use_compilation : bool, optional
|
|
429
|
+
whether the model is internally compiled using torch.compile (default: false)
|
|
430
|
+
"""
|
|
431
|
+
def __init__(
|
|
432
|
+
self,
|
|
433
|
+
noise_predictor: torch.nn.Module,
|
|
434
|
+
forward_diffusion: torch.nn.Module,
|
|
435
|
+
reverse_diffusion: torch.nn.Module,
|
|
436
|
+
data_loader: torch.utils.data.DataLoader,
|
|
437
|
+
optimizer: torch.optim.Optimizer,
|
|
438
|
+
objective: Callable,
|
|
439
|
+
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
440
|
+
max_epochs: int = 1000,
|
|
441
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
442
|
+
conditional_model: Optional[torch.nn.Module] = None,
|
|
443
|
+
metrics_: Optional[Any] = None,
|
|
444
|
+
bert_tokenizer: Optional[BertTokenizer] = None,
|
|
445
|
+
max_token_length: int = 77,
|
|
446
|
+
store_path: Optional[str] = None,
|
|
447
|
+
patience: int = 100,
|
|
448
|
+
warmup_epochs: int = 100,
|
|
449
|
+
val_frequency: int = 10,
|
|
450
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0),
|
|
451
|
+
normalize_output: bool = True,
|
|
452
|
+
use_ddp: bool = False,
|
|
453
|
+
grad_accumulation_steps: int = 1,
|
|
454
|
+
log_frequency: int = 1,
|
|
455
|
+
use_compilation: bool = False
|
|
456
|
+
) -> None:
|
|
457
|
+
|
|
458
|
+
super().__init__()
|
|
459
|
+
# initialize DDP settings first
|
|
460
|
+
self.use_ddp = use_ddp
|
|
461
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
462
|
+
if device is None:
|
|
463
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
464
|
+
elif isinstance(device, str):
|
|
465
|
+
self.device = torch.device(device)
|
|
466
|
+
else:
|
|
467
|
+
self.device = device
|
|
468
|
+
|
|
469
|
+
# setup distributed training if enabled
|
|
470
|
+
if self.use_ddp:
|
|
471
|
+
self._setup_ddp()
|
|
472
|
+
else:
|
|
473
|
+
self._setup_single_gpu()
|
|
474
|
+
|
|
475
|
+
# move models to appropriate device
|
|
476
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
477
|
+
self.forward_diffusion = forward_diffusion.to(self.device)
|
|
478
|
+
self.reverse_diffusion = reverse_diffusion.to(self.device)
|
|
479
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
480
|
+
|
|
481
|
+
# training components
|
|
482
|
+
self.metrics_ = metrics_
|
|
483
|
+
self.optimizer = optimizer
|
|
484
|
+
self.objective = objective
|
|
485
|
+
self.store_path = store_path or "sde_model"
|
|
486
|
+
self.data_loader = data_loader
|
|
487
|
+
self.val_loader = val_loader
|
|
488
|
+
self.max_epochs = max_epochs
|
|
489
|
+
self.max_token_length = max_token_length
|
|
490
|
+
self.patience = patience
|
|
491
|
+
self.val_frequency = val_frequency
|
|
492
|
+
self.image_output_range = image_output_range
|
|
493
|
+
self.normalize_output = normalize_output
|
|
494
|
+
self.log_frequency = log_frequency
|
|
495
|
+
self.use_compilation = use_compilation
|
|
496
|
+
|
|
497
|
+
# learning rate scheduling
|
|
498
|
+
self.scheduler = ReduceLROnPlateau(
|
|
499
|
+
self.optimizer,
|
|
500
|
+
patience=self.patience,
|
|
501
|
+
factor=0.5
|
|
502
|
+
)
|
|
503
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
504
|
+
|
|
505
|
+
# initialize tokenizer
|
|
506
|
+
if bert_tokenizer is None:
|
|
507
|
+
try:
|
|
508
|
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
509
|
+
except Exception as e:
|
|
510
|
+
raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
|
|
511
|
+
else:
|
|
512
|
+
self.tokenizer = bert_tokenizer
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def _setup_ddp(self) -> None:
|
|
516
|
+
"""Setup Distributed Data Parallel training configuration.
|
|
517
|
+
|
|
518
|
+
Initializes process group, determines rank information, and sets up
|
|
519
|
+
CUDA device for the current process.
|
|
520
|
+
"""
|
|
521
|
+
# check if DDP environment variables are set
|
|
522
|
+
if "RANK" not in os.environ:
|
|
523
|
+
raise ValueError("DDP enabled but RANK environment variable not set")
|
|
524
|
+
if "LOCAL_RANK" not in os.environ:
|
|
525
|
+
raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
|
|
526
|
+
if "WORLD_SIZE" not in os.environ:
|
|
527
|
+
raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
|
|
528
|
+
|
|
529
|
+
# ensure CUDA is available for DDP
|
|
530
|
+
if not torch.cuda.is_available():
|
|
531
|
+
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
532
|
+
|
|
533
|
+
# initialize process group only if not already initialized
|
|
534
|
+
if not torch.distributed.is_initialized():
|
|
535
|
+
init_process_group(backend="nccl")
|
|
536
|
+
|
|
537
|
+
# get rank information
|
|
538
|
+
self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes
|
|
539
|
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node
|
|
540
|
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes
|
|
541
|
+
|
|
542
|
+
# set device and make it current
|
|
543
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
544
|
+
torch.cuda.set_device(self.device)
|
|
545
|
+
|
|
546
|
+
# master process handles logging, checkpointing, etc.
|
|
547
|
+
self.master_process = self.ddp_rank == 0
|
|
548
|
+
|
|
549
|
+
if self.master_process:
|
|
550
|
+
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
551
|
+
|
|
552
|
+
def _setup_single_gpu(self) -> None:
|
|
553
|
+
"""Setup single GPU or CPU training configuration."""
|
|
554
|
+
self.ddp_rank = 0
|
|
555
|
+
self.ddp_local_rank = 0
|
|
556
|
+
self.ddp_world_size = 1
|
|
557
|
+
self.master_process = True
|
|
558
|
+
|
|
559
|
+
def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
|
|
560
|
+
"""Loads a training checkpoint to resume training.
|
|
561
|
+
|
|
562
|
+
Restores the state of the noise predictor, conditional model (if applicable),
|
|
563
|
+
and optimizer from a saved checkpoint. Handles DDP model state dict loading.
|
|
564
|
+
|
|
565
|
+
Parameters
|
|
566
|
+
----------
|
|
567
|
+
checkpoint_path : str
|
|
568
|
+
Path to the checkpoint file.
|
|
569
|
+
|
|
570
|
+
Returns
|
|
571
|
+
-------
|
|
572
|
+
epoch : int
|
|
573
|
+
The epoch at which the checkpoint was saved.
|
|
574
|
+
loss : float
|
|
575
|
+
The loss at the checkpoint.
|
|
576
|
+
"""
|
|
577
|
+
try:
|
|
578
|
+
# load checkpoint with proper device mapping
|
|
579
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
580
|
+
except FileNotFoundError:
|
|
581
|
+
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
|
582
|
+
|
|
583
|
+
# load noise predictor state
|
|
584
|
+
if 'model_state_dict_noise_predictor' not in checkpoint:
|
|
585
|
+
raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
|
|
586
|
+
|
|
587
|
+
# handle DDP wrapped model state dict
|
|
588
|
+
state_dict = checkpoint['model_state_dict_noise_predictor']
|
|
589
|
+
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
|
|
590
|
+
# if loading non-DDP checkpoint into DDP model, add 'module.' prefix
|
|
591
|
+
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
|
|
592
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
593
|
+
# if loading DDP checkpoint into non-DDP model, remove 'module.' prefix
|
|
594
|
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
595
|
+
|
|
596
|
+
self.noise_predictor.load_state_dict(state_dict)
|
|
597
|
+
|
|
598
|
+
# load conditional model state if applicable
|
|
599
|
+
if self.conditional_model is not None:
|
|
600
|
+
if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
|
|
601
|
+
cond_state_dict = checkpoint['model_state_dict_conditional']
|
|
602
|
+
# handle DDP wrapping for conditional model
|
|
603
|
+
if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
|
|
604
|
+
cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
|
|
605
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
|
|
606
|
+
cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
|
|
607
|
+
self.conditional_model.load_state_dict(cond_state_dict)
|
|
608
|
+
else:
|
|
609
|
+
warnings.warn(
|
|
610
|
+
"Checkpoint contains no 'model_state_dict_conditional' or it is None, "
|
|
611
|
+
"skipping conditional model loading"
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
# load variance_scheduler state
|
|
615
|
+
if 'variance_scheduler_model' not in checkpoint:
|
|
616
|
+
raise KeyError("Checkpoint missing 'variance_scheduler_model' key")
|
|
617
|
+
try:
|
|
618
|
+
if isinstance(self.forward_diffusion.variance_scheduler, nn.Module):
|
|
619
|
+
self.forward_diffusion.variance_scheduler.load_state_dict(checkpoint['variance_scheduler_model'])
|
|
620
|
+
if isinstance(self.reverse_diffusion.variance_scheduler, nn.Module):
|
|
621
|
+
self.reverse_diffusion.variance_scheduler.load_state_dict(checkpoint['variance_scheduler_model'])
|
|
622
|
+
else:
|
|
623
|
+
self.forward_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
|
|
624
|
+
self.reverse_diffusion.variance_scheduler = checkpoint['variance_scheduler_model']
|
|
625
|
+
except Exception as e:
|
|
626
|
+
warnings.warn(f"Variance_scheduler loading failed: {e}. Continuing with current variance_scheduler.")
|
|
627
|
+
|
|
628
|
+
# load optimizer state
|
|
629
|
+
if 'optimizer_state_dict' not in checkpoint:
|
|
630
|
+
raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
|
|
631
|
+
try:
|
|
632
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
633
|
+
except ValueError as e:
|
|
634
|
+
warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
|
|
635
|
+
|
|
636
|
+
epoch = checkpoint.get('epoch', -1)
|
|
637
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
638
|
+
|
|
639
|
+
if self.master_process:
|
|
640
|
+
print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
|
|
641
|
+
return epoch, loss
|
|
642
|
+
|
|
643
|
+
@staticmethod
|
|
644
|
+
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
|
|
645
|
+
"""Creates a learning rate scheduler for warmup.
|
|
646
|
+
|
|
647
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
648
|
+
optimizer's initial value over the specified warmup epochs, then maintains it.
|
|
649
|
+
|
|
650
|
+
Parameters
|
|
651
|
+
----------
|
|
652
|
+
optimizer : torch.optim.Optimizer
|
|
653
|
+
Optimizer to apply the scheduler to.
|
|
654
|
+
warmup_epochs : int
|
|
655
|
+
Number of epochs for the warmup phase.
|
|
656
|
+
|
|
657
|
+
Returns
|
|
658
|
+
-------
|
|
659
|
+
torch.optim.lr_scheduler.LambdaLR
|
|
660
|
+
Learning rate scheduler for warmup.
|
|
661
|
+
"""
|
|
662
|
+
|
|
663
|
+
def lr_lambda(epoch):
|
|
664
|
+
if epoch < warmup_epochs:
|
|
665
|
+
return epoch / warmup_epochs
|
|
666
|
+
return 1.0
|
|
667
|
+
|
|
668
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
669
|
+
|
|
670
|
+
def _wrap_models_for_ddp(self) -> None:
|
|
671
|
+
"""Wrap models with DistributedDataParallel for multi-GPU training."""
|
|
672
|
+
if self.use_ddp:
|
|
673
|
+
# wrap noise predictor with DDP
|
|
674
|
+
self.noise_predictor = DDP(
|
|
675
|
+
self.noise_predictor,
|
|
676
|
+
device_ids=[self.ddp_local_rank],
|
|
677
|
+
find_unused_parameters=True
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
# wrap conditional model with DDP if it exists
|
|
681
|
+
if self.conditional_model is not None:
|
|
682
|
+
self.conditional_model = DDP(
|
|
683
|
+
self.conditional_model,
|
|
684
|
+
device_ids=[self.ddp_local_rank],
|
|
685
|
+
find_unused_parameters=True
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def forward(self) -> Tuple[List, float]:
|
|
690
|
+
"""Trains the SDE model to predict noise added by the forward diffusion process.
|
|
691
|
+
|
|
692
|
+
Executes the training loop, optimizing the noise predictor and conditional model
|
|
693
|
+
(if applicable) using mixed precision, gradient clipping, and learning rate
|
|
694
|
+
scheduling. Supports validation, early stopping, and checkpointing.
|
|
695
|
+
|
|
696
|
+
Returns
|
|
697
|
+
-------
|
|
698
|
+
train_losses : list of float
|
|
699
|
+
List of mean training losses per epoch.
|
|
700
|
+
best_val_loss : float
|
|
701
|
+
Best validation or training loss achieved.
|
|
702
|
+
|
|
703
|
+
**Notes**
|
|
704
|
+
|
|
705
|
+
- Training uses mixed precision via `torch.cuda.amp` or `torch.amp` for efficiency.
|
|
706
|
+
- Checkpoints are saved when the validation (or training) loss improves, and on early stopping.
|
|
707
|
+
- Early stopping is triggered if no improvement occurs for `patience` epochs.
|
|
708
|
+
"""
|
|
709
|
+
# set models to training mode
|
|
710
|
+
self.noise_predictor.train()
|
|
711
|
+
if self.conditional_model is not None:
|
|
712
|
+
self.conditional_model.train()
|
|
713
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
714
|
+
self.reverse_diffusion.train()
|
|
715
|
+
self.forward_diffusion.train()
|
|
716
|
+
else:
|
|
717
|
+
self.reverse_diffusion.eval()
|
|
718
|
+
self.forward_diffusion.eval()
|
|
719
|
+
|
|
720
|
+
# compile models for optimization (if supported)
|
|
721
|
+
if self.use_compilation:
|
|
722
|
+
try:
|
|
723
|
+
self.noise_predictor = torch.compile(self.noise_predictor)
|
|
724
|
+
if self.conditional_model is not None:
|
|
725
|
+
self.conditional_model = torch.compile(self.conditional_model)
|
|
726
|
+
except Exception as e:
|
|
727
|
+
if self.master_process:
|
|
728
|
+
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
# wrap models for DDP after compilation
|
|
732
|
+
self._wrap_models_for_ddp()
|
|
733
|
+
|
|
734
|
+
# initialize training components
|
|
735
|
+
scaler = torch.GradScaler()
|
|
736
|
+
train_losses = []
|
|
737
|
+
best_val_loss = float("inf")
|
|
738
|
+
wait = 0
|
|
739
|
+
|
|
740
|
+
# main training loop
|
|
741
|
+
for epoch in range(self.max_epochs):
|
|
742
|
+
# set epoch for distributed sampler if using DDP
|
|
743
|
+
if self.use_ddp and hasattr(self.data_loader.sampler, 'set_epoch'):
|
|
744
|
+
self.data_loader.sampler.set_epoch(epoch)
|
|
745
|
+
|
|
746
|
+
train_losses_epoch = []
|
|
747
|
+
# training step loop with gradient accumulation
|
|
748
|
+
for step, (x, y) in enumerate(tqdm(self.data_loader, disable=not self.master_process)):
|
|
749
|
+
x = x.to(self.device)
|
|
750
|
+
|
|
751
|
+
# process conditional inputs if conditional model exists
|
|
752
|
+
if self.conditional_model is not None:
|
|
753
|
+
y_encoded = self._process_conditional_input(y)
|
|
754
|
+
else:
|
|
755
|
+
y_encoded = None
|
|
756
|
+
|
|
757
|
+
# forward pass with mixed precision
|
|
758
|
+
with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
|
|
759
|
+
# generate noise and timesteps
|
|
760
|
+
noise = torch.randn_like(x).to(self.device)
|
|
761
|
+
t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
|
|
762
|
+
|
|
763
|
+
# apply forward diffusion
|
|
764
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
765
|
+
|
|
766
|
+
# predict noise
|
|
767
|
+
predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
|
|
768
|
+
|
|
769
|
+
# compute loss and scale for gradient accumulation
|
|
770
|
+
loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
|
|
771
|
+
|
|
772
|
+
# backward pass
|
|
773
|
+
scaler.scale(loss).backward()
|
|
774
|
+
|
|
775
|
+
# gradient accumulation and optimizer step
|
|
776
|
+
if (step + 1) % self.grad_accumulation_steps == 0:
|
|
777
|
+
# clip gradients
|
|
778
|
+
scaler.unscale_(self.optimizer)
|
|
779
|
+
torch.nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
|
|
780
|
+
if self.conditional_model is not None:
|
|
781
|
+
torch.nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
|
|
782
|
+
|
|
783
|
+
# optimizer step
|
|
784
|
+
scaler.step(self.optimizer)
|
|
785
|
+
scaler.update()
|
|
786
|
+
self.optimizer.zero_grad()
|
|
787
|
+
|
|
788
|
+
# update learning rate (warmup scheduler)
|
|
789
|
+
self.warmup_lr_scheduler.step()
|
|
790
|
+
|
|
791
|
+
# record loss (unscaled)
|
|
792
|
+
train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
|
|
793
|
+
|
|
794
|
+
# compute mean training loss
|
|
795
|
+
mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
|
|
796
|
+
train_losses.append(mean_train_loss)
|
|
797
|
+
|
|
798
|
+
# all-reduce loss across processes for DDP
|
|
799
|
+
if self.use_ddp:
|
|
800
|
+
loss_tensor = torch.tensor(mean_train_loss, device=self.device)
|
|
801
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
|
802
|
+
mean_train_loss = loss_tensor.item()
|
|
803
|
+
|
|
804
|
+
# print training progress (only master process)
|
|
805
|
+
if self.master_process and (epoch + 1) % self.log_frequency == 0:
|
|
806
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
807
|
+
print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
|
|
808
|
+
|
|
809
|
+
# validation step
|
|
810
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
811
|
+
val_metrics = self.validate()
|
|
812
|
+
val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
|
|
813
|
+
|
|
814
|
+
if self.master_process:
|
|
815
|
+
print(f" | Val Loss: {val_loss:.4f}", end="")
|
|
816
|
+
if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
817
|
+
print(f" | FID: {fid:.4f}", end="")
|
|
818
|
+
if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
819
|
+
print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
|
|
820
|
+
if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
821
|
+
print(f" | LPIPS: {lpips_score:.4f}", end="")
|
|
822
|
+
print()
|
|
823
|
+
|
|
824
|
+
current_best = val_loss
|
|
825
|
+
self.scheduler.step(val_loss)
|
|
826
|
+
else:
|
|
827
|
+
if self.master_process:
|
|
828
|
+
print()
|
|
829
|
+
current_best = mean_train_loss
|
|
830
|
+
self.scheduler.step(mean_train_loss)
|
|
831
|
+
|
|
832
|
+
# save checkpoint and early stopping (only master process)
|
|
833
|
+
if self.master_process:
|
|
834
|
+
if current_best < best_val_loss and (epoch + 1) % self.val_frequency == 0:
|
|
835
|
+
best_val_loss = current_best
|
|
836
|
+
wait = 0
|
|
837
|
+
self._save_checkpoint(epoch + 1, best_val_loss)
|
|
838
|
+
else:
|
|
839
|
+
wait += 1
|
|
840
|
+
if wait >= self.patience:
|
|
841
|
+
print("Early stopping triggered")
|
|
842
|
+
self._save_checkpoint(epoch + 1, best_val_loss, "_early_stop")
|
|
843
|
+
break
|
|
844
|
+
|
|
845
|
+
# clean up DDP
|
|
846
|
+
if self.use_ddp:
|
|
847
|
+
destroy_process_group()
|
|
848
|
+
|
|
849
|
+
return train_losses, best_val_loss
|
|
850
|
+
|
|
851
|
+
def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor:
|
|
852
|
+
"""Process conditional input for text-to-image generation.
|
|
853
|
+
|
|
854
|
+
Parameters
|
|
855
|
+
----------
|
|
856
|
+
y : torch.Tensor or list
|
|
857
|
+
Conditional input (text prompts).
|
|
858
|
+
|
|
859
|
+
Returns
|
|
860
|
+
-------
|
|
861
|
+
torch.Tensor
|
|
862
|
+
Encoded conditional input.
|
|
863
|
+
"""
|
|
864
|
+
# convert to string list
|
|
865
|
+
y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
|
|
866
|
+
y_list = [str(item) for item in y_list]
|
|
867
|
+
|
|
868
|
+
# tokenize
|
|
869
|
+
y_encoded = self.tokenizer(
|
|
870
|
+
y_list,
|
|
871
|
+
padding="max_length",
|
|
872
|
+
truncation=True,
|
|
873
|
+
max_length=self.max_token_length,
|
|
874
|
+
return_tensors="pt"
|
|
875
|
+
).to(self.device)
|
|
876
|
+
|
|
877
|
+
# get embeddings
|
|
878
|
+
input_ids = y_encoded["input_ids"]
|
|
879
|
+
attention_mask = y_encoded["attention_mask"]
|
|
880
|
+
y_encoded = self.conditional_model(input_ids, attention_mask)
|
|
881
|
+
|
|
882
|
+
return y_encoded
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "") -> None:
|
|
886
|
+
"""Save model checkpoint (only called by master process).
|
|
887
|
+
|
|
888
|
+
Parameters
|
|
889
|
+
----------
|
|
890
|
+
epoch : int
|
|
891
|
+
Current epoch number.
|
|
892
|
+
loss : float
|
|
893
|
+
Current loss value.
|
|
894
|
+
suffix : str, optional
|
|
895
|
+
Suffix to add to checkpoint filename.
|
|
896
|
+
"""
|
|
897
|
+
try:
|
|
898
|
+
# get state dicts, handling DDP wrapping
|
|
899
|
+
noise_predictor_state = (
|
|
900
|
+
self.noise_predictor.module.state_dict() if self.use_ddp
|
|
901
|
+
else self.noise_predictor.state_dict()
|
|
902
|
+
)
|
|
903
|
+
conditional_state = None
|
|
904
|
+
if self.conditional_model is not None:
|
|
905
|
+
conditional_state = (
|
|
906
|
+
self.conditional_model.module.state_dict() if self.use_ddp
|
|
907
|
+
else self.conditional_model.state_dict()
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
checkpoint = {
|
|
911
|
+
'epoch': epoch,
|
|
912
|
+
'model_state_dict_noise_predictor': noise_predictor_state,
|
|
913
|
+
'model_state_dict_conditional': conditional_state,
|
|
914
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
915
|
+
'loss': loss,
|
|
916
|
+
'variance_scheduler_model': (
|
|
917
|
+
self.forward_diffusion.variance_scheduler.state_dict() if isinstance(self.forward_diffusion.variance_scheduler, nn.Module)
|
|
918
|
+
else self.forward_diffusion.variance_scheduler
|
|
919
|
+
),
|
|
920
|
+
'max_epochs': self.max_epochs,
|
|
921
|
+
}
|
|
922
|
+
|
|
923
|
+
filename = f"sde_epoch_{epoch}{suffix}.pth"
|
|
924
|
+
filepath = os.path.join(self.store_path, filename)
|
|
925
|
+
os.makedirs(self.store_path, exist_ok=True)
|
|
926
|
+
torch.save(checkpoint, filepath)
|
|
927
|
+
|
|
928
|
+
print(f"Model saved at epoch {epoch}")
|
|
929
|
+
|
|
930
|
+
except Exception as e:
|
|
931
|
+
print(f"Failed to save model: {e}")
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
def validate(self) -> Tuple[float, float, float, float, float, float]:
|
|
935
|
+
"""Validates the noise predictor and computes evaluation Metrics.
|
|
936
|
+
|
|
937
|
+
Computes validation loss (MSE between predicted and ground truth noise) and generates
|
|
938
|
+
samples using the reverse diffusion model by manually iterating over timesteps.
|
|
939
|
+
Decodes samples to images and computes image-domain Metrics (MSE, PSNR, SSIM, FID, LPIPS)
|
|
940
|
+
if metrics_ is provided.
|
|
941
|
+
|
|
942
|
+
Returns
|
|
943
|
+
-------
|
|
944
|
+
val_loss : float
|
|
945
|
+
Mean validation loss.
|
|
946
|
+
fid : float, or `float('inf')` if not computed
|
|
947
|
+
Mean FID score.
|
|
948
|
+
mse : float, or None if not computed
|
|
949
|
+
Mean MSE
|
|
950
|
+
psnr : float, or None if not computed
|
|
951
|
+
Mean PSNR
|
|
952
|
+
ssim : float, or None if not computed
|
|
953
|
+
Mean SSIM
|
|
954
|
+
lpips_score : float, or None if not computed
|
|
955
|
+
Mean LPIPS score
|
|
956
|
+
"""
|
|
957
|
+
self.noise_predictor.eval()
|
|
958
|
+
if self.conditional_model is not None:
|
|
959
|
+
self.conditional_model.eval()
|
|
960
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
961
|
+
self.forward_diffusion.eval()
|
|
962
|
+
self.reverse_diffusion.eval()
|
|
963
|
+
|
|
964
|
+
val_losses = []
|
|
965
|
+
fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
|
|
966
|
+
|
|
967
|
+
with torch.no_grad():
|
|
968
|
+
for x, y in self.val_loader:
|
|
969
|
+
x = x.to(self.device)
|
|
970
|
+
x_orig = x.clone()
|
|
971
|
+
|
|
972
|
+
# process conditional input
|
|
973
|
+
if self.conditional_model is not None:
|
|
974
|
+
y_encoded = self._process_conditional_input(y)
|
|
975
|
+
else:
|
|
976
|
+
y_encoded = None
|
|
977
|
+
|
|
978
|
+
# compute validation loss
|
|
979
|
+
noise = torch.randn_like(x).to(self.device)
|
|
980
|
+
t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (x.shape[0],)).to(self.device)
|
|
981
|
+
|
|
982
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
983
|
+
predicted_noise = self.noise_predictor(noisy_x, t, y_encoded, None)
|
|
984
|
+
loss = self.objective(predicted_noise, noise)
|
|
985
|
+
val_losses.append(loss.item())
|
|
986
|
+
|
|
987
|
+
# generate samples for metrics evaluation
|
|
988
|
+
if self.metrics_ is not None and self.reverse_diffusion is not None:
|
|
989
|
+
xt = torch.randn_like(x).to(self.device)
|
|
990
|
+
|
|
991
|
+
# reverse diffusion sampling
|
|
992
|
+
for t in reversed(range(self.forward_diffusion.variance_scheduler.num_steps)):
|
|
993
|
+
time_steps = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
|
|
994
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y_encoded, None)
|
|
995
|
+
noise = torch.randn_like(xt) if getattr(self.reverse_diffusion, "method", None) != "ode" else None
|
|
996
|
+
xt = self.reverse_diffusion(xt, noise, predicted_noise, time_steps)
|
|
997
|
+
|
|
998
|
+
# clamp and normalize generated samples
|
|
999
|
+
x_hat = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
|
|
1000
|
+
if self.normalize_output:
|
|
1001
|
+
x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
1002
|
+
x_orig = (x_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
1003
|
+
|
|
1004
|
+
# compute metrics
|
|
1005
|
+
metrics_result = self.metrics_.forward(x_orig, x_hat)
|
|
1006
|
+
fid, mse, psnr, ssim, lpips_score = metrics_result
|
|
1007
|
+
|
|
1008
|
+
if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
1009
|
+
fid_scores.append(fid)
|
|
1010
|
+
if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
1011
|
+
mse_scores.append(mse)
|
|
1012
|
+
psnr_scores.append(psnr)
|
|
1013
|
+
ssim_scores.append(ssim)
|
|
1014
|
+
if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
1015
|
+
lpips_scores.append(lpips_score)
|
|
1016
|
+
|
|
1017
|
+
# compute average metrics
|
|
1018
|
+
val_loss = torch.tensor(val_losses).mean().item()
|
|
1019
|
+
|
|
1020
|
+
# all-reduce validation metrics across processes for DDP
|
|
1021
|
+
if self.use_ddp:
|
|
1022
|
+
val_loss_tensor = torch.tensor(val_loss, device=self.device)
|
|
1023
|
+
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
|
|
1024
|
+
val_loss = val_loss_tensor.item()
|
|
1025
|
+
|
|
1026
|
+
fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
|
|
1027
|
+
mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
|
|
1028
|
+
psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
|
|
1029
|
+
ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
|
|
1030
|
+
lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
|
|
1031
|
+
|
|
1032
|
+
# return to training mode
|
|
1033
|
+
self.noise_predictor.train()
|
|
1034
|
+
if self.conditional_model is not None:
|
|
1035
|
+
self.conditional_model.train()
|
|
1036
|
+
if self.forward_diffusion.variance_scheduler.trainable_beta:
|
|
1037
|
+
self.reverse_diffusion.train()
|
|
1038
|
+
self.forward_diffusion.train()
|
|
1039
|
+
|
|
1040
|
+
return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
###==================================================================================================================###
|
|
1044
|
+
|
|
1045
|
+
class SampleSDE(nn.Module):
|
|
1046
|
+
"""Sampler for generating images using SDE-based generative models.
|
|
1047
|
+
|
|
1048
|
+
Generates images by iteratively denoising random noise using the reverse SDE process
|
|
1049
|
+
and a trained noise predictor, as described in Song et al. (2021). Supports both
|
|
1050
|
+
unconditional and conditional generation with text prompts.
|
|
1051
|
+
|
|
1052
|
+
Parameters
|
|
1053
|
+
----------
|
|
1054
|
+
reverse_diffusion : ReverseSDE
|
|
1055
|
+
Reverse SDE diffusion module for denoising.
|
|
1056
|
+
noise_predictor : nn.Module
|
|
1057
|
+
Model to predict noise added during the forward SDE process.
|
|
1058
|
+
image_shape : tuple
|
|
1059
|
+
Shape of generated images as (height, width).
|
|
1060
|
+
conditional_model : nn.Module, optional
|
|
1061
|
+
Model for conditional generation (e.g., TextEncoder), default None.
|
|
1062
|
+
tokenizer : str or BertTokenizer, optional
|
|
1063
|
+
Tokenizer for processing text prompts, default "bert-base-uncased".
|
|
1064
|
+
max_token_length : int, optional
|
|
1065
|
+
Maximum length for tokenized prompts (default: 77).
|
|
1066
|
+
batch_size : int, optional
|
|
1067
|
+
Number of images to generate per batch (default: 1).
|
|
1068
|
+
in_channels : int, optional
|
|
1069
|
+
Number of input channels for generated images (default: 3).
|
|
1070
|
+
device : torch.device, optional
|
|
1071
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
1072
|
+
image_output_range : tuple, optional
|
|
1073
|
+
Range for clamping generated images (min, max), default (-1, 1).
|
|
1074
|
+
"""
|
|
1075
|
+
def __init__(
|
|
1076
|
+
self,
|
|
1077
|
+
reverse_diffusion: torch.nn.Module,
|
|
1078
|
+
noise_predictor: torch.nn.Module,
|
|
1079
|
+
image_shape: Tuple[int, int],
|
|
1080
|
+
conditional_model: Optional[torch.nn.Module] = None,
|
|
1081
|
+
tokenizer: str = "bert-base-uncased",
|
|
1082
|
+
max_token_length: int = 77,
|
|
1083
|
+
batch_size: int = 1,
|
|
1084
|
+
in_channels: int = 3,
|
|
1085
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
1086
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0)
|
|
1087
|
+
) -> None:
|
|
1088
|
+
super().__init__()
|
|
1089
|
+
if device is None:
|
|
1090
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
1091
|
+
elif isinstance(device, str):
|
|
1092
|
+
self.device = torch.device(device)
|
|
1093
|
+
else:
|
|
1094
|
+
self.device = device
|
|
1095
|
+
self.reverse = reverse_diffusion.to(self.device)
|
|
1096
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
1097
|
+
self.conditional_model = conditional_model.to(self.device) if conditional_model else None
|
|
1098
|
+
self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
|
|
1099
|
+
self.max_token_length = max_token_length
|
|
1100
|
+
self.in_channels = in_channels
|
|
1101
|
+
self.image_shape = image_shape
|
|
1102
|
+
self.batch_size = batch_size
|
|
1103
|
+
self.image_output_range = image_output_range
|
|
1104
|
+
|
|
1105
|
+
if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(isinstance(s, int) and s > 0 for s in image_shape):
|
|
1106
|
+
raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
|
|
1107
|
+
if batch_size <= 0:
|
|
1108
|
+
raise ValueError("batch_size must be positive")
|
|
1109
|
+
if not isinstance(image_output_range, (tuple, list)) or len(image_output_range) != 2 or image_output_range[0] >= image_output_range[1]:
|
|
1110
|
+
raise ValueError("output_range must be a tuple (min, max) with min < max")
|
|
1111
|
+
|
|
1112
|
+
def tokenize(self, prompts: Union[str, List]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1113
|
+
"""Tokenizes text prompts for conditional generation.
|
|
1114
|
+
|
|
1115
|
+
Converts input prompts into tokenized tensors using the specified tokenizer.
|
|
1116
|
+
|
|
1117
|
+
Parameters
|
|
1118
|
+
----------
|
|
1119
|
+
prompts : str or list
|
|
1120
|
+
Text prompt(s) for conditional generation. Can be a single string or a list
|
|
1121
|
+
of strings.
|
|
1122
|
+
|
|
1123
|
+
Returns
|
|
1124
|
+
-------
|
|
1125
|
+
input_ids : torch.Tensor
|
|
1126
|
+
Tokenized input IDs, shape (batch_size, max_token_length).
|
|
1127
|
+
attention_mask : torch.Tensor
|
|
1128
|
+
Attention mask, shape (batch_size, max_token_length).
|
|
1129
|
+
"""
|
|
1130
|
+
if isinstance(prompts, str):
|
|
1131
|
+
prompts = [prompts]
|
|
1132
|
+
elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
|
|
1133
|
+
raise TypeError("prompts must be a string or list of strings")
|
|
1134
|
+
encoded = self.tokenizer(
|
|
1135
|
+
prompts,
|
|
1136
|
+
padding="max_length",
|
|
1137
|
+
truncation=True,
|
|
1138
|
+
max_length=self.max_token_length,
|
|
1139
|
+
return_tensors="pt"
|
|
1140
|
+
)
|
|
1141
|
+
return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
|
|
1142
|
+
|
|
1143
|
+
def forward(
|
|
1144
|
+
self,
|
|
1145
|
+
conditions: Optional[Union[str, List]] = None,
|
|
1146
|
+
normalize_output: bool = True,
|
|
1147
|
+
save_images: bool = True,
|
|
1148
|
+
save_path: str = "sde_generated"
|
|
1149
|
+
) -> torch.Tensor:
|
|
1150
|
+
"""Generates images using the reverse SDE sampling process.
|
|
1151
|
+
|
|
1152
|
+
Iteratively denoises random noise to generate images using the reverse SDE process
|
|
1153
|
+
and noise predictor. Supports conditional generation with text prompts.
|
|
1154
|
+
|
|
1155
|
+
Parameters
|
|
1156
|
+
----------
|
|
1157
|
+
conditions : str or list, optional
|
|
1158
|
+
Text prompt(s) for conditional generation, default None.
|
|
1159
|
+
normalize_output : bool, optional
|
|
1160
|
+
If True, normalizes output images to [0, 1] (default: True).
|
|
1161
|
+
save_images : bool, optional
|
|
1162
|
+
If True, saves generated images to `save_path` (default: True).
|
|
1163
|
+
save_path : str, optional
|
|
1164
|
+
Directory to save generated images (default: "sde_generated").
|
|
1165
|
+
|
|
1166
|
+
Returns
|
|
1167
|
+
-------
|
|
1168
|
+
generated_imgs (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width). If `normalize_output` is True, images are normalized to [0, 1]; otherwise, they are clamped to `output_range`.
|
|
1169
|
+
"""
|
|
1170
|
+
if conditions is not None and self.conditional_model is None:
|
|
1171
|
+
raise ValueError("Conditions provided but no conditional model specified")
|
|
1172
|
+
if conditions is None and self.conditional_model is not None:
|
|
1173
|
+
raise ValueError("Conditions must be provided for conditional model")
|
|
1174
|
+
|
|
1175
|
+
noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
|
|
1176
|
+
|
|
1177
|
+
self.noise_predictor.eval()
|
|
1178
|
+
self.reverse.eval()
|
|
1179
|
+
if self.conditional_model:
|
|
1180
|
+
self.conditional_model.eval()
|
|
1181
|
+
|
|
1182
|
+
with torch.no_grad():
|
|
1183
|
+
xt = noisy_samples
|
|
1184
|
+
for t in reversed(range(self.reverse.variance_scheduler.num_steps)):
|
|
1185
|
+
noise = torch.randn_like(xt) if self.reverse.sde_method != "ode" else None
|
|
1186
|
+
time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
|
|
1187
|
+
|
|
1188
|
+
if self.conditional_model is not None and conditions is not None:
|
|
1189
|
+
input_ids, attention_masks = self.tokenize(conditions)
|
|
1190
|
+
key_padding_mask = (attention_masks == 0)
|
|
1191
|
+
y = self.conditional_model(input_ids, key_padding_mask)
|
|
1192
|
+
predicted_noise = self.noise_predictor(xt, time_steps, y)
|
|
1193
|
+
else:
|
|
1194
|
+
predicted_noise = self.noise_predictor(xt, time_steps)
|
|
1195
|
+
|
|
1196
|
+
xt = self.reverse(xt, noise, predicted_noise, time_steps)
|
|
1197
|
+
|
|
1198
|
+
generated_imgs = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
|
|
1199
|
+
if normalize_output:
|
|
1200
|
+
generated_imgs = (generated_imgs - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
1201
|
+
|
|
1202
|
+
# save images if save_images is True
|
|
1203
|
+
if save_images:
|
|
1204
|
+
os.makedirs(save_path, exist_ok=True)
|
|
1205
|
+
for i in range(generated_imgs.size(0)):
|
|
1206
|
+
img_path = os.path.join(save_path, f"image_{i+1}.png")
|
|
1207
|
+
save_image(generated_imgs[i], img_path)
|
|
1208
|
+
|
|
1209
|
+
return generated_imgs
|
|
1210
|
+
|
|
1211
|
+
def to(self, device: torch.device) -> Self:
|
|
1212
|
+
"""Moves the module and its components to the specified device.
|
|
1213
|
+
|
|
1214
|
+
Updates the device attribute and moves the reverse diffusion, noise predictor,
|
|
1215
|
+
and conditional model (if present) to the specified device.
|
|
1216
|
+
|
|
1217
|
+
Parameters
|
|
1218
|
+
----------
|
|
1219
|
+
device : torch.device
|
|
1220
|
+
Target device for the module and its components.
|
|
1221
|
+
|
|
1222
|
+
Returns
|
|
1223
|
+
-------
|
|
1224
|
+
sample_sde (SampleSDE) - moved to the specified device.
|
|
1225
|
+
"""
|
|
1226
|
+
self.device = device
|
|
1227
|
+
self.noise_predictor.to(device)
|
|
1228
|
+
self.reverse.to(device)
|
|
1229
|
+
if self.conditional_model:
|
|
1230
|
+
self.conditional_model.to(device)
|
|
1231
|
+
return super().to(device)
|