TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
@@ -0,0 +1,312 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, List, Tuple, Union
4
+ from project_decoder import CLIPContextProjection
5
+ from transformers import BertTokenizer
6
+
7
+
8
+ class UnClipDecoder(nn.Module):
9
+ """Decoder for UnCLIP diffusion models.
10
+
11
+ Combines CLIP image embeddings and text embeddings to guide the denoising process,
12
+ using a noise predictor and diffusion processes. Incorporates classifier-free guidance,
13
+ text caption dropout, and projection of CLIP embeddings into context tokens.
14
+
15
+ Parameters
16
+ ----------
17
+ `clip_embedding_dim` : int
18
+ Dimensionality of the input embeddings.
19
+ `noise_predictor` : nn.Module
20
+ Model to predict noise during the denoising process.
21
+ `forward_diffusion` : nn.Module
22
+ Forward diffusion module (e.g., ForwardUnCLIP) for adding noise.
23
+ `reverse_diffusion` : nn.Module
24
+ Reverse diffusion module (e.g., ReverseUnCLIP) for denoising.
25
+ `glide_text_encoder` : nn.Module, optional
26
+ GLIDE text encoder for processing text prompts, default None.
27
+ `bert_tokenizer` : BertTokenizer, optional
28
+ Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
29
+ `device` : Union[str, torch.device], optional
30
+ Device for computation (default: CUDA if available, else CPU).
31
+ `image_output_range` : Tuple[float, float], optional
32
+ Range for clamping output images (default: (-1.0, 1.0)).
33
+ `normalize_clip_embeddings` : bool, optional
34
+ Whether to normalize outputs (default: True).
35
+ `classifier_free_prop` : float, optional
36
+ Probability for classifier-free guidance (default: 0.1, per paper).
37
+ `drop_caption` : float, optional
38
+ Probability for text caption dropout (default: 0.5, per paper).
39
+ `max_token_length` : int, optional
40
+ Maximum length for tokenized prompts (default: 77).
41
+ """
42
+ def __init__(
43
+ self,
44
+ clip_embedding_dim: int,
45
+ noise_predictor: nn.Module,
46
+ forward_diffusion: nn.Module,
47
+ reverse_diffusion: nn.Module,
48
+ glide_text_encoder: torch.nn.Module = None, # GLIDE text encoder
49
+ bert_tokenizer: Optional[BertTokenizer] = None,
50
+ device: Optional[Union[str, torch.device]] = None,
51
+ image_output_range: Tuple[float, float] = (-1.0, 1.0),
52
+ normalize_clip_embeddings: bool = True,
53
+ classifier_free_prop: float = 0.1, # paper specifies 10%
54
+ drop_caption: float = 0.5, # paper specifies 50%
55
+ max_token_length: int = 77 # max_token_length for tokenization
56
+ ) -> None:
57
+ super().__init__()
58
+
59
+ if device is None:
60
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ elif isinstance(device, str):
62
+ self.device = torch.device(device)
63
+ else:
64
+ self.device = device
65
+ self.clip_embedding_dim = clip_embedding_dim
66
+
67
+ # core models
68
+ self.noise_predictor = noise_predictor.to(self.device)
69
+ self.forward_diffusion = forward_diffusion.to(self.device)
70
+ self.reverse_diffusion = reverse_diffusion.to(self.device)
71
+ self.glide_text_encoder = glide_text_encoder.to(self.device) if glide_text_encoder else None
72
+
73
+ # paper: "projecting CLIP embeddings into four extra tokens of context"
74
+ self.clip_decoder_projection = CLIPContextProjection(
75
+ clip_embedding_dim=self.clip_embedding_dim,
76
+ num_tokens=4).to(self.device
77
+ )
78
+ self.clip_time_projection = nn.Linear(self.clip_embedding_dim, self.clip_embedding_dim).to(self.device)
79
+
80
+ # training parameters
81
+ self.image_output_range = image_output_range
82
+ self.normalize_clip_embeddings = normalize_clip_embeddings
83
+ self.classifier_free_prop = classifier_free_prop
84
+ self.drop_caption = drop_caption
85
+ self.max_token_length = max_token_length
86
+
87
+ # initialize tokenizer
88
+ if bert_tokenizer is None:
89
+ try:
90
+ self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
91
+ except Exception as e:
92
+ raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
93
+
94
+
95
+ def forward(
96
+ self,
97
+ image_embeddings: torch.Tensor,
98
+ text_embeddings: torch.Tensor,
99
+ images: torch.Tensor,
100
+ texts: torch.Tensor,
101
+ p_classifier_free: float,
102
+ p_text_drop: float) -> Tuple[torch.Tensor, torch.Tensor]:
103
+ """Processes embeddings and images to predict noise for training.
104
+
105
+ Applies classifier-free guidance and text dropout, projects CLIP image embeddings
106
+ into context tokens, encodes text with GLIDE, and predicts noise for the diffusion process.
107
+
108
+ Parameters
109
+ ----------
110
+ `image_embeddings` : torch.Tensor
111
+ CLIP image embeddings, shape (batch_size, embedding_dim).
112
+ `text_embeddings` : torch.Tensor
113
+ CLIP text embeddings, shape (batch_size, embedding_dim).
114
+ `images` : torch.Tensor
115
+ Input images, shape (batch_size, channels, height, width).
116
+ `texts` : torch.Tensor
117
+ Text prompts for conditional generation.
118
+ `p_classifier_free` : float
119
+ Probability for applying classifier-free guidance.
120
+ `p_text_drop` : float
121
+ Probability for applying text caption dropout.
122
+
123
+ Returns
124
+ -------
125
+ predicted_noise : torch.Tensor
126
+ Predicted noise tensor, shape (batch_size, channels, height, width).
127
+ noise : torch.Tensor
128
+ Ground truth noise tensor, shape (batch_size, channels, height, width).
129
+ """
130
+
131
+ image_embeddings = self._apply_classifier_free_guidance(image_embeddings, p_classifier_free)
132
+ text_embeddings = self._apply_text_dropout(text_embeddings, p_text_drop)
133
+
134
+ # project z_i to 4 tokens
135
+ c = self.clip_decoder_projection(image_embeddings)
136
+ # print("z i to 4 tokens: ", c.size())
137
+
138
+ # encode text with GLIDE
139
+ y_encoded = self._encode_text_with_glide(texts if text_embeddings is not None else None)
140
+ # if y_encoded is not None:
141
+ #print("y_encodded : ", y_encoded.size())
142
+
143
+ # concatenate embeddings
144
+ context = self._concatenate_embeddings(y_encoded, c)
145
+ # print("y_encodded and c concat : ", s.size())
146
+
147
+ # sample timestep and noise
148
+ t, noise = self._sample_timestep_and_noise(images.shape[0], images.shape)
149
+ # print("t : ", t.size())
150
+ # print("noise : ", noise.size())
151
+
152
+ # compute noisy image
153
+ noisy_images = self.forward_diffusion(images, noise, t)
154
+ # print("noisy images : ", noisy_images.size())
155
+
156
+ clip_image_embedding = self.clip_time_projection(image_embeddings)
157
+ # print("clip image embedded : ", clip_image_embedding.size())
158
+
159
+ predicted_noise = self.noise_predictor(noisy_images, t, context, clip_image_embedding)
160
+ # print("predicted noise : ", predicted_noise.size())
161
+
162
+ return predicted_noise, noise
163
+ def inference_forward(self, image_embeddings, prompt_embeddings):
164
+ pass
165
+
166
+ def _apply_classifier_free_guidance(self, image_embeddings: torch.Tensor, p_value: float) -> torch.Tensor:
167
+ """Applies classifier-free guidance to image embeddings.
168
+
169
+ Sets image embeddings to zero with a specified probability to implement
170
+ classifier-free guidance, as described in the UnCLIP paper.
171
+
172
+ Parameters
173
+ ----------
174
+ `image_embeddings` : torch.Tensor
175
+ CLIP image embeddings, shape (batch_size, embedding_dim).
176
+ `p_value` : float
177
+ Probability for applying classifier-free guidance.
178
+
179
+ Returns
180
+ -------
181
+ image_embeddings : torch.Tensor
182
+ Modified image embeddings, shape (batch_size, embedding_dim).
183
+ """
184
+ if p_value < self.classifier_free_prop:
185
+ # set z_i ← 0 {classifier-free guidance}
186
+ image_embeddings = torch.zeros_like(image_embeddings)
187
+
188
+ return image_embeddings
189
+
190
+ def _apply_text_dropout(self, text_embeddings: torch.Tensor, p_value: float) -> Optional[torch.Tensor]:
191
+ """Applies text caption dropout to text embeddings.
192
+
193
+ Drops text embeddings with a specified probability to implement text dropout,
194
+ as described in the UnCLIP paper.
195
+
196
+ Parameters
197
+ ----------
198
+ `text_embeddings` : torch.Tensor
199
+ CLIP text embeddings, shape (batch_size, embedding_dim).
200
+ `p_value` : float
201
+ Probability for applying text caption dropout.
202
+
203
+ Returns
204
+ -------
205
+ text_embeddings : torch.Tensor or None
206
+ Modified text embeddings or None if dropped, shape (batch_size, embedding_dim).
207
+ """
208
+ if p_value < self.drop_caption:
209
+ # set y ← ∅ {drop text caption}
210
+ return None
211
+
212
+ return text_embeddings
213
+
214
+
215
+ def _encode_text_with_glide(self, texts: Union[List, torch.Tensor]) -> Optional[torch.Tensor]:
216
+ """Encodes text prompts using the GLIDE text encoder.
217
+
218
+ Tokenizes and encodes text prompts into embeddings using the GLIDE text encoder,
219
+ returning None if no text or conditional model is provided.
220
+
221
+ Parameters
222
+ ----------
223
+ `texts` : Union[List, torch.Tensor]
224
+ Text prompts or tensor of text data.
225
+
226
+ Returns
227
+ -------
228
+ y_encoded : torch.Tensor or None
229
+ Encoded text embeddings, shape (batch_size, seq_len, embedding_dim), or None.
230
+ """
231
+ if texts is None:
232
+ return None
233
+
234
+ if self.glide_text_encoder is None:
235
+ return None
236
+
237
+ # convert to string list if needed
238
+ if isinstance(texts, torch.Tensor):
239
+ texts = texts.cpu().numpy().tolist()
240
+ texts = [str(item) for item in texts]
241
+
242
+ # tokenize
243
+ tokenized = self.bert_tokenizer(
244
+ texts,
245
+ padding="max_length",
246
+ truncation=True,
247
+ max_length=self.max_token_length,
248
+ return_tensors="pt"
249
+ ).to(self.device)
250
+
251
+ # get embeddings from GLIDE text encoder
252
+ input_ids = tokenized["input_ids"]
253
+ attention_mask = tokenized["attention_mask"]
254
+ y_encoded = self.glide_text_encoder(input_ids, attention_mask)
255
+ print("y shape: ", y_encoded.size())
256
+
257
+ return y_encoded
258
+
259
+ def _concatenate_embeddings(self, y_encoded: Optional[torch.Tensor], c: torch.Tensor) -> torch.Tensor:
260
+ """Concatenates GLIDE text embeddings and context tokens.
261
+
262
+ Combines encoded text embeddings (if available) with projected context tokens
263
+ along the sequence dimension, as specified in the UnCLIP paper.
264
+
265
+ Parameters
266
+ ----------
267
+ `y_encoded` : torch.Tensor or None
268
+ Encoded text embeddings from GLIDE, shape (batch_size, seq_len, embedding_dim).
269
+ `c` : torch.Tensor
270
+ Projected context tokens, shape (batch_size, num_tokens, embedding_dim).
271
+
272
+ Returns
273
+ -------
274
+ s : torch.Tensor
275
+ Concatenated embeddings, shape (batch_size, seq_len + num_tokens, embedding_dim).
276
+ """
277
+ if y_encoded is not None:
278
+ # ensure y_encoded has sequence dimension
279
+ if len(y_encoded.shape) == 2: # [batch_size, embed_dim]
280
+ y_encoded = y_encoded.unsqueeze(1) # [batch_size, 1, embed_dim]
281
+
282
+ # concatenate along the sequence dimension
283
+ s = torch.cat([y_encoded, c], dim=1) # [batch_size, seq_len + 4, embed_dim]
284
+ else:
285
+ s = c # [batch_size, 4, embed_dim]
286
+
287
+ return s
288
+
289
+ def _sample_timestep_and_noise(self, batch_size: int, image_shape: torch.Size) -> Tuple[torch.Tensor, torch.Tensor]:
290
+ """Samples timesteps and noise for the diffusion process.
291
+
292
+ Generates random timesteps and Gaussian noise for use in the forward diffusion process.
293
+
294
+ Parameters
295
+ ----------
296
+ `batch_size` : int
297
+ Number of samples in the batch.
298
+ `image_shape` : torch.Size
299
+ Shape of the images, typically (batch_size, channels, height, width).
300
+
301
+ Returns
302
+ -------
303
+ t : torch.Tensor
304
+ Sampled timestep indices, shape (batch_size,).
305
+ noise : torch.Tensor
306
+ Sampled Gaussian noise, shape (batch_size, channels, height, width).
307
+ """
308
+ # sample timestep t ~ Uniform(1, T)
309
+ t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
310
+ # sample noise ε ~ N(0, I)
311
+ noise = torch.randn(image_shape, device=self.device)
312
+ return t, noise