TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
unclip/prior_diff.py ADDED
@@ -0,0 +1,402 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from typing import Optional, Tuple
5
+
6
+ class VarianceSchedulerUnCLIP(nn.Module):
7
+ """Manages noise schedule parameters for UnCLIP diffusion models.
8
+
9
+ Handles beta values, derived noise schedule quantities, and a subsampled time step schedule
10
+ (tau schedule) for UnCLIP diffusion processes. Supports trainable or fixed beta schedules
11
+ and multiple scheduling methods, including linear, sigmoid, quadratic, constant, inverse_time,
12
+ and cosine schedules.
13
+
14
+ Parameters
15
+ ----------
16
+ `eta` : float, optional
17
+ Noise scaling factor for the reverse process (default: 0, deterministic).
18
+ `num_steps` : int, optional
19
+ Total number of diffusion steps (default: 1000).
20
+ `tau_num_steps` : int, optional
21
+ Number of subsampled time steps for sampling (default: 100).
22
+ `beta_start` : float, optional
23
+ Starting value for beta (default: 1e-4).
24
+ `beta_end` : float, optional
25
+ Ending value for beta (default: 0.02).
26
+ `trainable_beta` : bool, optional
27
+ Whether the beta schedule is trainable (default: False).
28
+ `beta_method` : str, optional
29
+ Method for computing the beta schedule (default: "linear").
30
+ Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time", "cosine".
31
+ """
32
+ def __init__(
33
+ self,
34
+ eta: Optional[float] = None,
35
+ num_steps: int = 1000,
36
+ tau_num_steps: int = 100,
37
+ beta_start: float = 1e-4,
38
+ beta_end: float = 0.02,
39
+ trainable_beta: bool = False,
40
+ beta_method: str = "linear"
41
+ ) -> None:
42
+ super().__init__()
43
+ self.eta = eta or 0
44
+ self.num_steps = num_steps
45
+ self.tau_num_steps = tau_num_steps
46
+ self.beta_start = beta_start
47
+ self.beta_end = beta_end
48
+ self.trainable_beta = trainable_beta
49
+ self.beta_method = beta_method
50
+
51
+ if not (0 < beta_start < beta_end < 1):
52
+ raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
53
+ if num_steps <= 0:
54
+ raise ValueError(f"num_steps ({num_steps}) must be positive")
55
+
56
+ beta_range = (beta_start, beta_end)
57
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
58
+
59
+ if trainable_beta:
60
+ self.beta_raw = nn.Parameter(torch.logit((betas_init - beta_start) / (beta_end - beta_start)))
61
+ else:
62
+ self.register_buffer('betas_buffer', betas_init)
63
+ self.register_buffer('alphas', 1 - self.betas)
64
+ self.register_buffer('alpha_cumprod', torch.cumprod(self.alphas, dim=0))
65
+ self.register_buffer('sqrt_alpha_cumprod', torch.sqrt(self.alpha_cumprod))
66
+ self.register_buffer('sqrt_one_minus_alpha_cumprod', torch.sqrt(1 - self.alpha_cumprod))
67
+
68
+ self.register_buffer('tau_indices', torch.linspace(0, num_steps - 1, tau_num_steps, dtype=torch.long))
69
+
70
+ @property
71
+ def betas(self) -> torch.Tensor:
72
+ """Returns the beta values, applying reparameterization if trainable.
73
+
74
+ Returns the beta values, using sigmoid reparameterization for trainable betas
75
+ or directly accessing the stored buffer for fixed betas.
76
+
77
+ Returns
78
+ -------
79
+ betas : torch.Tensor
80
+ Beta values, shape (num_steps,).
81
+ """
82
+ if self.trainable_beta:
83
+ return self.beta_start + (self.beta_end - self.beta_start) * torch.sigmoid(self.beta_raw)
84
+ return self._buffers['betas_buffer']
85
+
86
+ def compute_beta_schedule(self, beta_range: Tuple[float, float], num_steps: int, method: str) -> torch.Tensor:
87
+ """Computes the beta schedule based on the specified method.
88
+
89
+ Generates a sequence of beta values for the noise schedule using the chosen method,
90
+ ensuring values are clamped within the specified range. Supports linear, sigmoid,
91
+ quadratic, constant, inverse_time, and cosine schedules.
92
+
93
+ Parameters
94
+ ----------
95
+ `beta_range` : tuple
96
+ Tuple of (min_beta, max_beta) specifying the valid range for beta values.
97
+ `num_steps` : int
98
+ Number of diffusion steps.
99
+ `method` : str
100
+ Method for computing the beta schedule. Supported methods:
101
+ "linear", "sigmoid", "quadratic", "constant", "inverse_time", "cosine".
102
+
103
+ Returns
104
+ -------
105
+ beta : torch.Tensor
106
+ Tensor of beta values, shape (num_steps,).
107
+ """
108
+ beta_min, beta_max = beta_range
109
+ if method == "sigmoid":
110
+ x = torch.linspace(-6, 6, num_steps)
111
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
112
+ elif method == "quadratic":
113
+ x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
114
+ beta = x ** 2
115
+ elif method == "constant":
116
+ beta = torch.full((num_steps,), beta_max)
117
+ elif method == "inverse_time":
118
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
119
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
120
+ elif method == "linear":
121
+ beta = torch.linspace(beta_min, beta_max, num_steps)
122
+ elif method == "cosine":
123
+ s = 0.008
124
+ steps = num_steps + 1
125
+ x = torch.linspace(0, num_steps, steps)
126
+ alphas_cumprod = torch.cos(((x / num_steps) + s) / (1 + s) * math.pi * 0.5) ** 2
127
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
128
+ beta = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
129
+ else:
130
+ raise ValueError(f"Unknown beta_method: {method}")
131
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
132
+ return beta
133
+
134
+ def get_tau_schedule(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
135
+ """Computes the subsampled (tau) noise schedule for UnCLIP.
136
+
137
+ Returns the noise schedule parameters for the subsampled time steps used in
138
+ UnCLIP sampling, based on the `tau_indices`.
139
+
140
+ Returns
141
+ -------
142
+ tau_betas : torch.Tensor
143
+ Beta values for subsampled steps, shape (tau_num_steps,).
144
+ tau_alphas : torch.Tensor
145
+ Alpha values for subsampled steps, shape (tau_num_steps,).
146
+ tau_alpha_cumprod : torch.Tensor
147
+ Cumulative product of alphas for subsampled steps, shape (tau_num_steps,).
148
+ tau_sqrt_alpha_cumprod : torch.Tensor
149
+ Square root of alpha_cumprod for subsampled steps, shape (tau_num_steps,).
150
+ tau_sqrt_one_minus_alpha_cumprod : torch.Tensor
151
+ Square root of (1 - alpha_cumprod) for subsampled steps, shape (tau_num_steps,).
152
+ """
153
+ if self.trainable_beta:
154
+ betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod = self.compute_schedule()
155
+ else:
156
+ betas = self.betas
157
+ alphas = self.alphas
158
+ alpha_cumprod = self.alpha_cumprod
159
+ sqrt_alpha_cumprod = self.sqrt_alpha_cumprod
160
+ sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod
161
+
162
+ tau_betas = betas[self.tau_indices]
163
+ tau_alphas = alphas[self.tau_indices]
164
+ tau_alpha_cumprod = alpha_cumprod[self.tau_indices]
165
+ tau_sqrt_alpha_cumprod = sqrt_alpha_cumprod[self.tau_indices]
166
+ tau_sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[self.tau_indices]
167
+
168
+ return tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod
169
+
170
+ def compute_schedule(self, time_steps: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
171
+ """Computes noise schedule parameters dynamically from betas.
172
+
173
+ Calculates the derived noise schedule parameters (alphas, alpha_cumprod, etc.)
174
+ from the provided beta values for the UnCLIP diffusion process.
175
+
176
+ Parameters
177
+ ----------
178
+ `time_steps` : torch.Tensor, optional
179
+ If provided, returns parameters only for specified time steps.
180
+ If None, returns parameters for all time steps.
181
+
182
+ Returns
183
+ -------
184
+ betas : torch.Tensor
185
+ Beta values, shape (num_steps,) or (len(time_steps),).
186
+ alphas : torch.Tensor
187
+ 1 - betas, shape (num_steps,) or (len(time_steps),).
188
+ alpha_cumprod : torch.Tensor
189
+ Cumulative product of alphas, shape (num_steps,) or (len(time_steps),).
190
+ sqrt_alpha_cumprod : torch.Tensor
191
+ Square root of alpha_cumprod, shape (num_steps,) or (len(time_steps),).
192
+ sqrt_one_minus_alpha_cumprod : torch.Tensor
193
+ Square root of (1 - alpha_cumprod), shape (num_steps,) or (len(time_steps),).
194
+ """
195
+ betas = self.betas
196
+ alphas = 1 - betas
197
+ alpha_cumprod = torch.cumprod(alphas, dim=0)
198
+ sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod)
199
+ sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alpha_cumprod)
200
+ if time_steps is not None:
201
+ return (betas[time_steps], alphas[time_steps], alpha_cumprod[time_steps],
202
+ sqrt_alpha_cumprod[time_steps], sqrt_one_minus_alpha_cumprod[time_steps])
203
+ return betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod
204
+
205
+ class ForwardUnCLIP(nn.Module):
206
+ """Forward diffusion process for UnCLIP diffusion models.
207
+
208
+ Applies Gaussian noise to input data (2D or 4D tensors) according to the UnCLIP
209
+ forward diffusion process at specified time steps, using cumulative noise schedule
210
+ parameters from the variance scheduler.
211
+
212
+ Parameters
213
+ ----------
214
+ `variance_scheduler` : torch.nn.Module
215
+ Variance scheduler module (e.g., VarianceSchedulerUnCLIP) containing the noise
216
+ schedule parameters.
217
+ """
218
+ def __init__(self, variance_scheduler: torch.nn.Module) -> None:
219
+ super().__init__()
220
+ self.variance_scheduler = variance_scheduler
221
+
222
+ def forward(self, x0: torch.Tensor, noise: torch.Tensor, time_steps: torch.Tensor) -> torch.Tensor:
223
+ """Applies the forward diffusion process to the input data.
224
+
225
+ Perturbs the input data `x0` by adding Gaussian noise at specified time steps,
226
+ supporting both 2D (e.g., latent embeddings) and 4D (e.g., image) inputs.
227
+
228
+ Parameters
229
+ ----------
230
+ `x0` : torch.Tensor
231
+ Input data tensor, shape (batch_size, embedding_dim) for 2D or
232
+ (batch_size, channels, height, width) for 4D.
233
+ `noise` : torch.Tensor
234
+ Gaussian noise tensor, same shape as `x0`.
235
+ `time_steps` : torch.Tensor
236
+ Tensor of time step indices (long), shape (batch_size,),
237
+ where each value is in the range [0, variance_scheduler.num_steps - 1].
238
+
239
+ Returns
240
+ -------
241
+ xt : torch.Tensor
242
+ Noisy data tensor at the specified time steps, same shape as `x0`.
243
+ """
244
+ if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.num_steps)):
245
+ raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.num_steps - 1}")
246
+
247
+ if self.variance_scheduler.trainable_beta:
248
+ _, _, _, sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.compute_schedule(
249
+ time_steps
250
+ )
251
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.to(x0.device)
252
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.to(x0.device)
253
+ else:
254
+ sqrt_alpha_cumprod_t = self.variance_scheduler.sqrt_alpha_cumprod[time_steps].to(x0.device)
255
+ sqrt_one_minus_alpha_cumprod_t = self.variance_scheduler.sqrt_one_minus_alpha_cumprod[time_steps].to(x0.device)
256
+
257
+ # check input dimensions and adjust reshaping for 2D or 4D tensors
258
+ is_2d = x0.dim() == 2 # check if input is 2D (batch_size, embedding_dim)
259
+ if is_2d:
260
+ # for 2D inputs, reshape to [batch_size, 1]
261
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1)
262
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1)
263
+ else:
264
+ # for 4D inputs, reshape to [batch_size, 1, 1, 1]
265
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1, 1, 1)
266
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1, 1, 1)
267
+
268
+ xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
269
+ return xt
270
+
271
+
272
+ class ReverseUnCLIP(nn.Module):
273
+ """Reverse diffusion process for UnCLIP diffusion models.
274
+
275
+ Denoises a noisy input `xt` using either a predicted noise component or predicted clean image
276
+ and a subsampled time step schedule, supporting both 2D (e.g., latent embeddings) and 4D (e.g., image) inputs.
277
+
278
+ Parameters
279
+ ----------
280
+ `variance_scheduler` : torch.nn.Module
281
+ Variance scheduler module (e.g., VarianceSchedulerUnCLIP) containing the noise
282
+ schedule parameters.
283
+ `prediction_type` : str, default "noise"
284
+ Type of prediction the model makes. Either "noise" (predicts noise like DDIM) or
285
+ "x0" (predicts clean image like UnCLIP prior).
286
+ """
287
+
288
+ def __init__(self, variance_scheduler: torch.nn.Module, prediction_type: str = "noise"):
289
+ super().__init__()
290
+ self.variance_scheduler = variance_scheduler
291
+ if prediction_type not in ["noise", "x0"]:
292
+ raise ValueError(f"prediction_type must be either 'noise' or 'x0', got {prediction_type}")
293
+ self.prediction_type = prediction_type
294
+
295
+ def forward(self, xt: torch.Tensor, model_prediction: torch.Tensor, time_steps: torch.Tensor,
296
+ prev_time_steps: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
297
+ """Applies the reverse diffusion process to the noisy input.
298
+
299
+ Denoises the input `xt` at time step `t` to produce the previous step `xt_prev`
300
+ at `prev_time_steps` using either the predicted noise or predicted clean image
301
+ and the UnCLIP reverse process. Supports both 2D and 4D inputs.
302
+
303
+ Parameters
304
+ ----------
305
+ `xt` : torch.Tensor
306
+ Noisy input tensor at time step `t`, shape (batch_size, embedding_dim) for 2D
307
+ or (batch_size, channels, height, width) for 4D.
308
+ `model_prediction` : torch.Tensor
309
+ Model prediction tensor, same shape as `xt`. Can be either predicted noise
310
+ or predicted clean image depending on `prediction_type`.
311
+ `time_steps` : torch.Tensor
312
+ Tensor of time step indices (long), shape (batch_size,), where each value
313
+ is in the range [0, variance_scheduler.tau_num_steps - 1].
314
+ `prev_time_steps` : torch.Tensor
315
+ Tensor of previous time step indices (long), shape (batch_size,), where each
316
+ value is in the range [0, variance_scheduler.tau_num_steps - 1].
317
+
318
+ Returns
319
+ -------
320
+ xt_prev : torch.Tensor
321
+ Denoised tensor at `prev_time_steps`, same shape as `xt`.
322
+ x0 : torch.Tensor
323
+ Estimated original data (t=0), same shape as `xt`.
324
+ """
325
+ if not torch.all((time_steps >= 0) & (time_steps < self.variance_scheduler.tau_num_steps)):
326
+ raise ValueError(f"time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
327
+ if not torch.all((prev_time_steps >= 0) & (prev_time_steps < self.variance_scheduler.tau_num_steps)):
328
+ raise ValueError(f"prev_time_steps must be between 0 and {self.variance_scheduler.tau_num_steps - 1}")
329
+
330
+ _, _, _, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = self.variance_scheduler.get_tau_schedule()
331
+
332
+ # Check input dimensions and adjust reshaping for 2D or 4D tensors
333
+ is_2d = xt.dim() == 2 # check if input is 2D (batch_size, embedding_dim)
334
+ if is_2d:
335
+ # for 2D inputs, reshape to [batch_size, 1]
336
+ tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1)
337
+ tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1)
338
+ prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1)
339
+ prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(
340
+ xt.device).view(-1, 1)
341
+ else:
342
+ # for 4D inputs, reshape to [batch_size, 1, 1, 1]
343
+ tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
344
+ tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1,
345
+ 1, 1)
346
+ prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
347
+ prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(
348
+ xt.device).view(-1, 1, 1, 1)
349
+
350
+ eta = self.variance_scheduler.eta
351
+
352
+ predicted_noise = None
353
+ x0 = None
354
+ # Handle different prediction types
355
+ if self.prediction_type == "noise":
356
+ # model predicts noise
357
+ predicted_noise = model_prediction
358
+ x0 = (xt - tau_sqrt_one_minus_alpha_cumprod_t * predicted_noise) / tau_sqrt_alpha_cumprod_t
359
+ elif self.prediction_type == "x0":
360
+ # model predicts clean image
361
+ x0 = model_prediction
362
+ # Calculate implied noise from the predicted clean image
363
+ predicted_noise = (xt - tau_sqrt_alpha_cumprod_t * x0) / tau_sqrt_one_minus_alpha_cumprod_t
364
+
365
+ # DDIM sampling step (same for both prediction types)
366
+ noise_coeff = eta * ((tau_sqrt_one_minus_alpha_cumprod_t / prev_tau_sqrt_alpha_cumprod_t) *
367
+ prev_tau_sqrt_one_minus_alpha_cumprod_t / torch.clamp(tau_sqrt_one_minus_alpha_cumprod_t,
368
+ min=1e-8))
369
+ direction_coeff = torch.clamp(prev_tau_sqrt_one_minus_alpha_cumprod_t ** 2 - noise_coeff ** 2, min=1e-8).sqrt()
370
+ xt_prev = prev_tau_sqrt_alpha_cumprod_t * x0 + noise_coeff * torch.randn_like(xt) + direction_coeff * predicted_noise
371
+
372
+ return xt_prev, x0
373
+
374
+ def set_prediction_type(self, prediction_type: str):
375
+ """Change the prediction type after initialization.
376
+
377
+ Parameters
378
+ ----------
379
+ prediction_type : str
380
+ Type of prediction the model makes. Either "noise" or "x0".
381
+ """
382
+ if prediction_type not in ["noise", "x0"]:
383
+ raise ValueError(f"prediction_type must be either 'noise' or 'x0', got {prediction_type}")
384
+ self.prediction_type = prediction_type
385
+
386
+ """
387
+ hyp = VarianceSchedulerUnCLIP(
388
+ num_steps=1000,
389
+ beta_start=1e-4,
390
+ beta_end=0.02,
391
+ trainable_beta=False,
392
+ beta_method="sigmoid"
393
+ )
394
+
395
+ forward = ForwardUnCLIP(hyp)
396
+ x = torch.randn((10, 3, 100, 100))
397
+ t = torch.randint(0, 1000, (10,))
398
+ noise = torch.randn_like(x)
399
+
400
+ xt = forward(x, noise, t)
401
+ print(xt.size())
402
+ """
unclip/prior_model.py ADDED
@@ -0,0 +1,264 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from typing import Union, Optional
5
+
6
+
7
+ class UnCLIPTransformerPrior(nn.Module):
8
+ """Transformer-based prior model for UnCLIP diffusion.
9
+
10
+ Predicts clean image embeddings from noisy image embeddings and text embeddings using
11
+ a Transformer architecture, incorporating time embeddings and optional projection
12
+ layers for text and image inputs.
13
+
14
+ Parameters
15
+ ----------
16
+ `forward_diffusion` : nn.Module
17
+ Forward diffusion module (e.g., ForwardUnCLIP) for adding noise during training.
18
+ `reverse_diffusion` : nn.Module
19
+ Reverse diffusion module (e.g., ReverseUnCLIP) for denoising during training.
20
+ `clip_text_projection` : nn.Module, optional
21
+ Projection module for text embeddings, default None.
22
+ `clip_image_projection` : nn.Module, optional
23
+ Projection module for image embeddings, default None.
24
+ `transformer_embedding_dim` : int, optional
25
+ Dimensionality of embeddings (default: 320).
26
+ `num_layers` : int, optional
27
+ Number of Transformer layers (default: 12).
28
+ `num_attention_heads` : int, optional
29
+ Number of attention heads in each Transformer layer (default: 8).
30
+ `feedforward_dim` : int, optional
31
+ Dimensionality of the feedforward network in Transformer layers (default: 768).
32
+ `max_sequence_length` : int, optional
33
+ Maximum sequence length for input embeddings (default: 2).
34
+ `dropout_rate` : float, optional
35
+ Dropout probability for regularization (default: 0.2).
36
+ """
37
+ def __init__(
38
+ self,
39
+ forward_diffusion: nn.Module, # will be used during training
40
+ reverse_diffusion: nn.Module, # will be used during training
41
+ clip_text_projection: Optional[nn.Module] = None, # used during training instead of PCA in the main paper
42
+ clip_image_projection: Optional[nn.Module] = None, # used during training instead of PCA in the main paper
43
+ transformer_embedding_dim: int = 320,
44
+ num_layers: int = 12,
45
+ num_attention_heads: int = 8,
46
+ feedforward_dim: int = 768,
47
+ max_sequence_length: int = 2,
48
+ dropout_rate: float = 0.2
49
+ ) -> None:
50
+ super().__init__()
51
+
52
+ self.forward_diffusion = forward_diffusion
53
+ self.reverse_diffusion = reverse_diffusion
54
+ self.clip_text_projection = clip_text_projection
55
+ self.clip_image_projection = clip_image_projection
56
+
57
+ self.transformer_embedding_dim = transformer_embedding_dim
58
+ self.max_sequence_length = max_sequence_length
59
+
60
+ # Time embedding network
61
+ self.time_embedding_net = nn.Sequential(
62
+ nn.Linear(transformer_embedding_dim, transformer_embedding_dim),
63
+ nn.GELU(),
64
+ nn.Linear(transformer_embedding_dim, transformer_embedding_dim)
65
+ )
66
+
67
+ # Positional embeddings
68
+ self.positional_embeddings = nn.Parameter(torch.randn(max_sequence_length, transformer_embedding_dim))
69
+
70
+ # Transformer layers
71
+ self.transformer_blocks = nn.ModuleList([
72
+ TransformerBlock(transformer_embedding_dim, num_attention_heads, feedforward_dim, dropout_rate)
73
+ for _ in range(num_layers)
74
+ ])
75
+
76
+ # Final output projection
77
+ self.output_projection = nn.Linear(transformer_embedding_dim, transformer_embedding_dim)
78
+
79
+ def forward(
80
+ self,
81
+ text_embeddings: torch.Tensor,
82
+ noisy_image_embeddings: torch.Tensor,
83
+ timesteps: torch.Tensor
84
+ ) -> torch.Tensor:
85
+ """Predicts clean image embeddings from noisy inputs and text embeddings.
86
+
87
+ Processes text and noisy image embeddings through a Transformer architecture,
88
+ conditioned on time embeddings, to predict the clean image embeddings.
89
+
90
+ Parameters
91
+ ----------
92
+ `text_embeddings` : torch.Tensor
93
+ Text embeddings, shape (batch_size, embedding_dim).
94
+ `noisy_image_embeddings` : torch.Tensor
95
+ Noisy image embeddings, shape (batch_size, embedding_dim).
96
+ `timesteps` : torch.Tensor
97
+ Tensor of time step indices (long), shape (batch_size,).
98
+
99
+ Returns
100
+ -------
101
+ predicted_clean_embeddings : torch.Tensor
102
+ Predicted clean image embeddings, shape (batch_size, embedding_dim).
103
+ """
104
+
105
+ batch_size = text_embeddings.shape[0]
106
+ device = text_embeddings.device
107
+ #print("text", text_embeddings.size())
108
+ #print("noisy ", noisy_image_embeddings.size())
109
+ #print("time ", timesteps.size())
110
+
111
+ # Create sinusoidal time embeddings
112
+ time_embeddings = self._get_sinusoidal_embeddings(timesteps, self.embedding_dim, device)
113
+ time_embeddings = self.time_embedding_net(time_embeddings)
114
+
115
+ # Add time information to image embeddings
116
+ conditioned_image_embeddings = noisy_image_embeddings + time_embeddings
117
+
118
+ # Create sequence: [text_embeddings, conditioned_image_embeddings]
119
+ sequence = torch.stack([text_embeddings, conditioned_image_embeddings], dim=1) # [B, 2, D]
120
+
121
+ # Add positional embeddings
122
+ sequence = sequence + self.positional_embeddings.unsqueeze(0)
123
+
124
+ # Pass through transformer blocks
125
+ for transformer_block in self.transformer_blocks:
126
+ sequence = transformer_block(sequence)
127
+
128
+ # Extract predicted clean image embedding (second position in sequence)
129
+ predicted_clean_embeddings = sequence[:, 1, :] # [B, D]
130
+
131
+ # Apply final projection
132
+ predicted_clean_embeddings = self.output_projection(predicted_clean_embeddings)
133
+
134
+ return predicted_clean_embeddings
135
+
136
+ def _get_sinusoidal_embeddings(
137
+ self,
138
+ timesteps: torch.Tensor,
139
+ embedding_dim: int,
140
+ device: Union[torch.device, str]
141
+ ) -> torch.Tensor:
142
+ """Generates sinusoidal positional embeddings for timesteps.
143
+
144
+ Creates sinusoidal embeddings for the given timesteps to condition the Transformer
145
+ on the diffusion process time steps.
146
+
147
+ Parameters
148
+ ----------
149
+ `timesteps` : torch.Tensor
150
+ Tensor of time step indices (long), shape (batch_size,).
151
+ `embedding_dim` : int
152
+ Dimensionality of the embeddings.
153
+ `device` : Union[torch.device, str]
154
+ Device to place the embeddings on.
155
+
156
+ Returns
157
+ -------
158
+ embeddings : torch.Tensor
159
+ Sinusoidal time embeddings, shape (batch_size, embedding_dim).
160
+ """
161
+ half_dim = embedding_dim // 2
162
+ emb = math.log(10000) / (half_dim - 1)
163
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
164
+ emb = timesteps[:, None].float() * emb[None, :]
165
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
166
+
167
+ # Handle odd embedding dimensions
168
+ if embedding_dim % 2 == 1:
169
+ emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
170
+
171
+ return emb
172
+
173
+
174
+ class TransformerBlock(nn.Module):
175
+ """Single Transformer block with multi-head attention and feedforward layers.
176
+
177
+ Implements a Transformer block with multi-head self-attention, layer normalization,
178
+ and a feedforward network with residual connections for processing sequences in
179
+ the UnCLIPTransformerPrior model.
180
+
181
+ Parameters
182
+ ----------
183
+ `embedding_dim` : int
184
+ Dimensionality of input and output embeddings.
185
+ `num_heads` : int
186
+ Number of attention heads in the multi-head attention layer.
187
+ `feedforward_dim` : int
188
+ Dimensionality of the feedforward network.
189
+ `dropout` : float
190
+ Dropout probability for regularization.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ embedding_dim: int,
196
+ num_heads: int,
197
+ feedforward_dim: int,
198
+ dropout: float
199
+ ) -> None:
200
+ super().__init__()
201
+
202
+ self.self_attention = nn.MultiheadAttention(
203
+ embedding_dim,
204
+ num_heads,
205
+ dropout=dropout,
206
+ batch_first=True
207
+ )
208
+ self.attention_norm = nn.LayerNorm(embedding_dim)
209
+ self.feedforward_norm = nn.LayerNorm(embedding_dim)
210
+
211
+ self.feedforward = nn.Sequential(
212
+ nn.Linear(embedding_dim, feedforward_dim),
213
+ nn.GELU(),
214
+ nn.Dropout(dropout),
215
+ nn.Linear(feedforward_dim, embedding_dim),
216
+ nn.Dropout(dropout)
217
+ )
218
+
219
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
220
+ """Processes input sequence through the Transformer block.
221
+
222
+ Applies multi-head self-attention followed by a feedforward network, with residual
223
+ connections and layer normalization.
224
+
225
+ Parameters
226
+ ----------
227
+ `x` : torch.Tensor
228
+ Input sequence tensor, shape (batch_size, sequence_length, embedding_dim).
229
+
230
+ Returns
231
+ -------
232
+ output : torch.Tensor
233
+ Processed sequence tensor, shape (batch_size, sequence_length, embedding_dim).
234
+ """
235
+ # Self-attention with residual connection
236
+ attn_output, _ = self.self_attention(x, x, x)
237
+ x = self.attention_norm(x + attn_output)
238
+
239
+ # Feedforward with residual connection
240
+ ff_output = self.feedforward(x)
241
+ x = self.feedforward_norm(x + ff_output)
242
+
243
+ return x
244
+
245
+
246
+ """
247
+ model = UnCLIPTransformerPrior(
248
+ embedding_dim=320,
249
+ num_layers=12,
250
+ num_attention_heads=8,
251
+ feedforward_dim=768,
252
+ max_sequence_length=2,
253
+ dropout_rate=0.3
254
+ )
255
+
256
+ x = torch.randn((10, 320))
257
+ t = torch.randint(0, 1000, (10,))
258
+ print(t.size())
259
+ tm = model._get_sinusoidal_embeddings(t, 320, "cpu")
260
+ print(tm.size())
261
+ y = torch.randn((10, 320))
262
+ p = model(y, x, t)
263
+ print(p.size())
264
+ """