TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
unclip/unclip_sampler.py
ADDED
|
@@ -0,0 +1,626 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torchvision
|
|
4
|
+
from typing import Optional, Union, List, Tuple, Self
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SampleUnCLIP(nn.Module):
|
|
9
|
+
"""Generates images using the UnCLIP model pipeline.
|
|
10
|
+
|
|
11
|
+
Combines a prior model, decoder model, CLIP model, and upsampler models to generate
|
|
12
|
+
images from text prompts or noise. Performs diffusion-based sampling with classifier-free
|
|
13
|
+
guidance in both prior and decoder stages, followed by upsampling to higher resolutions.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
`prior_model` : nn.Module
|
|
18
|
+
The UnCLIP prior model for generating image embeddings from text.
|
|
19
|
+
`decoder_model` : nn.Module
|
|
20
|
+
The UnCLIP decoder model for generating low-resolution images from embeddings.
|
|
21
|
+
`clip_model` : nn.Module
|
|
22
|
+
CLIP model for encoding text prompts into embeddings.
|
|
23
|
+
`low_res_upsampler` : nn.Module
|
|
24
|
+
First upsampler model for scaling images from 64x64 to 256x256.
|
|
25
|
+
`high_res_upsampler` : nn.Module, optional
|
|
26
|
+
Second upsampler model for scaling images from 256x256 to 1024x1024, default None.
|
|
27
|
+
`device` : Union[torch.device, str], optional
|
|
28
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
29
|
+
`clip_embedding_dim` : int, optional
|
|
30
|
+
Dimensionality of CLIP embeddings (default: 512).
|
|
31
|
+
`prior_guidance_scale` : float, optional
|
|
32
|
+
Classifier-free guidance scale for the prior model (default: 4.0).
|
|
33
|
+
`decoder_guidance_scale` : float, optional
|
|
34
|
+
Classifier-free guidance scale for the decoder model (default: 8.0).
|
|
35
|
+
`batch_size` : int, optional
|
|
36
|
+
Number of images to generate per batch (default: 1).
|
|
37
|
+
`normalize` : bool, optional
|
|
38
|
+
Whether to normalize CLIP embeddings (default: True).
|
|
39
|
+
`prior_dim_reduction` : bool, optional
|
|
40
|
+
Whether to apply dimensionality reduction in the prior model (default: True).
|
|
41
|
+
`image_size` : Tuple[int, int, int], optional
|
|
42
|
+
Size of the initial generated images (default: (3, 64, 64) for RGB 64x64).
|
|
43
|
+
`use_high_res_upsampler` : bool, optional
|
|
44
|
+
Whether to use the second upsampler for 1024x1024 output (default: True).
|
|
45
|
+
`image_output_range` : Tuple[float, float], optional
|
|
46
|
+
Range for clamping output images (default: (-1.0, 1.0)).
|
|
47
|
+
"""
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
prior_model: nn.Module,
|
|
51
|
+
decoder_model: nn.Module,
|
|
52
|
+
clip_model: nn.Module,
|
|
53
|
+
low_res_upsampler: nn.Module,
|
|
54
|
+
high_res_upsampler: Optional[nn.Module] = None,
|
|
55
|
+
device: Optional[Union[torch.device, str]] = None,
|
|
56
|
+
clip_embedding_dim: int = 512, # CLIP embedding dimension
|
|
57
|
+
prior_guidance_scale: float = 4.0,
|
|
58
|
+
decoder_guidance_scale: float = 8.0,
|
|
59
|
+
batch_size: int = 1,
|
|
60
|
+
normalize_clip_embeddings: bool = True,
|
|
61
|
+
prior_dim_reduction: bool = True,
|
|
62
|
+
initial_image_size: Tuple[int, int, int] = (3, 64, 64),
|
|
63
|
+
use_high_res_upsampler: bool = True,
|
|
64
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0)
|
|
65
|
+
) -> None:
|
|
66
|
+
super().__init__()
|
|
67
|
+
|
|
68
|
+
if device is None:
|
|
69
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
70
|
+
elif isinstance(device, str):
|
|
71
|
+
self.device = torch.device(device)
|
|
72
|
+
else:
|
|
73
|
+
self.device = device
|
|
74
|
+
|
|
75
|
+
self.prior_model = prior_model.to(self.device)
|
|
76
|
+
self.decoder_model = decoder_model.to(self.device)
|
|
77
|
+
self.clip_model = clip_model.to(self.device)
|
|
78
|
+
self.low_res_upsampler = low_res_upsampler.to(self.device)
|
|
79
|
+
self.high_res_upsampler = high_res_upsampler.to(self.device) if high_res_upsampler else None
|
|
80
|
+
|
|
81
|
+
self.prior_guidance_scale = prior_guidance_scale
|
|
82
|
+
self.decoder_guidance_scale = decoder_guidance_scale
|
|
83
|
+
self.batch_size = batch_size
|
|
84
|
+
self.normalize_clip_embeddings = normalize_clip_embeddings
|
|
85
|
+
self.prior_dim_reduction = prior_dim_reduction
|
|
86
|
+
self.clip_embedding_dim = clip_embedding_dim
|
|
87
|
+
self.initial_image_size = initial_image_size
|
|
88
|
+
self.use_high_res_upsampler = use_high_res_upsampler
|
|
89
|
+
self.image_output_range = image_output_range
|
|
90
|
+
self.images_256 = None
|
|
91
|
+
self.images_1024 = None
|
|
92
|
+
|
|
93
|
+
def forward(
|
|
94
|
+
self,
|
|
95
|
+
prompts: Optional[Union[str, List]] = None,
|
|
96
|
+
normalize_output: bool = True,
|
|
97
|
+
save_images: bool = True,
|
|
98
|
+
save_path: str = "unclip_generated"
|
|
99
|
+
):
|
|
100
|
+
"""Generates images from text prompts or noise using the UnCLIP pipeline.
|
|
101
|
+
|
|
102
|
+
Executes the full UnCLIP generation process: prior model generates image embeddings,
|
|
103
|
+
decoder model generates 64x64 images, first upsampler scales to 256x256, and optional
|
|
104
|
+
second upsampler scales to 1024x1024. Supports classifier-free guidance and saves
|
|
105
|
+
generated images if requested.
|
|
106
|
+
|
|
107
|
+
Parameters
|
|
108
|
+
----------
|
|
109
|
+
`prompts` : Union[str, List], optional
|
|
110
|
+
Text prompt(s) for conditional generation, default None (unconditional).
|
|
111
|
+
`normalize_output` : bool, optional
|
|
112
|
+
Whether to normalize output images to [0, 1] range (default: True).
|
|
113
|
+
`save_images` : bool, optional
|
|
114
|
+
Whether to save generated images to disk (default: True).
|
|
115
|
+
`save_path` : str, optional
|
|
116
|
+
Directory to save generated images (default: "unclip_generated").
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
final_images : torch.Tensor
|
|
121
|
+
Generated images, shape (batch_size, channels, height, width), either 256x256
|
|
122
|
+
or 1024x1024 depending on use_second_upsampler.
|
|
123
|
+
"""
|
|
124
|
+
# initialize noise for prior sampling (image embedding space)
|
|
125
|
+
embedding_noise = torch.randn((self.batch_size, self.clip_embedding_dim), device=self.device)
|
|
126
|
+
print("embedding noise: ", embedding_noise.size())
|
|
127
|
+
|
|
128
|
+
with torch.no_grad():
|
|
129
|
+
# ====== PRIOR STAGE: generate image embeddings from text ======
|
|
130
|
+
print("############################################################")
|
|
131
|
+
print(" prior model ")
|
|
132
|
+
print("############################################################")
|
|
133
|
+
# encode text prompt using CLIP
|
|
134
|
+
text_embeddings = self.clip_model(data=prompts, data_type="text", normalize=self.normalize_clip_embeddings)
|
|
135
|
+
print("text embedding : ", text_embeddings.size())
|
|
136
|
+
|
|
137
|
+
current_embeddings = embedding_noise.clone()
|
|
138
|
+
|
|
139
|
+
# optionally reduce dimensionality for prior model
|
|
140
|
+
if self.prior_dim_reduction:
|
|
141
|
+
text_embeddings_reduced = self.prior_model.text_projection(text_embeddings)
|
|
142
|
+
current_embeddings_reduced = self.prior_model.image_projection(current_embeddings)
|
|
143
|
+
print("text embedding reduced: ", text_embeddings_reduced.size())
|
|
144
|
+
print("current embedding reduced: ", current_embeddings_reduced.size())
|
|
145
|
+
else:
|
|
146
|
+
text_embeddings_reduced = text_embeddings
|
|
147
|
+
current_embeddings_reduced = current_embeddings
|
|
148
|
+
print("text embedding reduced: ", text_embeddings_reduced.size())
|
|
149
|
+
print("current embedding reduced: ", current_embeddings_reduced.size())
|
|
150
|
+
|
|
151
|
+
# prior diffusion sampling loop
|
|
152
|
+
t_counter = 0
|
|
153
|
+
for t in reversed(range(self.prior_model.forward_diffusion.variance_scheduler.tau_num_steps)):
|
|
154
|
+
timesteps = torch.full((self.batch_size,), t, device=self.device)
|
|
155
|
+
prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
|
|
156
|
+
|
|
157
|
+
# predict embeddings
|
|
158
|
+
predicted_embeddings = self.prior_model(text_embeddings_reduced, current_embeddings_reduced, timesteps)
|
|
159
|
+
if t == 10:
|
|
160
|
+
print("predicted embeddings: ", predicted_embeddings.size())
|
|
161
|
+
|
|
162
|
+
# apply guidance
|
|
163
|
+
guided_embeddings = self.compute_prior_guided_prediction(
|
|
164
|
+
predicted_embeddings, text_embeddings_reduced, current_embeddings_reduced, timesteps
|
|
165
|
+
)
|
|
166
|
+
if t == 10:
|
|
167
|
+
print("guided embeddings: ", guided_embeddings.size())
|
|
168
|
+
|
|
169
|
+
# update embeddings using reverse diffusion
|
|
170
|
+
current_embeddings_reduced, _ = self.prior_model.reverse_diffusion(
|
|
171
|
+
current_embeddings_reduced, guided_embeddings, timesteps, prev_timesteps
|
|
172
|
+
)
|
|
173
|
+
if t == 10:
|
|
174
|
+
print("current embedding reduced: ", current_embeddings_reduced.size())
|
|
175
|
+
|
|
176
|
+
# convert back to full embedding dimension if needed
|
|
177
|
+
if self.prior_dim_reduction:
|
|
178
|
+
final_image_embeddings = self.prior_model.image_projection.inverse_transform(current_embeddings_reduced)
|
|
179
|
+
print("final image embeddings: ", final_image_embeddings.size())
|
|
180
|
+
else:
|
|
181
|
+
final_image_embeddings = current_embeddings_reduced
|
|
182
|
+
print("final image embeddings: ", final_image_embeddings.size())
|
|
183
|
+
|
|
184
|
+
t_counter += 1
|
|
185
|
+
print("number of iters in prior model: ", t_counter)
|
|
186
|
+
|
|
187
|
+
# ====== DECODER STAGE: generate 64x64 images from embeddings ======
|
|
188
|
+
|
|
189
|
+
print("############################################################")
|
|
190
|
+
print(" decoder model ")
|
|
191
|
+
print("############################################################")
|
|
192
|
+
|
|
193
|
+
# initialize noise for decoder sampling
|
|
194
|
+
decoder_noise = torch.randn((self.batch_size, self.initial_image_size[0], self.initial_image_size[1], self.initial_image_size[2]), device=self.device)
|
|
195
|
+
print("decoder noise: ", decoder_noise.size())
|
|
196
|
+
|
|
197
|
+
# project image embeddings to 4 tokens
|
|
198
|
+
projected_embeddings = self.decoder_model.decoder_projection(final_image_embeddings)
|
|
199
|
+
print("projected embeddings: ", projected_embeddings.size())
|
|
200
|
+
|
|
201
|
+
# encode text with GLIDE/decoder's text encoder
|
|
202
|
+
glide_text_embeddings = self.decoder_model._encode_text_with_glide(prompts)
|
|
203
|
+
print("glide text embeddings: ", glide_text_embeddings.size())
|
|
204
|
+
|
|
205
|
+
# concatenate embeddings for context
|
|
206
|
+
context = self.decoder_model._concatenate_embeddings(glide_text_embeddings, projected_embeddings)
|
|
207
|
+
print("context: ", context.size())
|
|
208
|
+
|
|
209
|
+
current_images = decoder_noise
|
|
210
|
+
# decoder diffusion sampling loop
|
|
211
|
+
t_counter = 0
|
|
212
|
+
for t in reversed(range(self.decoder_model.forward_diffusion.variance_scheduler.tau_num_steps)):
|
|
213
|
+
timesteps = torch.full((self.batch_size,), t, device=self.device)
|
|
214
|
+
prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
|
|
215
|
+
|
|
216
|
+
# Predict noise
|
|
217
|
+
predicted_noise = self.decoder_model.noise_predictor(current_images, timesteps, context, None)
|
|
218
|
+
if t == 10:
|
|
219
|
+
print("predicted noise: ", predicted_noise.size())
|
|
220
|
+
|
|
221
|
+
# apply guidance
|
|
222
|
+
guided_noise = self.compute_decoder_guided_prediction(
|
|
223
|
+
predicted_noise, current_images, timesteps, context
|
|
224
|
+
)
|
|
225
|
+
if t == 10:
|
|
226
|
+
print("guided noise: ", guided_noise.size())
|
|
227
|
+
|
|
228
|
+
# update images using reverse diffusion
|
|
229
|
+
current_images, _ = self.decoder_model.reverse_diffusion(
|
|
230
|
+
current_images, guided_noise, timesteps, prev_timesteps
|
|
231
|
+
)
|
|
232
|
+
if t == 10:
|
|
233
|
+
print("current image: ", current_images.size())
|
|
234
|
+
t_counter += 1
|
|
235
|
+
|
|
236
|
+
generated_64x64 = current_images
|
|
237
|
+
print(" number of iters of decoder model: ", t_counter)
|
|
238
|
+
|
|
239
|
+
# ====== FIRST UPSAMPLER: 64x64 -> 256x256 ======
|
|
240
|
+
print("############################################################")
|
|
241
|
+
print(" first upsampler ")
|
|
242
|
+
print("############################################################")
|
|
243
|
+
upsampled_256_noise = torch.randn((self.batch_size, self.initial_image_size[0], 256, 256), device=self.device)
|
|
244
|
+
current_256_images = upsampled_256_noise
|
|
245
|
+
print("upsampled 256 noise: ", upsampled_256_noise.size())
|
|
246
|
+
|
|
247
|
+
t_counter = 0
|
|
248
|
+
for t in reversed(range(self.low_res_upsampler.forward_diffusion.variance_scheduler.tau_num_steps)):
|
|
249
|
+
timesteps = torch.full((self.batch_size,), t, device=self.device)
|
|
250
|
+
prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
|
|
251
|
+
|
|
252
|
+
# predict noise for upsampling (conditioned on low-res image)
|
|
253
|
+
predicted_noise = self.low_res_upsampler(current_256_images, timesteps, generated_64x64)
|
|
254
|
+
if t == 10:
|
|
255
|
+
print("predicted noise: ", predicted_noise.size())
|
|
256
|
+
|
|
257
|
+
# update using reverse diffusion
|
|
258
|
+
current_256_images, _ = self.low_res_upsampler.reverse_diffusion(
|
|
259
|
+
current_256_images, predicted_noise, timesteps, prev_timesteps
|
|
260
|
+
)
|
|
261
|
+
if t == 10:
|
|
262
|
+
print("current 256 images: ", current_256_images.size())
|
|
263
|
+
t_counter += 1
|
|
264
|
+
print("number of iters in upsampler one:", t_counter)
|
|
265
|
+
|
|
266
|
+
self.images_256 = current_256_images
|
|
267
|
+
|
|
268
|
+
# ====== SECOND UPSAMPLER: 256x256 -> 1024x1024 (if enabled) ======
|
|
269
|
+
print("############################################################")
|
|
270
|
+
print(" second upsampler ")
|
|
271
|
+
print("############################################################")
|
|
272
|
+
if self.use_high_res_upsampler and self.high_res_upsampler:
|
|
273
|
+
upsampled_1024_noise = torch.randn((self.batch_size, self.initial_image_size[0], 1024, 1024), device=self.device)
|
|
274
|
+
current_1024_images = upsampled_1024_noise
|
|
275
|
+
|
|
276
|
+
t_counter = 0
|
|
277
|
+
for t in reversed(range(self.high_res_upsampler.forward_diffusion.variance_scheduler.tau_num_steps)):
|
|
278
|
+
timesteps = torch.full((self.batch_size,), t, device=self.device)
|
|
279
|
+
prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
|
|
280
|
+
|
|
281
|
+
# predict noise for upsampling (conditioned on 256x256 image)
|
|
282
|
+
predicted_noise = self.high_res_upsampler(current_1024_images, timesteps, self.images_256)
|
|
283
|
+
if t == 10:
|
|
284
|
+
print("predicted noise: ", predicted_noise.size())
|
|
285
|
+
|
|
286
|
+
# update using reverse diffusion
|
|
287
|
+
current_1024_images, _ = self.high_res_upsampler.reverse_diffusion(
|
|
288
|
+
current_1024_images, predicted_noise, timesteps, prev_timesteps
|
|
289
|
+
)
|
|
290
|
+
if t == 10:
|
|
291
|
+
print("current 1024 images: ", current_1024_images.size())
|
|
292
|
+
t_counter += 1
|
|
293
|
+
print("number of iters in upsampler two:", t_counter)
|
|
294
|
+
|
|
295
|
+
self.images_1024 = current_1024_images
|
|
296
|
+
|
|
297
|
+
# ====== POST-PROCESSING ======
|
|
298
|
+
# normalize output to [0, 1] range if requested
|
|
299
|
+
if normalize_output:
|
|
300
|
+
final_256 = (self.images_256 - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
301
|
+
final_1024 = None
|
|
302
|
+
if self.images_1024 is not None:
|
|
303
|
+
final_1024 = (self.images_1024 - self.image_output_range[0]) / (
|
|
304
|
+
self.image_output_range[1] - self.image_output_range[0])
|
|
305
|
+
else:
|
|
306
|
+
final_256 = self.images_256
|
|
307
|
+
final_1024 = self.images_1024
|
|
308
|
+
|
|
309
|
+
# save images if requested
|
|
310
|
+
if save_images:
|
|
311
|
+
os.makedirs(save_path, exist_ok=True)
|
|
312
|
+
os.makedirs(os.path.join(save_path, "images_256"), exist_ok=True)
|
|
313
|
+
if final_1024 is not None:
|
|
314
|
+
os.makedirs(os.path.join(save_path, "images_1024"), exist_ok=True)
|
|
315
|
+
|
|
316
|
+
for i in range(self.batch_size):
|
|
317
|
+
img_path_256 = os.path.join(save_path, "images_256", f"image_{i}.png")
|
|
318
|
+
torchvision.utils.save_image(final_256[i], img_path_256)
|
|
319
|
+
|
|
320
|
+
if final_1024 is not None:
|
|
321
|
+
img_path_1024 = os.path.join(save_path, "images_1024", f"image_{i}.png")
|
|
322
|
+
torchvision.utils.save_image(final_1024[i], img_path_1024)
|
|
323
|
+
|
|
324
|
+
# return final images
|
|
325
|
+
if final_1024 is not None:
|
|
326
|
+
return final_1024
|
|
327
|
+
else:
|
|
328
|
+
return final_256
|
|
329
|
+
|
|
330
|
+
def compute_prior_guided_prediction(
|
|
331
|
+
self,
|
|
332
|
+
predicted_embeddings: torch.Tensor,
|
|
333
|
+
text_embeddings: torch.Tensor,
|
|
334
|
+
current_embeddings: torch.Tensor,
|
|
335
|
+
timesteps: torch.Tensor
|
|
336
|
+
) -> torch.Tensor:
|
|
337
|
+
"""Computes classifier-free guidance for the prior model.
|
|
338
|
+
|
|
339
|
+
Combines conditioned and unconditioned predictions using the classifier-free guidance
|
|
340
|
+
formula to enhance the quality of generated image embeddings.
|
|
341
|
+
|
|
342
|
+
Parameters
|
|
343
|
+
----------
|
|
344
|
+
`predicted_embeddings` : torch.Tensor
|
|
345
|
+
Conditioned predicted embeddings, shape (batch_size, embedding_dim).
|
|
346
|
+
`text_embeddings` : torch.Tensor
|
|
347
|
+
Text embeddings from CLIP, shape (batch_size, embedding_dim).
|
|
348
|
+
`current_embeddings` : torch.Tensor
|
|
349
|
+
Current noisy embeddings, shape (batch_size, embedding_dim).
|
|
350
|
+
`timesteps` : torch.Tensor
|
|
351
|
+
Timestep indices, shape (batch_size,).
|
|
352
|
+
|
|
353
|
+
Returns
|
|
354
|
+
-------
|
|
355
|
+
guided_embeddings : torch.Tensor
|
|
356
|
+
Guided embeddings, shape (batch_size, embedding_dim).
|
|
357
|
+
"""
|
|
358
|
+
# use zero embeddings for unconditional generation
|
|
359
|
+
zero_text_embeddings = torch.zeros_like(text_embeddings)
|
|
360
|
+
unconditioned_pred = self.prior_model(zero_text_embeddings, current_embeddings, timesteps)
|
|
361
|
+
|
|
362
|
+
# CFG formula: (1 + guidance_scale) * conditioned - guidance_scale * unconditioned
|
|
363
|
+
return (1.0 + self.prior_guidance_scale) * predicted_embeddings - self.prior_guidance_scale * unconditioned_pred
|
|
364
|
+
|
|
365
|
+
def compute_decoder_guided_prediction(
|
|
366
|
+
self,
|
|
367
|
+
predicted_noise: torch.Tensor,
|
|
368
|
+
current_images: torch.Tensor,
|
|
369
|
+
timesteps: torch.Tensor,
|
|
370
|
+
context: torch.Tensor
|
|
371
|
+
) -> torch.Tensor:
|
|
372
|
+
"""Computes classifier-free guidance for the decoder model.
|
|
373
|
+
|
|
374
|
+
Combines conditioned and unconditioned noise predictions using the classifier-free
|
|
375
|
+
guidance formula to enhance the quality of generated images.
|
|
376
|
+
|
|
377
|
+
Parameters
|
|
378
|
+
----------
|
|
379
|
+
`predicted_noise` : torch.Tensor
|
|
380
|
+
Conditioned predicted noise, shape (batch_size, channels, height, width).
|
|
381
|
+
`current_images` : torch.Tensor
|
|
382
|
+
Current noisy images, shape (batch_size, channels, height, width).
|
|
383
|
+
`timesteps` : torch.Tensor
|
|
384
|
+
Timestep indices, shape (batch_size,).
|
|
385
|
+
`context` : torch.Tensor
|
|
386
|
+
Context embeddings (concatenated GLIDE text and projected image embeddings),
|
|
387
|
+
shape (batch_size, seq_len, embedding_dim).
|
|
388
|
+
|
|
389
|
+
Returns
|
|
390
|
+
-------
|
|
391
|
+
guided_noise : torch.Tensor
|
|
392
|
+
Guided noise prediction, shape (batch_size, channels, height, width).
|
|
393
|
+
"""
|
|
394
|
+
zero_context = torch.zeros_like(context)
|
|
395
|
+
unconditioned_noise = self.decoder_model.noise_predictor(current_images, timesteps, zero_context, None)
|
|
396
|
+
|
|
397
|
+
# CFG formula: (1 + guidance_scale) * conditioned - guidance_scale * unconditioned
|
|
398
|
+
return (1.0 + self.decoder_guidance_scale) * predicted_noise - self.decoder_guidance_scale * unconditioned_noise
|
|
399
|
+
|
|
400
|
+
def to(self, device: Union[torch.device, str]) -> Self:
|
|
401
|
+
"""Moves the module and all its components to the specified device.
|
|
402
|
+
|
|
403
|
+
Updates the device attribute and moves all sub-models (prior, decoder, CLIP,
|
|
404
|
+
and upsamplers) to the specified device.
|
|
405
|
+
|
|
406
|
+
Parameters
|
|
407
|
+
----------
|
|
408
|
+
device : Union[torch.device, str]
|
|
409
|
+
Target device for the module and its components.
|
|
410
|
+
|
|
411
|
+
Returns
|
|
412
|
+
-------
|
|
413
|
+
SampleUnCLIP
|
|
414
|
+
The module moved to the specified device.
|
|
415
|
+
"""
|
|
416
|
+
if isinstance(device, str):
|
|
417
|
+
device = torch.device(device)
|
|
418
|
+
|
|
419
|
+
self.device = device
|
|
420
|
+
|
|
421
|
+
# move all sub-models to the specified device
|
|
422
|
+
self.prior_model.to(device)
|
|
423
|
+
self.decoder_model.to(device)
|
|
424
|
+
self.clip_model.to(device)
|
|
425
|
+
self.low_res_upsampler.to(device)
|
|
426
|
+
|
|
427
|
+
if self.high_res_upsampler is not None:
|
|
428
|
+
self.high_res_upsampler.to(device)
|
|
429
|
+
|
|
430
|
+
return super().to(device)
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
"""
|
|
434
|
+
from prior_model import UnCLIPTransformerPrior
|
|
435
|
+
from utils import NoisePredictor, TextEncoder
|
|
436
|
+
from clip_model import CLIPEncoder
|
|
437
|
+
from project_prior import Projection
|
|
438
|
+
import torch
|
|
439
|
+
from prior_diff import VarianceSchedulerUnCLIP, ForwardUnCLIP, ReverseUnCLIP
|
|
440
|
+
from decoder_model import UnClipDecoder
|
|
441
|
+
from upsampler import UpsamplerUnCLIP
|
|
442
|
+
|
|
443
|
+
device = torch.device("cuda")
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
h_model = VarianceSchedulerUnCLIP(
|
|
447
|
+
num_steps=1000,
|
|
448
|
+
beta_start=1e-4,
|
|
449
|
+
beta_end=0.02,
|
|
450
|
+
trainable_beta=True,
|
|
451
|
+
beta_method="cosine"
|
|
452
|
+
).to(device)
|
|
453
|
+
|
|
454
|
+
c_model = CLIPEncoder(model_name="openai/clip-vit-base-patch32").to(device)
|
|
455
|
+
tp = Projection(
|
|
456
|
+
input_dim=512,
|
|
457
|
+
output_dim=320,
|
|
458
|
+
hidden_dim=480,
|
|
459
|
+
num_layers=2,
|
|
460
|
+
dropout=0.1,
|
|
461
|
+
use_layer_norm=True
|
|
462
|
+
).to(device)
|
|
463
|
+
ip = Projection(
|
|
464
|
+
input_dim=512,
|
|
465
|
+
output_dim=320,
|
|
466
|
+
hidden_dim=480,
|
|
467
|
+
num_layers=2,
|
|
468
|
+
dropout=0.1,
|
|
469
|
+
use_layer_norm=True
|
|
470
|
+
).to(device)
|
|
471
|
+
|
|
472
|
+
d_model = ForwardUnCLIP(h_model).to(device)
|
|
473
|
+
r_model = ReverseUnCLIP(h_model).to(device)
|
|
474
|
+
|
|
475
|
+
prior_model = UnCLIPTransformerPrior(
|
|
476
|
+
forward_diffusion=d_model,
|
|
477
|
+
reverse_diffusion=r_model,
|
|
478
|
+
text_projection=tp,
|
|
479
|
+
image_projection=ip,
|
|
480
|
+
embedding_dim=320,
|
|
481
|
+
num_layers=12,
|
|
482
|
+
num_attention_heads=8,
|
|
483
|
+
feedforward_dim=512,
|
|
484
|
+
max_sequence_length=2,
|
|
485
|
+
dropout_rate=0.3
|
|
486
|
+
).to(device)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
dn_model = NoisePredictor(
|
|
490
|
+
in_channels=3,
|
|
491
|
+
down_channels=[16, 32],
|
|
492
|
+
mid_channels=[32, 32],
|
|
493
|
+
up_channels=[32, 16],
|
|
494
|
+
down_sampling=[True, True],
|
|
495
|
+
time_embed_dim=512,
|
|
496
|
+
y_embed_dim=512,
|
|
497
|
+
num_down_blocks=2,
|
|
498
|
+
num_mid_blocks=2,
|
|
499
|
+
num_up_blocks=2,
|
|
500
|
+
down_sampling_factor=2
|
|
501
|
+
).to(device)
|
|
502
|
+
|
|
503
|
+
dt_proj = Projection(
|
|
504
|
+
input_dim=512,
|
|
505
|
+
output_dim=320,
|
|
506
|
+
hidden_dim=468,
|
|
507
|
+
num_layers=2,
|
|
508
|
+
dropout=0.1,
|
|
509
|
+
use_layer_norm=True
|
|
510
|
+
).to(device)
|
|
511
|
+
di_proj = Projection(
|
|
512
|
+
input_dim=512,
|
|
513
|
+
output_dim=320,
|
|
514
|
+
hidden_dim=468,
|
|
515
|
+
num_layers=2,
|
|
516
|
+
dropout=0.1,
|
|
517
|
+
use_layer_norm=True
|
|
518
|
+
).to(device)
|
|
519
|
+
|
|
520
|
+
dh_model = VarianceSchedulerUnCLIP(
|
|
521
|
+
num_steps=500,
|
|
522
|
+
beta_start=1e-4,
|
|
523
|
+
beta_end=0.02,
|
|
524
|
+
trainable_beta=False,
|
|
525
|
+
beta_method="linear"
|
|
526
|
+
).to(device)
|
|
527
|
+
dfor_ = ForwardUnCLIP(h_model).to(device)
|
|
528
|
+
drev_ = ReverseUnCLIP(h_model).to(device)
|
|
529
|
+
|
|
530
|
+
dcond = TextEncoder(
|
|
531
|
+
use_pretrained_model=True,
|
|
532
|
+
model_name="bert-base-uncased",
|
|
533
|
+
vocabulary_size=30522,
|
|
534
|
+
num_layers=2,
|
|
535
|
+
input_dimension=512,
|
|
536
|
+
output_dimension=512,
|
|
537
|
+
num_heads=2,
|
|
538
|
+
context_length=77
|
|
539
|
+
).to(device)
|
|
540
|
+
|
|
541
|
+
decoder_model = UnClipDecoder(
|
|
542
|
+
embedding_dim=512,
|
|
543
|
+
noise_predictor=dn_model,
|
|
544
|
+
forward_diffusion=dfor_,
|
|
545
|
+
reverse_diffusion=drev_,
|
|
546
|
+
conditional_model=dcond,
|
|
547
|
+
tokenizer=None,
|
|
548
|
+
device="cuda",
|
|
549
|
+
output_range=(-1.0, 1.0),
|
|
550
|
+
normalize=True,
|
|
551
|
+
classifier_free=0.1,
|
|
552
|
+
drop_caption=0.5,
|
|
553
|
+
max_length=77
|
|
554
|
+
).to(device)
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
hyp = VarianceSchedulerUnCLIP(
|
|
558
|
+
num_steps=1000,
|
|
559
|
+
beta_start=1e-4,
|
|
560
|
+
beta_end=0.02,
|
|
561
|
+
trainable_beta=False,
|
|
562
|
+
beta_method="cosine"
|
|
563
|
+
).to(device)
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
up_for = ForwardUnCLIP(hyp).to(device)
|
|
569
|
+
up_rev = ReverseUnCLIP(hyp).to(device)
|
|
570
|
+
|
|
571
|
+
upsampler_model_first = UpsamplerUnCLIP(
|
|
572
|
+
forward_diffusion=up_for,
|
|
573
|
+
reverse_diffusion=up_rev,
|
|
574
|
+
in_channels= 3,
|
|
575
|
+
out_channels= 3,
|
|
576
|
+
model_channels= 32,
|
|
577
|
+
num_res_blocks = 2,
|
|
578
|
+
channel_mult = (1, 2, 4, 8),
|
|
579
|
+
dropout = 0.1,
|
|
580
|
+
time_embed_dim = 756,
|
|
581
|
+
low_res_size = 64,
|
|
582
|
+
high_res_size = 256
|
|
583
|
+
).to(device)
|
|
584
|
+
|
|
585
|
+
upsampler_model_second = UpsamplerUnCLIP(
|
|
586
|
+
forward_diffusion=up_for,
|
|
587
|
+
reverse_diffusion=up_rev,
|
|
588
|
+
in_channels= 3,
|
|
589
|
+
out_channels= 3,
|
|
590
|
+
model_channels= 32,
|
|
591
|
+
num_res_blocks = 2,
|
|
592
|
+
channel_mult = (1, 2, 4, 8),
|
|
593
|
+
dropout = 0.1,
|
|
594
|
+
time_embed_dim = 756,
|
|
595
|
+
low_res_size = 256,
|
|
596
|
+
high_res_size = 1024
|
|
597
|
+
).to(device)
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
sampler = SampleUnCLIP(
|
|
602
|
+
prior_model=prior_model,
|
|
603
|
+
decoder_model=decoder_model,
|
|
604
|
+
clip_model=c_model,
|
|
605
|
+
first_upsampler_model=upsampler_model_first,
|
|
606
|
+
second_upsampler_model=upsampler_model_second,
|
|
607
|
+
device=None,
|
|
608
|
+
prior_guidance_scale=4.0,
|
|
609
|
+
decoder_guidance_scale=8.0,
|
|
610
|
+
batch_size=1,
|
|
611
|
+
normalize=True,
|
|
612
|
+
reduce_dim=True,
|
|
613
|
+
embedding_dim=512,
|
|
614
|
+
image_size=(3, 64, 64),
|
|
615
|
+
use_second_upsampler=True,
|
|
616
|
+
output_range=(-1.0, 1.0)
|
|
617
|
+
).to(device)
|
|
618
|
+
|
|
619
|
+
f = sampler(
|
|
620
|
+
prompts = ["this is a test prompt"],
|
|
621
|
+
normalize_output = True,
|
|
622
|
+
save_images = True,
|
|
623
|
+
save_path = "unclip_generated"
|
|
624
|
+
)
|
|
625
|
+
"""
|
|
626
|
+
|