TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
@@ -0,0 +1,626 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ from typing import Optional, Union, List, Tuple, Self
5
+ import os
6
+
7
+
8
+ class SampleUnCLIP(nn.Module):
9
+ """Generates images using the UnCLIP model pipeline.
10
+
11
+ Combines a prior model, decoder model, CLIP model, and upsampler models to generate
12
+ images from text prompts or noise. Performs diffusion-based sampling with classifier-free
13
+ guidance in both prior and decoder stages, followed by upsampling to higher resolutions.
14
+
15
+ Parameters
16
+ ----------
17
+ `prior_model` : nn.Module
18
+ The UnCLIP prior model for generating image embeddings from text.
19
+ `decoder_model` : nn.Module
20
+ The UnCLIP decoder model for generating low-resolution images from embeddings.
21
+ `clip_model` : nn.Module
22
+ CLIP model for encoding text prompts into embeddings.
23
+ `low_res_upsampler` : nn.Module
24
+ First upsampler model for scaling images from 64x64 to 256x256.
25
+ `high_res_upsampler` : nn.Module, optional
26
+ Second upsampler model for scaling images from 256x256 to 1024x1024, default None.
27
+ `device` : Union[torch.device, str], optional
28
+ Device for computation (default: CUDA if available, else CPU).
29
+ `clip_embedding_dim` : int, optional
30
+ Dimensionality of CLIP embeddings (default: 512).
31
+ `prior_guidance_scale` : float, optional
32
+ Classifier-free guidance scale for the prior model (default: 4.0).
33
+ `decoder_guidance_scale` : float, optional
34
+ Classifier-free guidance scale for the decoder model (default: 8.0).
35
+ `batch_size` : int, optional
36
+ Number of images to generate per batch (default: 1).
37
+ `normalize` : bool, optional
38
+ Whether to normalize CLIP embeddings (default: True).
39
+ `prior_dim_reduction` : bool, optional
40
+ Whether to apply dimensionality reduction in the prior model (default: True).
41
+ `image_size` : Tuple[int, int, int], optional
42
+ Size of the initial generated images (default: (3, 64, 64) for RGB 64x64).
43
+ `use_high_res_upsampler` : bool, optional
44
+ Whether to use the second upsampler for 1024x1024 output (default: True).
45
+ `image_output_range` : Tuple[float, float], optional
46
+ Range for clamping output images (default: (-1.0, 1.0)).
47
+ """
48
+ def __init__(
49
+ self,
50
+ prior_model: nn.Module,
51
+ decoder_model: nn.Module,
52
+ clip_model: nn.Module,
53
+ low_res_upsampler: nn.Module,
54
+ high_res_upsampler: Optional[nn.Module] = None,
55
+ device: Optional[Union[torch.device, str]] = None,
56
+ clip_embedding_dim: int = 512, # CLIP embedding dimension
57
+ prior_guidance_scale: float = 4.0,
58
+ decoder_guidance_scale: float = 8.0,
59
+ batch_size: int = 1,
60
+ normalize_clip_embeddings: bool = True,
61
+ prior_dim_reduction: bool = True,
62
+ initial_image_size: Tuple[int, int, int] = (3, 64, 64),
63
+ use_high_res_upsampler: bool = True,
64
+ image_output_range: Tuple[float, float] = (-1.0, 1.0)
65
+ ) -> None:
66
+ super().__init__()
67
+
68
+ if device is None:
69
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ elif isinstance(device, str):
71
+ self.device = torch.device(device)
72
+ else:
73
+ self.device = device
74
+
75
+ self.prior_model = prior_model.to(self.device)
76
+ self.decoder_model = decoder_model.to(self.device)
77
+ self.clip_model = clip_model.to(self.device)
78
+ self.low_res_upsampler = low_res_upsampler.to(self.device)
79
+ self.high_res_upsampler = high_res_upsampler.to(self.device) if high_res_upsampler else None
80
+
81
+ self.prior_guidance_scale = prior_guidance_scale
82
+ self.decoder_guidance_scale = decoder_guidance_scale
83
+ self.batch_size = batch_size
84
+ self.normalize_clip_embeddings = normalize_clip_embeddings
85
+ self.prior_dim_reduction = prior_dim_reduction
86
+ self.clip_embedding_dim = clip_embedding_dim
87
+ self.initial_image_size = initial_image_size
88
+ self.use_high_res_upsampler = use_high_res_upsampler
89
+ self.image_output_range = image_output_range
90
+ self.images_256 = None
91
+ self.images_1024 = None
92
+
93
+ def forward(
94
+ self,
95
+ prompts: Optional[Union[str, List]] = None,
96
+ normalize_output: bool = True,
97
+ save_images: bool = True,
98
+ save_path: str = "unclip_generated"
99
+ ):
100
+ """Generates images from text prompts or noise using the UnCLIP pipeline.
101
+
102
+ Executes the full UnCLIP generation process: prior model generates image embeddings,
103
+ decoder model generates 64x64 images, first upsampler scales to 256x256, and optional
104
+ second upsampler scales to 1024x1024. Supports classifier-free guidance and saves
105
+ generated images if requested.
106
+
107
+ Parameters
108
+ ----------
109
+ `prompts` : Union[str, List], optional
110
+ Text prompt(s) for conditional generation, default None (unconditional).
111
+ `normalize_output` : bool, optional
112
+ Whether to normalize output images to [0, 1] range (default: True).
113
+ `save_images` : bool, optional
114
+ Whether to save generated images to disk (default: True).
115
+ `save_path` : str, optional
116
+ Directory to save generated images (default: "unclip_generated").
117
+
118
+ Returns
119
+ -------
120
+ final_images : torch.Tensor
121
+ Generated images, shape (batch_size, channels, height, width), either 256x256
122
+ or 1024x1024 depending on use_second_upsampler.
123
+ """
124
+ # initialize noise for prior sampling (image embedding space)
125
+ embedding_noise = torch.randn((self.batch_size, self.clip_embedding_dim), device=self.device)
126
+ print("embedding noise: ", embedding_noise.size())
127
+
128
+ with torch.no_grad():
129
+ # ====== PRIOR STAGE: generate image embeddings from text ======
130
+ print("############################################################")
131
+ print(" prior model ")
132
+ print("############################################################")
133
+ # encode text prompt using CLIP
134
+ text_embeddings = self.clip_model(data=prompts, data_type="text", normalize=self.normalize_clip_embeddings)
135
+ print("text embedding : ", text_embeddings.size())
136
+
137
+ current_embeddings = embedding_noise.clone()
138
+
139
+ # optionally reduce dimensionality for prior model
140
+ if self.prior_dim_reduction:
141
+ text_embeddings_reduced = self.prior_model.text_projection(text_embeddings)
142
+ current_embeddings_reduced = self.prior_model.image_projection(current_embeddings)
143
+ print("text embedding reduced: ", text_embeddings_reduced.size())
144
+ print("current embedding reduced: ", current_embeddings_reduced.size())
145
+ else:
146
+ text_embeddings_reduced = text_embeddings
147
+ current_embeddings_reduced = current_embeddings
148
+ print("text embedding reduced: ", text_embeddings_reduced.size())
149
+ print("current embedding reduced: ", current_embeddings_reduced.size())
150
+
151
+ # prior diffusion sampling loop
152
+ t_counter = 0
153
+ for t in reversed(range(self.prior_model.forward_diffusion.variance_scheduler.tau_num_steps)):
154
+ timesteps = torch.full((self.batch_size,), t, device=self.device)
155
+ prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
156
+
157
+ # predict embeddings
158
+ predicted_embeddings = self.prior_model(text_embeddings_reduced, current_embeddings_reduced, timesteps)
159
+ if t == 10:
160
+ print("predicted embeddings: ", predicted_embeddings.size())
161
+
162
+ # apply guidance
163
+ guided_embeddings = self.compute_prior_guided_prediction(
164
+ predicted_embeddings, text_embeddings_reduced, current_embeddings_reduced, timesteps
165
+ )
166
+ if t == 10:
167
+ print("guided embeddings: ", guided_embeddings.size())
168
+
169
+ # update embeddings using reverse diffusion
170
+ current_embeddings_reduced, _ = self.prior_model.reverse_diffusion(
171
+ current_embeddings_reduced, guided_embeddings, timesteps, prev_timesteps
172
+ )
173
+ if t == 10:
174
+ print("current embedding reduced: ", current_embeddings_reduced.size())
175
+
176
+ # convert back to full embedding dimension if needed
177
+ if self.prior_dim_reduction:
178
+ final_image_embeddings = self.prior_model.image_projection.inverse_transform(current_embeddings_reduced)
179
+ print("final image embeddings: ", final_image_embeddings.size())
180
+ else:
181
+ final_image_embeddings = current_embeddings_reduced
182
+ print("final image embeddings: ", final_image_embeddings.size())
183
+
184
+ t_counter += 1
185
+ print("number of iters in prior model: ", t_counter)
186
+
187
+ # ====== DECODER STAGE: generate 64x64 images from embeddings ======
188
+
189
+ print("############################################################")
190
+ print(" decoder model ")
191
+ print("############################################################")
192
+
193
+ # initialize noise for decoder sampling
194
+ decoder_noise = torch.randn((self.batch_size, self.initial_image_size[0], self.initial_image_size[1], self.initial_image_size[2]), device=self.device)
195
+ print("decoder noise: ", decoder_noise.size())
196
+
197
+ # project image embeddings to 4 tokens
198
+ projected_embeddings = self.decoder_model.decoder_projection(final_image_embeddings)
199
+ print("projected embeddings: ", projected_embeddings.size())
200
+
201
+ # encode text with GLIDE/decoder's text encoder
202
+ glide_text_embeddings = self.decoder_model._encode_text_with_glide(prompts)
203
+ print("glide text embeddings: ", glide_text_embeddings.size())
204
+
205
+ # concatenate embeddings for context
206
+ context = self.decoder_model._concatenate_embeddings(glide_text_embeddings, projected_embeddings)
207
+ print("context: ", context.size())
208
+
209
+ current_images = decoder_noise
210
+ # decoder diffusion sampling loop
211
+ t_counter = 0
212
+ for t in reversed(range(self.decoder_model.forward_diffusion.variance_scheduler.tau_num_steps)):
213
+ timesteps = torch.full((self.batch_size,), t, device=self.device)
214
+ prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
215
+
216
+ # Predict noise
217
+ predicted_noise = self.decoder_model.noise_predictor(current_images, timesteps, context, None)
218
+ if t == 10:
219
+ print("predicted noise: ", predicted_noise.size())
220
+
221
+ # apply guidance
222
+ guided_noise = self.compute_decoder_guided_prediction(
223
+ predicted_noise, current_images, timesteps, context
224
+ )
225
+ if t == 10:
226
+ print("guided noise: ", guided_noise.size())
227
+
228
+ # update images using reverse diffusion
229
+ current_images, _ = self.decoder_model.reverse_diffusion(
230
+ current_images, guided_noise, timesteps, prev_timesteps
231
+ )
232
+ if t == 10:
233
+ print("current image: ", current_images.size())
234
+ t_counter += 1
235
+
236
+ generated_64x64 = current_images
237
+ print(" number of iters of decoder model: ", t_counter)
238
+
239
+ # ====== FIRST UPSAMPLER: 64x64 -> 256x256 ======
240
+ print("############################################################")
241
+ print(" first upsampler ")
242
+ print("############################################################")
243
+ upsampled_256_noise = torch.randn((self.batch_size, self.initial_image_size[0], 256, 256), device=self.device)
244
+ current_256_images = upsampled_256_noise
245
+ print("upsampled 256 noise: ", upsampled_256_noise.size())
246
+
247
+ t_counter = 0
248
+ for t in reversed(range(self.low_res_upsampler.forward_diffusion.variance_scheduler.tau_num_steps)):
249
+ timesteps = torch.full((self.batch_size,), t, device=self.device)
250
+ prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
251
+
252
+ # predict noise for upsampling (conditioned on low-res image)
253
+ predicted_noise = self.low_res_upsampler(current_256_images, timesteps, generated_64x64)
254
+ if t == 10:
255
+ print("predicted noise: ", predicted_noise.size())
256
+
257
+ # update using reverse diffusion
258
+ current_256_images, _ = self.low_res_upsampler.reverse_diffusion(
259
+ current_256_images, predicted_noise, timesteps, prev_timesteps
260
+ )
261
+ if t == 10:
262
+ print("current 256 images: ", current_256_images.size())
263
+ t_counter += 1
264
+ print("number of iters in upsampler one:", t_counter)
265
+
266
+ self.images_256 = current_256_images
267
+
268
+ # ====== SECOND UPSAMPLER: 256x256 -> 1024x1024 (if enabled) ======
269
+ print("############################################################")
270
+ print(" second upsampler ")
271
+ print("############################################################")
272
+ if self.use_high_res_upsampler and self.high_res_upsampler:
273
+ upsampled_1024_noise = torch.randn((self.batch_size, self.initial_image_size[0], 1024, 1024), device=self.device)
274
+ current_1024_images = upsampled_1024_noise
275
+
276
+ t_counter = 0
277
+ for t in reversed(range(self.high_res_upsampler.forward_diffusion.variance_scheduler.tau_num_steps)):
278
+ timesteps = torch.full((self.batch_size,), t, device=self.device)
279
+ prev_timesteps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device)
280
+
281
+ # predict noise for upsampling (conditioned on 256x256 image)
282
+ predicted_noise = self.high_res_upsampler(current_1024_images, timesteps, self.images_256)
283
+ if t == 10:
284
+ print("predicted noise: ", predicted_noise.size())
285
+
286
+ # update using reverse diffusion
287
+ current_1024_images, _ = self.high_res_upsampler.reverse_diffusion(
288
+ current_1024_images, predicted_noise, timesteps, prev_timesteps
289
+ )
290
+ if t == 10:
291
+ print("current 1024 images: ", current_1024_images.size())
292
+ t_counter += 1
293
+ print("number of iters in upsampler two:", t_counter)
294
+
295
+ self.images_1024 = current_1024_images
296
+
297
+ # ====== POST-PROCESSING ======
298
+ # normalize output to [0, 1] range if requested
299
+ if normalize_output:
300
+ final_256 = (self.images_256 - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
301
+ final_1024 = None
302
+ if self.images_1024 is not None:
303
+ final_1024 = (self.images_1024 - self.image_output_range[0]) / (
304
+ self.image_output_range[1] - self.image_output_range[0])
305
+ else:
306
+ final_256 = self.images_256
307
+ final_1024 = self.images_1024
308
+
309
+ # save images if requested
310
+ if save_images:
311
+ os.makedirs(save_path, exist_ok=True)
312
+ os.makedirs(os.path.join(save_path, "images_256"), exist_ok=True)
313
+ if final_1024 is not None:
314
+ os.makedirs(os.path.join(save_path, "images_1024"), exist_ok=True)
315
+
316
+ for i in range(self.batch_size):
317
+ img_path_256 = os.path.join(save_path, "images_256", f"image_{i}.png")
318
+ torchvision.utils.save_image(final_256[i], img_path_256)
319
+
320
+ if final_1024 is not None:
321
+ img_path_1024 = os.path.join(save_path, "images_1024", f"image_{i}.png")
322
+ torchvision.utils.save_image(final_1024[i], img_path_1024)
323
+
324
+ # return final images
325
+ if final_1024 is not None:
326
+ return final_1024
327
+ else:
328
+ return final_256
329
+
330
+ def compute_prior_guided_prediction(
331
+ self,
332
+ predicted_embeddings: torch.Tensor,
333
+ text_embeddings: torch.Tensor,
334
+ current_embeddings: torch.Tensor,
335
+ timesteps: torch.Tensor
336
+ ) -> torch.Tensor:
337
+ """Computes classifier-free guidance for the prior model.
338
+
339
+ Combines conditioned and unconditioned predictions using the classifier-free guidance
340
+ formula to enhance the quality of generated image embeddings.
341
+
342
+ Parameters
343
+ ----------
344
+ `predicted_embeddings` : torch.Tensor
345
+ Conditioned predicted embeddings, shape (batch_size, embedding_dim).
346
+ `text_embeddings` : torch.Tensor
347
+ Text embeddings from CLIP, shape (batch_size, embedding_dim).
348
+ `current_embeddings` : torch.Tensor
349
+ Current noisy embeddings, shape (batch_size, embedding_dim).
350
+ `timesteps` : torch.Tensor
351
+ Timestep indices, shape (batch_size,).
352
+
353
+ Returns
354
+ -------
355
+ guided_embeddings : torch.Tensor
356
+ Guided embeddings, shape (batch_size, embedding_dim).
357
+ """
358
+ # use zero embeddings for unconditional generation
359
+ zero_text_embeddings = torch.zeros_like(text_embeddings)
360
+ unconditioned_pred = self.prior_model(zero_text_embeddings, current_embeddings, timesteps)
361
+
362
+ # CFG formula: (1 + guidance_scale) * conditioned - guidance_scale * unconditioned
363
+ return (1.0 + self.prior_guidance_scale) * predicted_embeddings - self.prior_guidance_scale * unconditioned_pred
364
+
365
+ def compute_decoder_guided_prediction(
366
+ self,
367
+ predicted_noise: torch.Tensor,
368
+ current_images: torch.Tensor,
369
+ timesteps: torch.Tensor,
370
+ context: torch.Tensor
371
+ ) -> torch.Tensor:
372
+ """Computes classifier-free guidance for the decoder model.
373
+
374
+ Combines conditioned and unconditioned noise predictions using the classifier-free
375
+ guidance formula to enhance the quality of generated images.
376
+
377
+ Parameters
378
+ ----------
379
+ `predicted_noise` : torch.Tensor
380
+ Conditioned predicted noise, shape (batch_size, channels, height, width).
381
+ `current_images` : torch.Tensor
382
+ Current noisy images, shape (batch_size, channels, height, width).
383
+ `timesteps` : torch.Tensor
384
+ Timestep indices, shape (batch_size,).
385
+ `context` : torch.Tensor
386
+ Context embeddings (concatenated GLIDE text and projected image embeddings),
387
+ shape (batch_size, seq_len, embedding_dim).
388
+
389
+ Returns
390
+ -------
391
+ guided_noise : torch.Tensor
392
+ Guided noise prediction, shape (batch_size, channels, height, width).
393
+ """
394
+ zero_context = torch.zeros_like(context)
395
+ unconditioned_noise = self.decoder_model.noise_predictor(current_images, timesteps, zero_context, None)
396
+
397
+ # CFG formula: (1 + guidance_scale) * conditioned - guidance_scale * unconditioned
398
+ return (1.0 + self.decoder_guidance_scale) * predicted_noise - self.decoder_guidance_scale * unconditioned_noise
399
+
400
+ def to(self, device: Union[torch.device, str]) -> Self:
401
+ """Moves the module and all its components to the specified device.
402
+
403
+ Updates the device attribute and moves all sub-models (prior, decoder, CLIP,
404
+ and upsamplers) to the specified device.
405
+
406
+ Parameters
407
+ ----------
408
+ device : Union[torch.device, str]
409
+ Target device for the module and its components.
410
+
411
+ Returns
412
+ -------
413
+ SampleUnCLIP
414
+ The module moved to the specified device.
415
+ """
416
+ if isinstance(device, str):
417
+ device = torch.device(device)
418
+
419
+ self.device = device
420
+
421
+ # move all sub-models to the specified device
422
+ self.prior_model.to(device)
423
+ self.decoder_model.to(device)
424
+ self.clip_model.to(device)
425
+ self.low_res_upsampler.to(device)
426
+
427
+ if self.high_res_upsampler is not None:
428
+ self.high_res_upsampler.to(device)
429
+
430
+ return super().to(device)
431
+
432
+
433
+ """
434
+ from prior_model import UnCLIPTransformerPrior
435
+ from utils import NoisePredictor, TextEncoder
436
+ from clip_model import CLIPEncoder
437
+ from project_prior import Projection
438
+ import torch
439
+ from prior_diff import VarianceSchedulerUnCLIP, ForwardUnCLIP, ReverseUnCLIP
440
+ from decoder_model import UnClipDecoder
441
+ from upsampler import UpsamplerUnCLIP
442
+
443
+ device = torch.device("cuda")
444
+
445
+
446
+ h_model = VarianceSchedulerUnCLIP(
447
+ num_steps=1000,
448
+ beta_start=1e-4,
449
+ beta_end=0.02,
450
+ trainable_beta=True,
451
+ beta_method="cosine"
452
+ ).to(device)
453
+
454
+ c_model = CLIPEncoder(model_name="openai/clip-vit-base-patch32").to(device)
455
+ tp = Projection(
456
+ input_dim=512,
457
+ output_dim=320,
458
+ hidden_dim=480,
459
+ num_layers=2,
460
+ dropout=0.1,
461
+ use_layer_norm=True
462
+ ).to(device)
463
+ ip = Projection(
464
+ input_dim=512,
465
+ output_dim=320,
466
+ hidden_dim=480,
467
+ num_layers=2,
468
+ dropout=0.1,
469
+ use_layer_norm=True
470
+ ).to(device)
471
+
472
+ d_model = ForwardUnCLIP(h_model).to(device)
473
+ r_model = ReverseUnCLIP(h_model).to(device)
474
+
475
+ prior_model = UnCLIPTransformerPrior(
476
+ forward_diffusion=d_model,
477
+ reverse_diffusion=r_model,
478
+ text_projection=tp,
479
+ image_projection=ip,
480
+ embedding_dim=320,
481
+ num_layers=12,
482
+ num_attention_heads=8,
483
+ feedforward_dim=512,
484
+ max_sequence_length=2,
485
+ dropout_rate=0.3
486
+ ).to(device)
487
+
488
+
489
+ dn_model = NoisePredictor(
490
+ in_channels=3,
491
+ down_channels=[16, 32],
492
+ mid_channels=[32, 32],
493
+ up_channels=[32, 16],
494
+ down_sampling=[True, True],
495
+ time_embed_dim=512,
496
+ y_embed_dim=512,
497
+ num_down_blocks=2,
498
+ num_mid_blocks=2,
499
+ num_up_blocks=2,
500
+ down_sampling_factor=2
501
+ ).to(device)
502
+
503
+ dt_proj = Projection(
504
+ input_dim=512,
505
+ output_dim=320,
506
+ hidden_dim=468,
507
+ num_layers=2,
508
+ dropout=0.1,
509
+ use_layer_norm=True
510
+ ).to(device)
511
+ di_proj = Projection(
512
+ input_dim=512,
513
+ output_dim=320,
514
+ hidden_dim=468,
515
+ num_layers=2,
516
+ dropout=0.1,
517
+ use_layer_norm=True
518
+ ).to(device)
519
+
520
+ dh_model = VarianceSchedulerUnCLIP(
521
+ num_steps=500,
522
+ beta_start=1e-4,
523
+ beta_end=0.02,
524
+ trainable_beta=False,
525
+ beta_method="linear"
526
+ ).to(device)
527
+ dfor_ = ForwardUnCLIP(h_model).to(device)
528
+ drev_ = ReverseUnCLIP(h_model).to(device)
529
+
530
+ dcond = TextEncoder(
531
+ use_pretrained_model=True,
532
+ model_name="bert-base-uncased",
533
+ vocabulary_size=30522,
534
+ num_layers=2,
535
+ input_dimension=512,
536
+ output_dimension=512,
537
+ num_heads=2,
538
+ context_length=77
539
+ ).to(device)
540
+
541
+ decoder_model = UnClipDecoder(
542
+ embedding_dim=512,
543
+ noise_predictor=dn_model,
544
+ forward_diffusion=dfor_,
545
+ reverse_diffusion=drev_,
546
+ conditional_model=dcond,
547
+ tokenizer=None,
548
+ device="cuda",
549
+ output_range=(-1.0, 1.0),
550
+ normalize=True,
551
+ classifier_free=0.1,
552
+ drop_caption=0.5,
553
+ max_length=77
554
+ ).to(device)
555
+
556
+
557
+ hyp = VarianceSchedulerUnCLIP(
558
+ num_steps=1000,
559
+ beta_start=1e-4,
560
+ beta_end=0.02,
561
+ trainable_beta=False,
562
+ beta_method="cosine"
563
+ ).to(device)
564
+
565
+
566
+
567
+
568
+ up_for = ForwardUnCLIP(hyp).to(device)
569
+ up_rev = ReverseUnCLIP(hyp).to(device)
570
+
571
+ upsampler_model_first = UpsamplerUnCLIP(
572
+ forward_diffusion=up_for,
573
+ reverse_diffusion=up_rev,
574
+ in_channels= 3,
575
+ out_channels= 3,
576
+ model_channels= 32,
577
+ num_res_blocks = 2,
578
+ channel_mult = (1, 2, 4, 8),
579
+ dropout = 0.1,
580
+ time_embed_dim = 756,
581
+ low_res_size = 64,
582
+ high_res_size = 256
583
+ ).to(device)
584
+
585
+ upsampler_model_second = UpsamplerUnCLIP(
586
+ forward_diffusion=up_for,
587
+ reverse_diffusion=up_rev,
588
+ in_channels= 3,
589
+ out_channels= 3,
590
+ model_channels= 32,
591
+ num_res_blocks = 2,
592
+ channel_mult = (1, 2, 4, 8),
593
+ dropout = 0.1,
594
+ time_embed_dim = 756,
595
+ low_res_size = 256,
596
+ high_res_size = 1024
597
+ ).to(device)
598
+
599
+
600
+
601
+ sampler = SampleUnCLIP(
602
+ prior_model=prior_model,
603
+ decoder_model=decoder_model,
604
+ clip_model=c_model,
605
+ first_upsampler_model=upsampler_model_first,
606
+ second_upsampler_model=upsampler_model_second,
607
+ device=None,
608
+ prior_guidance_scale=4.0,
609
+ decoder_guidance_scale=8.0,
610
+ batch_size=1,
611
+ normalize=True,
612
+ reduce_dim=True,
613
+ embedding_dim=512,
614
+ image_size=(3, 64, 64),
615
+ use_second_upsampler=True,
616
+ output_range=(-1.0, 1.0)
617
+ ).to(device)
618
+
619
+ f = sampler(
620
+ prompts = ["this is a test prompt"],
621
+ normalize_output = True,
622
+ save_images = True,
623
+ save_path = "unclip_generated"
624
+ )
625
+ """
626
+