TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
unclip/decoder_model.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Optional, List, Tuple, Union
|
|
4
|
+
from project_decoder import CLIPContextProjection
|
|
5
|
+
from transformers import BertTokenizer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class UnClipDecoder(nn.Module):
|
|
9
|
+
"""Decoder for UnCLIP diffusion models.
|
|
10
|
+
|
|
11
|
+
Combines CLIP image embeddings and text embeddings to guide the denoising process,
|
|
12
|
+
using a noise predictor and diffusion processes. Incorporates classifier-free guidance,
|
|
13
|
+
text caption dropout, and projection of CLIP embeddings into context tokens.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
`clip_embedding_dim` : int
|
|
18
|
+
Dimensionality of the input embeddings.
|
|
19
|
+
`noise_predictor` : nn.Module
|
|
20
|
+
Model to predict noise during the denoising process.
|
|
21
|
+
`forward_diffusion` : nn.Module
|
|
22
|
+
Forward diffusion module (e.g., ForwardUnCLIP) for adding noise.
|
|
23
|
+
`reverse_diffusion` : nn.Module
|
|
24
|
+
Reverse diffusion module (e.g., ReverseUnCLIP) for denoising.
|
|
25
|
+
`glide_text_encoder` : nn.Module, optional
|
|
26
|
+
GLIDE text encoder for processing text prompts, default None.
|
|
27
|
+
`bert_tokenizer` : BertTokenizer, optional
|
|
28
|
+
Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
|
|
29
|
+
`device` : Union[str, torch.device], optional
|
|
30
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
31
|
+
`image_output_range` : Tuple[float, float], optional
|
|
32
|
+
Range for clamping output images (default: (-1.0, 1.0)).
|
|
33
|
+
`normalize_clip_embeddings` : bool, optional
|
|
34
|
+
Whether to normalize outputs (default: True).
|
|
35
|
+
`classifier_free_prop` : float, optional
|
|
36
|
+
Probability for classifier-free guidance (default: 0.1, per paper).
|
|
37
|
+
`drop_caption` : float, optional
|
|
38
|
+
Probability for text caption dropout (default: 0.5, per paper).
|
|
39
|
+
`max_token_length` : int, optional
|
|
40
|
+
Maximum length for tokenized prompts (default: 77).
|
|
41
|
+
"""
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
clip_embedding_dim: int,
|
|
45
|
+
noise_predictor: nn.Module,
|
|
46
|
+
forward_diffusion: nn.Module,
|
|
47
|
+
reverse_diffusion: nn.Module,
|
|
48
|
+
glide_text_encoder: torch.nn.Module = None, # GLIDE text encoder
|
|
49
|
+
bert_tokenizer: Optional[BertTokenizer] = None,
|
|
50
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
51
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0),
|
|
52
|
+
normalize_clip_embeddings: bool = True,
|
|
53
|
+
classifier_free_prop: float = 0.1, # paper specifies 10%
|
|
54
|
+
drop_caption: float = 0.5, # paper specifies 50%
|
|
55
|
+
max_token_length: int = 77 # max_token_length for tokenization
|
|
56
|
+
) -> None:
|
|
57
|
+
super().__init__()
|
|
58
|
+
|
|
59
|
+
if device is None:
|
|
60
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
61
|
+
elif isinstance(device, str):
|
|
62
|
+
self.device = torch.device(device)
|
|
63
|
+
else:
|
|
64
|
+
self.device = device
|
|
65
|
+
self.clip_embedding_dim = clip_embedding_dim
|
|
66
|
+
|
|
67
|
+
# core models
|
|
68
|
+
self.noise_predictor = noise_predictor.to(self.device)
|
|
69
|
+
self.forward_diffusion = forward_diffusion.to(self.device)
|
|
70
|
+
self.reverse_diffusion = reverse_diffusion.to(self.device)
|
|
71
|
+
self.glide_text_encoder = glide_text_encoder.to(self.device) if glide_text_encoder else None
|
|
72
|
+
|
|
73
|
+
# paper: "projecting CLIP embeddings into four extra tokens of context"
|
|
74
|
+
self.clip_decoder_projection = CLIPContextProjection(
|
|
75
|
+
clip_embedding_dim=self.clip_embedding_dim,
|
|
76
|
+
num_tokens=4).to(self.device
|
|
77
|
+
)
|
|
78
|
+
self.clip_time_projection = nn.Linear(self.clip_embedding_dim, self.clip_embedding_dim).to(self.device)
|
|
79
|
+
|
|
80
|
+
# training parameters
|
|
81
|
+
self.image_output_range = image_output_range
|
|
82
|
+
self.normalize_clip_embeddings = normalize_clip_embeddings
|
|
83
|
+
self.classifier_free_prop = classifier_free_prop
|
|
84
|
+
self.drop_caption = drop_caption
|
|
85
|
+
self.max_token_length = max_token_length
|
|
86
|
+
|
|
87
|
+
# initialize tokenizer
|
|
88
|
+
if bert_tokenizer is None:
|
|
89
|
+
try:
|
|
90
|
+
self.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
91
|
+
except Exception as e:
|
|
92
|
+
raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def forward(
|
|
96
|
+
self,
|
|
97
|
+
image_embeddings: torch.Tensor,
|
|
98
|
+
text_embeddings: torch.Tensor,
|
|
99
|
+
images: torch.Tensor,
|
|
100
|
+
texts: torch.Tensor,
|
|
101
|
+
p_classifier_free: float,
|
|
102
|
+
p_text_drop: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
103
|
+
"""Processes embeddings and images to predict noise for training.
|
|
104
|
+
|
|
105
|
+
Applies classifier-free guidance and text dropout, projects CLIP image embeddings
|
|
106
|
+
into context tokens, encodes text with GLIDE, and predicts noise for the diffusion process.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
`image_embeddings` : torch.Tensor
|
|
111
|
+
CLIP image embeddings, shape (batch_size, embedding_dim).
|
|
112
|
+
`text_embeddings` : torch.Tensor
|
|
113
|
+
CLIP text embeddings, shape (batch_size, embedding_dim).
|
|
114
|
+
`images` : torch.Tensor
|
|
115
|
+
Input images, shape (batch_size, channels, height, width).
|
|
116
|
+
`texts` : torch.Tensor
|
|
117
|
+
Text prompts for conditional generation.
|
|
118
|
+
`p_classifier_free` : float
|
|
119
|
+
Probability for applying classifier-free guidance.
|
|
120
|
+
`p_text_drop` : float
|
|
121
|
+
Probability for applying text caption dropout.
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
predicted_noise : torch.Tensor
|
|
126
|
+
Predicted noise tensor, shape (batch_size, channels, height, width).
|
|
127
|
+
noise : torch.Tensor
|
|
128
|
+
Ground truth noise tensor, shape (batch_size, channels, height, width).
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
image_embeddings = self._apply_classifier_free_guidance(image_embeddings, p_classifier_free)
|
|
132
|
+
text_embeddings = self._apply_text_dropout(text_embeddings, p_text_drop)
|
|
133
|
+
|
|
134
|
+
# project z_i to 4 tokens
|
|
135
|
+
c = self.clip_decoder_projection(image_embeddings)
|
|
136
|
+
# print("z i to 4 tokens: ", c.size())
|
|
137
|
+
|
|
138
|
+
# encode text with GLIDE
|
|
139
|
+
y_encoded = self._encode_text_with_glide(texts if text_embeddings is not None else None)
|
|
140
|
+
# if y_encoded is not None:
|
|
141
|
+
#print("y_encodded : ", y_encoded.size())
|
|
142
|
+
|
|
143
|
+
# concatenate embeddings
|
|
144
|
+
context = self._concatenate_embeddings(y_encoded, c)
|
|
145
|
+
# print("y_encodded and c concat : ", s.size())
|
|
146
|
+
|
|
147
|
+
# sample timestep and noise
|
|
148
|
+
t, noise = self._sample_timestep_and_noise(images.shape[0], images.shape)
|
|
149
|
+
# print("t : ", t.size())
|
|
150
|
+
# print("noise : ", noise.size())
|
|
151
|
+
|
|
152
|
+
# compute noisy image
|
|
153
|
+
noisy_images = self.forward_diffusion(images, noise, t)
|
|
154
|
+
# print("noisy images : ", noisy_images.size())
|
|
155
|
+
|
|
156
|
+
clip_image_embedding = self.clip_time_projection(image_embeddings)
|
|
157
|
+
# print("clip image embedded : ", clip_image_embedding.size())
|
|
158
|
+
|
|
159
|
+
predicted_noise = self.noise_predictor(noisy_images, t, context, clip_image_embedding)
|
|
160
|
+
# print("predicted noise : ", predicted_noise.size())
|
|
161
|
+
|
|
162
|
+
return predicted_noise, noise
|
|
163
|
+
def inference_forward(self, image_embeddings, prompt_embeddings):
|
|
164
|
+
pass
|
|
165
|
+
|
|
166
|
+
def _apply_classifier_free_guidance(self, image_embeddings: torch.Tensor, p_value: float) -> torch.Tensor:
|
|
167
|
+
"""Applies classifier-free guidance to image embeddings.
|
|
168
|
+
|
|
169
|
+
Sets image embeddings to zero with a specified probability to implement
|
|
170
|
+
classifier-free guidance, as described in the UnCLIP paper.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
`image_embeddings` : torch.Tensor
|
|
175
|
+
CLIP image embeddings, shape (batch_size, embedding_dim).
|
|
176
|
+
`p_value` : float
|
|
177
|
+
Probability for applying classifier-free guidance.
|
|
178
|
+
|
|
179
|
+
Returns
|
|
180
|
+
-------
|
|
181
|
+
image_embeddings : torch.Tensor
|
|
182
|
+
Modified image embeddings, shape (batch_size, embedding_dim).
|
|
183
|
+
"""
|
|
184
|
+
if p_value < self.classifier_free_prop:
|
|
185
|
+
# set z_i ← 0 {classifier-free guidance}
|
|
186
|
+
image_embeddings = torch.zeros_like(image_embeddings)
|
|
187
|
+
|
|
188
|
+
return image_embeddings
|
|
189
|
+
|
|
190
|
+
def _apply_text_dropout(self, text_embeddings: torch.Tensor, p_value: float) -> Optional[torch.Tensor]:
|
|
191
|
+
"""Applies text caption dropout to text embeddings.
|
|
192
|
+
|
|
193
|
+
Drops text embeddings with a specified probability to implement text dropout,
|
|
194
|
+
as described in the UnCLIP paper.
|
|
195
|
+
|
|
196
|
+
Parameters
|
|
197
|
+
----------
|
|
198
|
+
`text_embeddings` : torch.Tensor
|
|
199
|
+
CLIP text embeddings, shape (batch_size, embedding_dim).
|
|
200
|
+
`p_value` : float
|
|
201
|
+
Probability for applying text caption dropout.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
text_embeddings : torch.Tensor or None
|
|
206
|
+
Modified text embeddings or None if dropped, shape (batch_size, embedding_dim).
|
|
207
|
+
"""
|
|
208
|
+
if p_value < self.drop_caption:
|
|
209
|
+
# set y ← ∅ {drop text caption}
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
return text_embeddings
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _encode_text_with_glide(self, texts: Union[List, torch.Tensor]) -> Optional[torch.Tensor]:
|
|
216
|
+
"""Encodes text prompts using the GLIDE text encoder.
|
|
217
|
+
|
|
218
|
+
Tokenizes and encodes text prompts into embeddings using the GLIDE text encoder,
|
|
219
|
+
returning None if no text or conditional model is provided.
|
|
220
|
+
|
|
221
|
+
Parameters
|
|
222
|
+
----------
|
|
223
|
+
`texts` : Union[List, torch.Tensor]
|
|
224
|
+
Text prompts or tensor of text data.
|
|
225
|
+
|
|
226
|
+
Returns
|
|
227
|
+
-------
|
|
228
|
+
y_encoded : torch.Tensor or None
|
|
229
|
+
Encoded text embeddings, shape (batch_size, seq_len, embedding_dim), or None.
|
|
230
|
+
"""
|
|
231
|
+
if texts is None:
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
if self.glide_text_encoder is None:
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
# convert to string list if needed
|
|
238
|
+
if isinstance(texts, torch.Tensor):
|
|
239
|
+
texts = texts.cpu().numpy().tolist()
|
|
240
|
+
texts = [str(item) for item in texts]
|
|
241
|
+
|
|
242
|
+
# tokenize
|
|
243
|
+
tokenized = self.bert_tokenizer(
|
|
244
|
+
texts,
|
|
245
|
+
padding="max_length",
|
|
246
|
+
truncation=True,
|
|
247
|
+
max_length=self.max_token_length,
|
|
248
|
+
return_tensors="pt"
|
|
249
|
+
).to(self.device)
|
|
250
|
+
|
|
251
|
+
# get embeddings from GLIDE text encoder
|
|
252
|
+
input_ids = tokenized["input_ids"]
|
|
253
|
+
attention_mask = tokenized["attention_mask"]
|
|
254
|
+
y_encoded = self.glide_text_encoder(input_ids, attention_mask)
|
|
255
|
+
print("y shape: ", y_encoded.size())
|
|
256
|
+
|
|
257
|
+
return y_encoded
|
|
258
|
+
|
|
259
|
+
def _concatenate_embeddings(self, y_encoded: Optional[torch.Tensor], c: torch.Tensor) -> torch.Tensor:
|
|
260
|
+
"""Concatenates GLIDE text embeddings and context tokens.
|
|
261
|
+
|
|
262
|
+
Combines encoded text embeddings (if available) with projected context tokens
|
|
263
|
+
along the sequence dimension, as specified in the UnCLIP paper.
|
|
264
|
+
|
|
265
|
+
Parameters
|
|
266
|
+
----------
|
|
267
|
+
`y_encoded` : torch.Tensor or None
|
|
268
|
+
Encoded text embeddings from GLIDE, shape (batch_size, seq_len, embedding_dim).
|
|
269
|
+
`c` : torch.Tensor
|
|
270
|
+
Projected context tokens, shape (batch_size, num_tokens, embedding_dim).
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
s : torch.Tensor
|
|
275
|
+
Concatenated embeddings, shape (batch_size, seq_len + num_tokens, embedding_dim).
|
|
276
|
+
"""
|
|
277
|
+
if y_encoded is not None:
|
|
278
|
+
# ensure y_encoded has sequence dimension
|
|
279
|
+
if len(y_encoded.shape) == 2: # [batch_size, embed_dim]
|
|
280
|
+
y_encoded = y_encoded.unsqueeze(1) # [batch_size, 1, embed_dim]
|
|
281
|
+
|
|
282
|
+
# concatenate along the sequence dimension
|
|
283
|
+
s = torch.cat([y_encoded, c], dim=1) # [batch_size, seq_len + 4, embed_dim]
|
|
284
|
+
else:
|
|
285
|
+
s = c # [batch_size, 4, embed_dim]
|
|
286
|
+
|
|
287
|
+
return s
|
|
288
|
+
|
|
289
|
+
def _sample_timestep_and_noise(self, batch_size: int, image_shape: torch.Size) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
290
|
+
"""Samples timesteps and noise for the diffusion process.
|
|
291
|
+
|
|
292
|
+
Generates random timesteps and Gaussian noise for use in the forward diffusion process.
|
|
293
|
+
|
|
294
|
+
Parameters
|
|
295
|
+
----------
|
|
296
|
+
`batch_size` : int
|
|
297
|
+
Number of samples in the batch.
|
|
298
|
+
`image_shape` : torch.Size
|
|
299
|
+
Shape of the images, typically (batch_size, channels, height, width).
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
t : torch.Tensor
|
|
304
|
+
Sampled timestep indices, shape (batch_size,).
|
|
305
|
+
noise : torch.Tensor
|
|
306
|
+
Sampled Gaussian noise, shape (batch_size, channels, height, width).
|
|
307
|
+
"""
|
|
308
|
+
# sample timestep t ~ Uniform(1, T)
|
|
309
|
+
t = torch.randint(0, self.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
|
|
310
|
+
# sample noise ε ~ N(0, I)
|
|
311
|
+
noise = torch.randn(image_shape, device=self.device)
|
|
312
|
+
return t, noise
|