TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
unclip/prior_diff.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
class VarianceSchedulerUnCLIP(nn.Module):
|
|
7
|
+
"""Manages noise schedule parameters for UnCLIP diffusion models.
|
|
8
|
+
|
|
9
|
+
Handles beta values, derived noise schedule quantities, and a subsampled time step schedule
|
|
10
|
+
(tau schedule) for UnCLIP diffusion processes. Supports trainable or fixed beta schedules
|
|
11
|
+
and multiple scheduling methods, including linear, sigmoid, quadratic, constant, inverse_time,
|
|
12
|
+
and cosine schedules.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
`eta` : float, optional
|
|
17
|
+
Noise scaling factor for the reverse process (default: 0, deterministic).
|
|
18
|
+
`num_steps` : int, optional
|
|
19
|
+
Total number of diffusion steps (default: 1000).
|
|
20
|
+
`tau_num_steps` : int, optional
|
|
21
|
+
Number of subsampled time steps for sampling (default: 100).
|
|
22
|
+
`beta_start` : float, optional
|
|
23
|
+
Starting value for beta (default: 1e-4).
|
|
24
|
+
`beta_end` : float, optional
|
|
25
|
+
Ending value for beta (default: 0.02).
|
|
26
|
+
`trainable_beta` : bool, optional
|
|
27
|
+
Whether the beta schedule is trainable (default: False).
|
|
28
|
+
`beta_method` : str, optional
|
|
29
|
+
Method for computing the beta schedule (default: "linear").
|
|
30
|
+
Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time", "cosine".
|
|
31
|
+
"""
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
eta: Optional[float] = None,
|
|
35
|
+
num_steps: int = 1000,
|
|
36
|
+
tau_num_steps: int = 100,
|
|
37
|
+
beta_start: float = 1e-4,
|
|
38
|
+
beta_end: float = 0.02,
|
|
39
|
+
trainable_beta: bool = False,
|
|
40
|
+
beta_method: str = "linear"
|
|
41
|
+
) -> None:
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.eta = eta or 0
|
|
44
|
+
self.num_steps = num_steps
|
|
45
|
+
self.tau_num_steps = tau_num_steps
|
|
46
|
+
self.beta_start = beta_start
|
|
47
|
+
self.beta_end = beta_end
|
|
48
|
+
self.trainable_beta = trainable_beta
|
|
49
|
+
self.beta_method = beta_method
|
|
50
|
+
|
|
51
|
+
if not (0 < beta_start < beta_end < 1):
|
|
52
|
+
raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
|
|
53
|
+
if num_steps <= 0:
|
|
54
|
+
raise ValueError(f"num_steps ({num_steps}) must be positive")
|
|
55
|
+
|
|
56
|
+
beta_range = (beta_start, beta_end)
|
|
57
|
+
betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
|
|
58
|
+
|
|
59
|
+
if trainable_beta:
|
|
60
|
+
self.beta_raw = nn.Parameter(torch.logit((betas_init - beta_start) / (beta_end - beta_start)))
|
|
61
|
+
else:
|
|
62
|
+
self.register_buffer('betas_buffer', betas_init)
|
|
63
|
+
self.register_buffer('alphas', 1 - self.betas)
|
|
64
|
+
self.register_buffer('alpha_cumprod', torch.cumprod(self.alphas, dim=0))
|
|
65
|
+
self.register_buffer('sqrt_alpha_cumprod', torch.sqrt(self.alpha_cumprod))
|
|
66
|
+
self.register_buffer('sqrt_one_minus_alpha_cumprod', torch.sqrt(1 - self.alpha_cumprod))
|
|
67
|
+
|
|
68
|
+
self.register_buffer('tau_indices', torch.linspace(0, num_steps - 1, tau_num_steps, dtype=torch.long))
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def betas(self) -> torch.Tensor:
|
|
72
|
+
"""Returns the beta values, applying reparameterization if trainable.
|
|
73
|
+
|
|
74
|
+
Returns the beta values, using sigmoid reparameterization for trainable betas
|
|
75
|
+
or directly accessing the stored buffer for fixed betas.
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
betas : torch.Tensor
|
|
80
|
+
Beta values, shape (num_steps,).
|
|
81
|
+
"""
|
|
82
|
+
if self.trainable_beta:
|
|
83
|
+
return self.beta_start + (self.beta_end - self.beta_start) * torch.sigmoid(self.beta_raw)
|
|
84
|
+
return self._buffers['betas_buffer']
|
|
85
|
+
|
|
86
|
+
def compute_beta_schedule(self, beta_range: Tuple[float, float], num_steps: int, method: str) -> torch.Tensor:
|
|
87
|
+
"""Computes the beta schedule based on the specified method.
|
|
88
|
+
|
|
89
|
+
Generates a sequence of beta values for the noise schedule using the chosen method,
|
|
90
|
+
ensuring values are clamped within the specified range. Supports linear, sigmoid,
|
|
91
|
+
quadratic, constant, inverse_time, and cosine schedules.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
`beta_range` : tuple
|
|
96
|
+
Tuple of (min_beta, max_beta) specifying the valid range for beta values.
|
|
97
|
+
`num_steps` : int
|
|
98
|
+
Number of diffusion steps.
|
|
99
|
+
`method` : str
|
|
100
|
+
Method for computing the beta schedule. Supported methods:
|
|
101
|
+
"linear", "sigmoid", "quadratic", "constant", "inverse_time", "cosine".
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
beta : torch.Tensor
|
|
106
|
+
Tensor of beta values, shape (num_steps,).
|
|
107
|
+
"""
|
|
108
|
+
beta_min, beta_max = beta_range
|
|
109
|
+
if method == "sigmoid":
|
|
110
|
+
x = torch.linspace(-6, 6, num_steps)
|
|
111
|
+
beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
|
|
112
|
+
elif method == "quadratic":
|
|
113
|
+
x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
|
|
114
|
+
beta = x ** 2
|
|
115
|
+
elif method == "constant":
|
|
116
|
+
beta = torch.full((num_steps,), beta_max)
|
|
117
|
+
elif method == "inverse_time":
|
|
118
|
+
beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
|
|
119
|
+
beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
|
|
120
|
+
elif method == "linear":
|
|
121
|
+
beta = torch.linspace(beta_min, beta_max, num_steps)
|
|
122
|
+
elif method == "cosine":
|
|
123
|
+
s = 0.008
|
|
124
|
+
steps = num_steps + 1
|
|
125
|
+
x = torch.linspace(0, num_steps, steps)
|
|
126
|
+
alphas_cumprod = torch.cos(((x / num_steps) + s) / (1 + s) * math.pi * 0.5) ** 2
|
|
127
|
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
|
128
|
+
beta = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError(f"Unknown beta_method: {method}")
|
|
131
|
+
beta = torch.clamp(beta, min=beta_min, max=beta_max)
|
|
132
|
+
return beta
|
|
133
|
+
|
|
134
|
+
def get_tau_schedule(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
135
|
+
"""Computes the subsampled (tau) noise schedule for UnCLIP.
|
|
136
|
+
|
|
137
|
+
Returns the noise schedule parameters for the subsampled time steps used in
|
|
138
|
+
UnCLIP sampling, based on the `tau_indices`.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
tau_betas : torch.Tensor
|
|
143
|
+
Beta values for subsampled steps, shape (tau_num_steps,).
|
|
144
|
+
tau_alphas : torch.Tensor
|
|
145
|
+
Alpha values for subsampled steps, shape (tau_num_steps,).
|
|
146
|
+
tau_alpha_cumprod : torch.Tensor
|
|
147
|
+
Cumulative product of alphas for subsampled steps, shape (tau_num_steps,).
|
|
148
|
+
tau_sqrt_alpha_cumprod : torch.Tensor
|
|
149
|
+
Square root of alpha_cumprod for subsampled steps, shape (tau_num_steps,).
|
|
150
|
+
tau_sqrt_one_minus_alpha_cumprod : torch.Tensor
|
|
151
|
+
Square root of (1 - alpha_cumprod) for subsampled steps, shape (tau_num_steps,).
|
|
152
|
+
"""
|
|
153
|
+
if self.trainable_beta:
|
|
154
|
+
betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod = self.compute_schedule()
|
|
155
|
+
else:
|
|
156
|
+
betas = self.betas
|
|
157
|
+
alphas = self.alphas
|
|
158
|
+
alpha_cumprod = self.alpha_cumprod
|
|
159
|
+
sqrt_alpha_cumprod = self.sqrt_alpha_cumprod
|
|
160
|
+
sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod
|
|
161
|
+
|
|
162
|
+
tau_betas = betas[self.tau_indices]
|
|
163
|
+
tau_alphas = alphas[self.tau_indices]
|
|
164
|
+
tau_alpha_cumprod = alpha_cumprod[self.tau_indices]
|
|
165
|
+
tau_sqrt_alpha_cumprod = sqrt_alpha_cumprod[self.tau_indices]
|
|
166
|
+
tau_sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[self.tau_indices]
|
|
167
|
+
|
|
168
|
+
return tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod
|
|
169
|
+
|
|
170
|
+
def compute_schedule(self, time_steps: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
171
|
+
"""Computes noise schedule parameters dynamically from betas.
|
|
172
|
+
|
|
173
|
+
Calculates the derived noise schedule parameters (alphas, alpha_cumprod, etc.)
|
|
174
|
+
from the provided beta values for the UnCLIP diffusion process.
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
`time_steps` : torch.Tensor, optional
|
|
179
|
+
If provided, returns parameters only for specified time steps.
|
|
180
|
+
If None, returns parameters for all time steps.
|
|
181
|
+
|
|
182
|
+
Returns
|
|
183
|
+
-------
|
|
184
|
+
betas : torch.Tensor
|
|
185
|
+
Beta values, shape (num_steps,) or (len(time_steps),).
|
|
186
|
+
alphas : torch.Tensor
|
|
187
|
+
1 - betas, shape (num_steps,) or (len(time_steps),).
|
|
188
|
+
alpha_cumprod : torch.Tensor
|
|
189
|
+
Cumulative product of alphas, shape (num_steps,) or (len(time_steps),).
|
|
190
|
+
sqrt_alpha_cumprod : torch.Tensor
|
|
191
|
+
Square root of alpha_cumprod, shape (num_steps,) or (len(time_steps),).
|
|
192
|
+
sqrt_one_minus_alpha_cumprod : torch.Tensor
|
|
193
|
+
Square root of (1 - alpha_cumprod), shape (num_steps,) or (len(time_steps),).
|
|
194
|
+
"""
|
|
195
|
+
betas = self.betas
|
|
196
|
+
alphas = 1 - betas
|
|
197
|
+
alpha_cumprod = torch.cumprod(alphas, dim=0)
|
|
198
|
+
sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod)
|
|
199
|
+
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alpha_cumprod)
|
|
200
|
+
if time_steps is not None:
|
|
201
|
+
return (betas[time_steps], alphas[time_steps], alpha_cumprod[time_steps],
|
|
202
|
+
sqrt_alpha_cumprod[time_steps], sqrt_one_minus_alpha_cumprod[time_steps])
|
|
203
|
+
return betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod
|
|
204
|
+
|
|
205
|
+
class ForwardUnCLIP(nn.Module):
|
|
206
|
+
"""Forward diffusion process for UnCLIP diffusion models.
|
|
207
|
+
|
|
208
|
+
Applies Gaussian noise to input data (2D or 4D tensors) according to the UnCLIP
|
|
209
|
+
forward diffusion process at specified time steps, using cumulative noise schedule
|
|
210
|
+
parameters from the variance scheduler.
|
|
211
|
+
|
|
212
|
+
Parameters
|
|
213
|
+
----------
|
|
214
|
+
`variance_scheduler` : torch.nn.Module
|
|
215
|
+
Variance scheduler module (e.g., VarianceSchedulerUnCLIP) containing the noise
|
|
216
|
+
schedule parameters.
|
|
217
|
+
"""
|
|
218
|
+
def __init__(self, variance_scheduler: torch.nn.Module) -> None:
|
|
219
|
+
super().__init__()
|
|
220
|
+
self.variance_scheduler = variance_scheduler
|
|
221
|
+
|
|
222
|
+
def forward(self, x0: torch.Tensor, noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
|
|
223
|
+
"""Applies the forward diffusion process to the input data.
|
|
224
|
+
|
|
225
|
+
Perturbs the input data `x0` by adding Gaussian noise at specified time steps,
|
|
226
|
+
supporting both 2D (e.g., latent embeddings) and 4D (e.g., image) inputs.
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
`x0` : torch.Tensor
|
|
231
|
+
Input data tensor, shape (batch_size, embedding_dim) for 2D or
|
|
232
|
+
(batch_size, channels, height, width) for 4D.
|
|
233
|
+
`noise` : torch.Tensor
|
|
234
|
+
Gaussian noise tensor, same shape as `x0`.
|
|
235
|
+
`time_steps` : torch.Tensor
|
|
236
|
+
Tensor of time step indices (long), shape (batch_size,),
|
|
237
|
+
where each value is in the range [0, variance_scheduler.num_steps - 1].
|
|
238
|
+
|
|
239
|
+
Returns
|
|
240
|
+
-------
|
|
241
|
+
xt : torch.Tensor
|
|
242
|
+
Noisy data tensor at the specified time steps, same shape as `x0`.
|
|
243
|
+
"""
|
|
244
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.num_steps)):
|
|
245
|
+
raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.num_steps - 1}")
|
|
246
|
+
|
|
247
|
+
if self.variance_scheduler.trainable_beta:
|
|
248
|
+
_, _, _, sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.compute_schedule(
|
|
249
|
+
time_steps
|
|
250
|
+
)
|
|
251
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.to(x0.device)
|
|
252
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.to(x0.device)
|
|
253
|
+
else:
|
|
254
|
+
sqrt_alpha_cumprod_t = self.variance_scheduler.sqrt_alpha_cumprod[time_steps].to(x0.device)
|
|
255
|
+
sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.sqrt_one_minus_alpha_cumprod[time_steps].to(x0.device)
|
|
256
|
+
|
|
257
|
+
# check input dimensions and adjust reshaping for 2D or 4D tensors
|
|
258
|
+
is_2d = x0.dim() == 2 # check if input is 2D (batch_size, embedding_dim)
|
|
259
|
+
if is_2d:
|
|
260
|
+
# for 2D inputs, reshape to [batch_size, 1]
|
|
261
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1)
|
|
262
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1)
|
|
263
|
+
else:
|
|
264
|
+
# for 4D inputs, reshape to [batch_size, 1, 1, 1]
|
|
265
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1, 1, 1)
|
|
266
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1, 1, 1)
|
|
267
|
+
|
|
268
|
+
xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
|
|
269
|
+
return xt
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class ReverseUnCLIP(nn.Module):
|
|
273
|
+
"""Reverse diffusion process for UnCLIP diffusion models.
|
|
274
|
+
|
|
275
|
+
Denoises a noisy input `xt` using either a predicted noise component or predicted clean image
|
|
276
|
+
and a subsampled time step schedule, supporting both 2D (e.g., latent embeddings) and 4D (e.g., image) inputs.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
`variance_scheduler` : torch.nn.Module
|
|
281
|
+
Variance scheduler module (e.g., VarianceSchedulerUnCLIP) containing the noise
|
|
282
|
+
schedule parameters.
|
|
283
|
+
`prediction_type` : str, default "noise"
|
|
284
|
+
Type of prediction the model makes. Either "noise" (predicts noise like DDIM) or
|
|
285
|
+
"x0" (predicts clean image like UnCLIP prior).
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
def __init__(self, variance_scheduler: torch.nn.Module, prediction_type: str = "noise"):
|
|
289
|
+
super().__init__()
|
|
290
|
+
self.variance_scheduler = variance_scheduler
|
|
291
|
+
if prediction_type not in ["noise", "x0"]:
|
|
292
|
+
raise ValueError(f"prediction_type must be either 'noise' or 'x0', got {prediction_type}")
|
|
293
|
+
self.prediction_type = prediction_type
|
|
294
|
+
|
|
295
|
+
def forward(self, xt: torch.Tensor, model_prediction: torch.Tensor, time_steps: torch.Tensor,
|
|
296
|
+
prev_time_steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
297
|
+
"""Applies the reverse diffusion process to the noisy input.
|
|
298
|
+
|
|
299
|
+
Denoises the input `xt` at time step `t` to produce the previous step `xt_prev`
|
|
300
|
+
at `prev_time_steps` using either the predicted noise or predicted clean image
|
|
301
|
+
and the UnCLIP reverse process. Supports both 2D and 4D inputs.
|
|
302
|
+
|
|
303
|
+
Parameters
|
|
304
|
+
----------
|
|
305
|
+
`xt` : torch.Tensor
|
|
306
|
+
Noisy input tensor at time step `t`, shape (batch_size, embedding_dim) for 2D
|
|
307
|
+
or (batch_size, channels, height, width) for 4D.
|
|
308
|
+
`model_prediction` : torch.Tensor
|
|
309
|
+
Model prediction tensor, same shape as `xt`. Can be either predicted noise
|
|
310
|
+
or predicted clean image depending on `prediction_type`.
|
|
311
|
+
`time_steps` : torch.Tensor
|
|
312
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
313
|
+
is in the range [0, variance_scheduler.tau_num_steps - 1].
|
|
314
|
+
`prev_time_steps` : torch.Tensor
|
|
315
|
+
Tensor of previous time step indices (long), shape (batch_size,), where each
|
|
316
|
+
value is in the range [0, variance_scheduler.tau_num_steps - 1].
|
|
317
|
+
|
|
318
|
+
Returns
|
|
319
|
+
-------
|
|
320
|
+
xt_prev : torch.Tensor
|
|
321
|
+
Denoised tensor at `prev_time_steps`, same shape as `xt`.
|
|
322
|
+
x0 : torch.Tensor
|
|
323
|
+
Estimated original data (t=0), same shape as `xt`.
|
|
324
|
+
"""
|
|
325
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.tau_num_steps)):
|
|
326
|
+
raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
|
|
327
|
+
if not torch.all((prev_time_steps >= 0) & (prev_time_steps < self.variance_scheduler.tau_num_steps)):
|
|
328
|
+
raise ValueError(f"prev_time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
|
|
329
|
+
|
|
330
|
+
_, _, _, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = self.variance_scheduler.get_tau_schedule()
|
|
331
|
+
|
|
332
|
+
# Check input dimensions and adjust reshaping for 2D or 4D tensors
|
|
333
|
+
is_2d = xt.dim() == 2 # check if input is 2D (batch_size, embedding_dim)
|
|
334
|
+
if is_2d:
|
|
335
|
+
# for 2D inputs, reshape to [batch_size, 1]
|
|
336
|
+
tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1)
|
|
337
|
+
tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1)
|
|
338
|
+
prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1)
|
|
339
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(
|
|
340
|
+
xt.device).view(-1, 1)
|
|
341
|
+
else:
|
|
342
|
+
# for 4D inputs, reshape to [batch_size, 1, 1, 1]
|
|
343
|
+
tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
344
|
+
tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1,
|
|
345
|
+
1, 1)
|
|
346
|
+
prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
347
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(
|
|
348
|
+
xt.device).view(-1, 1, 1, 1)
|
|
349
|
+
|
|
350
|
+
eta = self.variance_scheduler.eta
|
|
351
|
+
|
|
352
|
+
predicted_noise = None
|
|
353
|
+
x0 = None
|
|
354
|
+
# Handle different prediction types
|
|
355
|
+
if self.prediction_type == "noise":
|
|
356
|
+
# model predicts noise
|
|
357
|
+
predicted_noise = model_prediction
|
|
358
|
+
x0 = (xt - tau_sqrt_one_minus_alpha_cumprod_t * predicted_noise) / tau_sqrt_alpha_cumprod_t
|
|
359
|
+
elif self.prediction_type == "x0":
|
|
360
|
+
# model predicts clean image
|
|
361
|
+
x0 = model_prediction
|
|
362
|
+
# Calculate implied noise from the predicted clean image
|
|
363
|
+
predicted_noise = (xt - tau_sqrt_alpha_cumprod_t * x0) / tau_sqrt_one_minus_alpha_cumprod_t
|
|
364
|
+
|
|
365
|
+
# DDIM sampling step (same for both prediction types)
|
|
366
|
+
noise_coeff = eta * ((tau_sqrt_one_minus_alpha_cumprod_t / prev_tau_sqrt_alpha_cumprod_t) *
|
|
367
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t / torch.clamp(tau_sqrt_one_minus_alpha_cumprod_t,
|
|
368
|
+
min=1e-8))
|
|
369
|
+
direction_coeff = torch.clamp(prev_tau_sqrt_one_minus_alpha_cumprod_t ** 2 - noise_coeff ** 2, min=1e-8).sqrt()
|
|
370
|
+
xt_prev = prev_tau_sqrt_alpha_cumprod_t * x0 + noise_coeff * torch.randn_like(xt) + direction_coeff * predicted_noise
|
|
371
|
+
|
|
372
|
+
return xt_prev, x0
|
|
373
|
+
|
|
374
|
+
def set_prediction_type(self, prediction_type: str):
|
|
375
|
+
"""Change the prediction type after initialization.
|
|
376
|
+
|
|
377
|
+
Parameters
|
|
378
|
+
----------
|
|
379
|
+
prediction_type : str
|
|
380
|
+
Type of prediction the model makes. Either "noise" or "x0".
|
|
381
|
+
"""
|
|
382
|
+
if prediction_type not in ["noise", "x0"]:
|
|
383
|
+
raise ValueError(f"prediction_type must be either 'noise' or 'x0', got {prediction_type}")
|
|
384
|
+
self.prediction_type = prediction_type
|
|
385
|
+
|
|
386
|
+
"""
|
|
387
|
+
hyp = VarianceSchedulerUnCLIP(
|
|
388
|
+
num_steps=1000,
|
|
389
|
+
beta_start=1e-4,
|
|
390
|
+
beta_end=0.02,
|
|
391
|
+
trainable_beta=False,
|
|
392
|
+
beta_method="sigmoid"
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
forward = ForwardUnCLIP(hyp)
|
|
396
|
+
x = torch.randn((10, 3, 100, 100))
|
|
397
|
+
t = torch.randint(0, 1000, (10,))
|
|
398
|
+
noise = torch.randn_like(x)
|
|
399
|
+
|
|
400
|
+
xt = forward(x, noise, t)
|
|
401
|
+
print(xt.size())
|
|
402
|
+
"""
|
unclip/prior_model.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import math
|
|
4
|
+
from typing import Union, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class UnCLIPTransformerPrior(nn.Module):
|
|
8
|
+
"""Transformer-based prior model for UnCLIP diffusion.
|
|
9
|
+
|
|
10
|
+
Predicts clean image embeddings from noisy image embeddings and text embeddings using
|
|
11
|
+
a Transformer architecture, incorporating time embeddings and optional projection
|
|
12
|
+
layers for text and image inputs.
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
`forward_diffusion` : nn.Module
|
|
17
|
+
Forward diffusion module (e.g., ForwardUnCLIP) for adding noise during training.
|
|
18
|
+
`reverse_diffusion` : nn.Module
|
|
19
|
+
Reverse diffusion module (e.g., ReverseUnCLIP) for denoising during training.
|
|
20
|
+
`clip_text_projection` : nn.Module, optional
|
|
21
|
+
Projection module for text embeddings, default None.
|
|
22
|
+
`clip_image_projection` : nn.Module, optional
|
|
23
|
+
Projection module for image embeddings, default None.
|
|
24
|
+
`transformer_embedding_dim` : int, optional
|
|
25
|
+
Dimensionality of embeddings (default: 320).
|
|
26
|
+
`num_layers` : int, optional
|
|
27
|
+
Number of Transformer layers (default: 12).
|
|
28
|
+
`num_attention_heads` : int, optional
|
|
29
|
+
Number of attention heads in each Transformer layer (default: 8).
|
|
30
|
+
`feedforward_dim` : int, optional
|
|
31
|
+
Dimensionality of the feedforward network in Transformer layers (default: 768).
|
|
32
|
+
`max_sequence_length` : int, optional
|
|
33
|
+
Maximum sequence length for input embeddings (default: 2).
|
|
34
|
+
`dropout_rate` : float, optional
|
|
35
|
+
Dropout probability for regularization (default: 0.2).
|
|
36
|
+
"""
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
forward_diffusion: nn.Module, # will be used during training
|
|
40
|
+
reverse_diffusion: nn.Module, # will be used during training
|
|
41
|
+
clip_text_projection: Optional[nn.Module] = None, # used during training instead of PCA in the main paper
|
|
42
|
+
clip_image_projection: Optional[nn.Module] = None, # used during training instead of PCA in the main paper
|
|
43
|
+
transformer_embedding_dim: int = 320,
|
|
44
|
+
num_layers: int = 12,
|
|
45
|
+
num_attention_heads: int = 8,
|
|
46
|
+
feedforward_dim: int = 768,
|
|
47
|
+
max_sequence_length: int = 2,
|
|
48
|
+
dropout_rate: float = 0.2
|
|
49
|
+
) -> None:
|
|
50
|
+
super().__init__()
|
|
51
|
+
|
|
52
|
+
self.forward_diffusion = forward_diffusion
|
|
53
|
+
self.reverse_diffusion = reverse_diffusion
|
|
54
|
+
self.clip_text_projection = clip_text_projection
|
|
55
|
+
self.clip_image_projection = clip_image_projection
|
|
56
|
+
|
|
57
|
+
self.transformer_embedding_dim = transformer_embedding_dim
|
|
58
|
+
self.max_sequence_length = max_sequence_length
|
|
59
|
+
|
|
60
|
+
# Time embedding network
|
|
61
|
+
self.time_embedding_net = nn.Sequential(
|
|
62
|
+
nn.Linear(transformer_embedding_dim, transformer_embedding_dim),
|
|
63
|
+
nn.GELU(),
|
|
64
|
+
nn.Linear(transformer_embedding_dim, transformer_embedding_dim)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Positional embeddings
|
|
68
|
+
self.positional_embeddings = nn.Parameter(torch.randn(max_sequence_length, transformer_embedding_dim))
|
|
69
|
+
|
|
70
|
+
# Transformer layers
|
|
71
|
+
self.transformer_blocks = nn.ModuleList([
|
|
72
|
+
TransformerBlock(transformer_embedding_dim, num_attention_heads, feedforward_dim, dropout_rate)
|
|
73
|
+
for _ in range(num_layers)
|
|
74
|
+
])
|
|
75
|
+
|
|
76
|
+
# Final output projection
|
|
77
|
+
self.output_projection = nn.Linear(transformer_embedding_dim, transformer_embedding_dim)
|
|
78
|
+
|
|
79
|
+
def forward(
|
|
80
|
+
self,
|
|
81
|
+
text_embeddings: torch.Tensor,
|
|
82
|
+
noisy_image_embeddings: torch.Tensor,
|
|
83
|
+
timesteps: torch.Tensor
|
|
84
|
+
) -> torch.Tensor:
|
|
85
|
+
"""Predicts clean image embeddings from noisy inputs and text embeddings.
|
|
86
|
+
|
|
87
|
+
Processes text and noisy image embeddings through a Transformer architecture,
|
|
88
|
+
conditioned on time embeddings, to predict the clean image embeddings.
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
`text_embeddings` : torch.Tensor
|
|
93
|
+
Text embeddings, shape (batch_size, embedding_dim).
|
|
94
|
+
`noisy_image_embeddings` : torch.Tensor
|
|
95
|
+
Noisy image embeddings, shape (batch_size, embedding_dim).
|
|
96
|
+
`timesteps` : torch.Tensor
|
|
97
|
+
Tensor of time step indices (long), shape (batch_size,).
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
predicted_clean_embeddings : torch.Tensor
|
|
102
|
+
Predicted clean image embeddings, shape (batch_size, embedding_dim).
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
batch_size = text_embeddings.shape[0]
|
|
106
|
+
device = text_embeddings.device
|
|
107
|
+
#print("text", text_embeddings.size())
|
|
108
|
+
#print("noisy ", noisy_image_embeddings.size())
|
|
109
|
+
#print("time ", timesteps.size())
|
|
110
|
+
|
|
111
|
+
# Create sinusoidal time embeddings
|
|
112
|
+
time_embeddings = self._get_sinusoidal_embeddings(timesteps, self.embedding_dim, device)
|
|
113
|
+
time_embeddings = self.time_embedding_net(time_embeddings)
|
|
114
|
+
|
|
115
|
+
# Add time information to image embeddings
|
|
116
|
+
conditioned_image_embeddings = noisy_image_embeddings + time_embeddings
|
|
117
|
+
|
|
118
|
+
# Create sequence: [text_embeddings, conditioned_image_embeddings]
|
|
119
|
+
sequence = torch.stack([text_embeddings, conditioned_image_embeddings], dim=1) # [B, 2, D]
|
|
120
|
+
|
|
121
|
+
# Add positional embeddings
|
|
122
|
+
sequence = sequence + self.positional_embeddings.unsqueeze(0)
|
|
123
|
+
|
|
124
|
+
# Pass through transformer blocks
|
|
125
|
+
for transformer_block in self.transformer_blocks:
|
|
126
|
+
sequence = transformer_block(sequence)
|
|
127
|
+
|
|
128
|
+
# Extract predicted clean image embedding (second position in sequence)
|
|
129
|
+
predicted_clean_embeddings = sequence[:, 1, :] # [B, D]
|
|
130
|
+
|
|
131
|
+
# Apply final projection
|
|
132
|
+
predicted_clean_embeddings = self.output_projection(predicted_clean_embeddings)
|
|
133
|
+
|
|
134
|
+
return predicted_clean_embeddings
|
|
135
|
+
|
|
136
|
+
def _get_sinusoidal_embeddings(
|
|
137
|
+
self,
|
|
138
|
+
timesteps: torch.Tensor,
|
|
139
|
+
embedding_dim: int,
|
|
140
|
+
device: Union[torch.device, str]
|
|
141
|
+
) -> torch.Tensor:
|
|
142
|
+
"""Generates sinusoidal positional embeddings for timesteps.
|
|
143
|
+
|
|
144
|
+
Creates sinusoidal embeddings for the given timesteps to condition the Transformer
|
|
145
|
+
on the diffusion process time steps.
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
`timesteps` : torch.Tensor
|
|
150
|
+
Tensor of time step indices (long), shape (batch_size,).
|
|
151
|
+
`embedding_dim` : int
|
|
152
|
+
Dimensionality of the embeddings.
|
|
153
|
+
`device` : Union[torch.device, str]
|
|
154
|
+
Device to place the embeddings on.
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
embeddings : torch.Tensor
|
|
159
|
+
Sinusoidal time embeddings, shape (batch_size, embedding_dim).
|
|
160
|
+
"""
|
|
161
|
+
half_dim = embedding_dim // 2
|
|
162
|
+
emb = math.log(10000) / (half_dim - 1)
|
|
163
|
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
|
164
|
+
emb = timesteps[:, None].float() * emb[None, :]
|
|
165
|
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
166
|
+
|
|
167
|
+
# Handle odd embedding dimensions
|
|
168
|
+
if embedding_dim % 2 == 1:
|
|
169
|
+
emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
|
|
170
|
+
|
|
171
|
+
return emb
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class TransformerBlock(nn.Module):
|
|
175
|
+
"""Single Transformer block with multi-head attention and feedforward layers.
|
|
176
|
+
|
|
177
|
+
Implements a Transformer block with multi-head self-attention, layer normalization,
|
|
178
|
+
and a feedforward network with residual connections for processing sequences in
|
|
179
|
+
the UnCLIPTransformerPrior model.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
`embedding_dim` : int
|
|
184
|
+
Dimensionality of input and output embeddings.
|
|
185
|
+
`num_heads` : int
|
|
186
|
+
Number of attention heads in the multi-head attention layer.
|
|
187
|
+
`feedforward_dim` : int
|
|
188
|
+
Dimensionality of the feedforward network.
|
|
189
|
+
`dropout` : float
|
|
190
|
+
Dropout probability for regularization.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
def __init__(
|
|
194
|
+
self,
|
|
195
|
+
embedding_dim: int,
|
|
196
|
+
num_heads: int,
|
|
197
|
+
feedforward_dim: int,
|
|
198
|
+
dropout: float
|
|
199
|
+
) -> None:
|
|
200
|
+
super().__init__()
|
|
201
|
+
|
|
202
|
+
self.self_attention = nn.MultiheadAttention(
|
|
203
|
+
embedding_dim,
|
|
204
|
+
num_heads,
|
|
205
|
+
dropout=dropout,
|
|
206
|
+
batch_first=True
|
|
207
|
+
)
|
|
208
|
+
self.attention_norm = nn.LayerNorm(embedding_dim)
|
|
209
|
+
self.feedforward_norm = nn.LayerNorm(embedding_dim)
|
|
210
|
+
|
|
211
|
+
self.feedforward = nn.Sequential(
|
|
212
|
+
nn.Linear(embedding_dim, feedforward_dim),
|
|
213
|
+
nn.GELU(),
|
|
214
|
+
nn.Dropout(dropout),
|
|
215
|
+
nn.Linear(feedforward_dim, embedding_dim),
|
|
216
|
+
nn.Dropout(dropout)
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
220
|
+
"""Processes input sequence through the Transformer block.
|
|
221
|
+
|
|
222
|
+
Applies multi-head self-attention followed by a feedforward network, with residual
|
|
223
|
+
connections and layer normalization.
|
|
224
|
+
|
|
225
|
+
Parameters
|
|
226
|
+
----------
|
|
227
|
+
`x` : torch.Tensor
|
|
228
|
+
Input sequence tensor, shape (batch_size, sequence_length, embedding_dim).
|
|
229
|
+
|
|
230
|
+
Returns
|
|
231
|
+
-------
|
|
232
|
+
output : torch.Tensor
|
|
233
|
+
Processed sequence tensor, shape (batch_size, sequence_length, embedding_dim).
|
|
234
|
+
"""
|
|
235
|
+
# Self-attention with residual connection
|
|
236
|
+
attn_output, _ = self.self_attention(x, x, x)
|
|
237
|
+
x = self.attention_norm(x + attn_output)
|
|
238
|
+
|
|
239
|
+
# Feedforward with residual connection
|
|
240
|
+
ff_output = self.feedforward(x)
|
|
241
|
+
x = self.feedforward_norm(x + ff_output)
|
|
242
|
+
|
|
243
|
+
return x
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
"""
|
|
247
|
+
model = UnCLIPTransformerPrior(
|
|
248
|
+
embedding_dim=320,
|
|
249
|
+
num_layers=12,
|
|
250
|
+
num_attention_heads=8,
|
|
251
|
+
feedforward_dim=768,
|
|
252
|
+
max_sequence_length=2,
|
|
253
|
+
dropout_rate=0.3
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
x = torch.randn((10, 320))
|
|
257
|
+
t = torch.randint(0, 1000, (10,))
|
|
258
|
+
print(t.size())
|
|
259
|
+
tm = model._get_sinusoidal_embeddings(t, 320, "cpu")
|
|
260
|
+
print(tm.size())
|
|
261
|
+
y = torch.randn((10, 320))
|
|
262
|
+
p = model(y, x, t)
|
|
263
|
+
print(p.size())
|
|
264
|
+
"""
|