TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
@@ -0,0 +1,366 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import unittest
4
+ from PIL import Image
5
+ import os
6
+ import numpy as np
7
+ from typing import Optional
8
+ from torchdiff.unclip import(
9
+ VarianceSchedulerUnCLIP, ForwardUnCLIP, ReverseUnCLIP,
10
+ CLIPEncoder, CLIPContextProjection, CLIPEmbeddingProjection,
11
+ UnCLIPTransformerPrior, TrainUnCLIPPrior,
12
+ UnClipDecoder, TrainUnClipDecoder,
13
+ UpsamplerUnCLIP, TrainUpsamplerUnCLIP,
14
+ SampleUnCLIP
15
+ )
16
+
17
+
18
+ # Mock Noise Predictor for UnClipDecoder and UpsamplerUnCLIP
19
+ class MockNoisePredictor(nn.Module):
20
+ def __init__(self, in_channels: int = 3, out_channels: int = 3):
21
+ super().__init__()
22
+ self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
23
+
24
+ def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None,
25
+ clip_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
26
+ return self.conv(x)
27
+
28
+
29
+ # Mock GLIDE Text Encoder
30
+ class MockGLIDETextEncoder(nn.Module):
31
+ def __init__(self, embedding_dim: int = 512, max_length: int = 77):
32
+ super().__init__()
33
+ self.embedding_dim = embedding_dim
34
+ self.max_length = max_length
35
+
36
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
37
+ batch_size = input_ids.shape[0]
38
+ return torch.randn(batch_size, self.max_length, self.embedding_dim, device=input_ids.device)
39
+
40
+
41
+ # Mock Metric (Loss Function)
42
+ class MockMetric:
43
+ def __call__(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
44
+ return torch.mean((pred - target) ** 2)
45
+
46
+
47
+ # Mock DataLoader
48
+ class MockDataLoader:
49
+ def __init__(self, batch_size: int, low_res_size: int = 64, high_res_size: int = 256):
50
+ self.batch_size = batch_size
51
+ self.low_res_size = low_res_size
52
+ self.high_res_size = high_res_size
53
+
54
+ def __iter__(self):
55
+ return self
56
+
57
+ def __next__(self):
58
+ low_res_images = torch.randn(self.batch_size, 3, self.low_res_size, self.low_res_size)
59
+ high_res_images = torch.randn(self.batch_size, 3, self.high_res_size, self.high_res_size)
60
+ return low_res_images, high_res_images
61
+
62
+
63
+ class TestUnCLIP(unittest.TestCase):
64
+ def setUp(self):
65
+ self.device = torch.device("cpu") # Use CPU for testing to avoid CUDA dependency
66
+ self.batch_size = 2
67
+ self.clip_embedding_dim = 512
68
+ self.image_size = (3, 64, 64)
69
+ self.high_res_size = 256
70
+ self.tau_num_steps = 10
71
+ self.num_steps = 50
72
+
73
+ # Initialize variance scheduler
74
+ self.variance_scheduler = VarianceSchedulerUnCLIP(
75
+ num_steps=self.num_steps,
76
+ tau_num_steps=self.tau_num_steps,
77
+ beta_start=1e-4,
78
+ beta_end=0.02,
79
+ beta_method="linear"
80
+ ).to(self.device)
81
+
82
+ # Initialize forward and reverse diffusion
83
+ self.forward_diffusion = ForwardUnCLIP(self.variance_scheduler).to(self.device)
84
+ self.reverse_diffusion = ReverseUnCLIP(self.variance_scheduler, prediction_type="noise").to(self.device)
85
+
86
+ # Initialize mock components
87
+ self.noise_predictor = MockNoisePredictor(in_channels=3, out_channels=3).to(self.device)
88
+ self.glide_text_encoder = MockGLIDETextEncoder(embedding_dim=self.clip_embedding_dim).to(self.device)
89
+ self.metric = MockMetric()
90
+
91
+ def test_variance_scheduler_unclip(self):
92
+ # Test initialization
93
+ self.assertEqual(self.variance_scheduler.num_steps, self.num_steps)
94
+ self.assertEqual(self.variance_scheduler.tau_num_steps, self.tau_num_steps)
95
+ self.assertEqual(self.variance_scheduler.betas.shape, (self.num_steps,))
96
+ self.assertTrue(torch.all(self.variance_scheduler.betas >= 1e-4))
97
+ self.assertTrue(torch.all(self.variance_scheduler.betas <= 0.02))
98
+
99
+ # Test beta schedule computation
100
+ betas = self.variance_scheduler.compute_beta_schedule((1e-4, 0.02), self.num_steps, "linear")
101
+ self.assertEqual(betas.shape, (self.num_steps,))
102
+ self.assertTrue(torch.all(betas >= 1e-4))
103
+ self.assertTrue(torch.all(betas <= 0.02))
104
+
105
+ # Test tau schedule
106
+ tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = \
107
+ self.variance_scheduler.get_tau_schedule()
108
+ self.assertEqual(tau_betas.shape, (self.tau_num_steps,))
109
+ self.assertEqual(tau_alphas.shape, (self.tau_num_steps,))
110
+ self.assertEqual(tau_alpha_cumprod.shape, (self.tau_num_steps,))
111
+
112
+ def test_forward_unclip(self):
113
+ # Test forward diffusion
114
+ x0 = torch.randn(self.batch_size, *self.image_size).to(self.device)
115
+ noise = torch.randn_like(x0)
116
+ time_steps = torch.randint(0, self.num_steps, (self.batch_size,), device=self.device)
117
+ xt = self.forward_diffusion(x0, noise, time_steps)
118
+ self.assertEqual(xt.shape, x0.shape)
119
+ self.assertTrue(torch.all(torch.isfinite(xt)))
120
+
121
+ # Test 2D input (latent embeddings)
122
+ x0_2d = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
123
+ noise_2d = torch.randn_like(x0_2d)
124
+ xt_2d = self.forward_diffusion(x0_2d, noise_2d, time_steps)
125
+ self.assertEqual(xt_2d.shape, x0_2d.shape)
126
+
127
+ # Test invalid time_steps
128
+ with self.assertRaises(ValueError):
129
+ invalid_time_steps = torch.tensor([self.num_steps], device=self.device)
130
+ self.forward_diffusion(x0, noise, invalid_time_steps)
131
+
132
+ def test_reverse_unclip(self):
133
+ # Test reverse diffusion (noise prediction)
134
+ xt = torch.randn(self.batch_size, *self.image_size).to(self.device)
135
+ model_prediction = torch.randn_like(xt)
136
+ time_steps = torch.randint(0, self.tau_num_steps, (self.batch_size,), device=self.device)
137
+ prev_time_steps = torch.max(time_steps - 1, torch.tensor(0, device=self.device))
138
+ xt_prev, x0 = self.reverse_diffusion(xt, model_prediction, time_steps, prev_time_steps)
139
+ self.assertEqual(xt_prev.shape, xt.shape)
140
+ self.assertEqual(x0.shape, xt.shape)
141
+ self.assertTrue(torch.all(torch.isfinite(xt_prev)))
142
+ self.assertTrue(torch.all(torch.isfinite(x0)))
143
+
144
+ # Test x0 prediction
145
+ self.reverse_diffusion.set_prediction_type("x0")
146
+ xt_prev, x0 = self.reverse_diffusion(xt, model_prediction, time_steps, prev_time_steps)
147
+ self.assertEqual(xt_prev.shape, xt.shape)
148
+ self.assertEqual(x0.shape, xt.shape)
149
+
150
+ # Test invalid time_steps
151
+ with self.assertRaises(ValueError):
152
+ invalid_time_steps = torch.tensor([self.tau_num_steps], device=self.device)
153
+ self.reverse_diffusion(xt, model_prediction, invalid_time_steps, prev_time_steps)
154
+
155
+ def test_clip_encoder(self):
156
+ # Initialize CLIPEncoder with a mocked CLIP model
157
+ class MockCLIPModel(nn.Module):
158
+ def __init__(self):
159
+ super().__init__()
160
+
161
+ def get_image_features(self, pixel_values):
162
+ return torch.randn(pixel_values.shape[0], self.clip_embedding_dim)
163
+
164
+ def get_text_features(self, input_ids, attention_mask=None):
165
+ return torch.randn(input_ids.shape[0], self.clip_embedding_dim)
166
+
167
+ class MockCLIPProcessor:
168
+ def __init__(self):
169
+ pass
170
+
171
+ def __call__(self, images=None, text=None, return_tensors="pt", padding=True, truncation=True):
172
+ if images:
173
+ return {"pixel_values": torch.randn(len(images), 3, 224, 224)}
174
+ if text:
175
+ return {"input_ids": torch.randint(0, 1000, (len(text), 77)),
176
+ "attention_mask": torch.ones(len(text), 77)}
177
+
178
+ clip_encoder = CLIPEncoder(model_name="mock", device=self.device)
179
+ clip_encoder.model = MockCLIPModel().to(self.device)
180
+ clip_encoder.processor = MockCLIPProcessor()
181
+
182
+ # Test image encoding
183
+ images = [Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) for _ in
184
+ range(self.batch_size)]
185
+ image_embeddings = clip_encoder(images, data_type="img")
186
+ self.assertEqual(image_embeddings.shape, (self.batch_size, self.clip_embedding_dim))
187
+
188
+ # Test text encoding
189
+ texts = ["A test prompt"] * self.batch_size
190
+ text_embeddings = clip_encoder(texts, data_type="text")
191
+ self.assertEqual(text_embeddings.shape, (self.batch_size, self.clip_embedding_dim))
192
+
193
+ # Test similarity
194
+ similarity = clip_encoder.compute_similarity(image_embeddings, text_embeddings)
195
+ self.assertEqual(similarity.shape, (self.batch_size, self.batch_size))
196
+
197
+ def test_unclip_decoder(self):
198
+ # Initialize UnClipDecoder
199
+ decoder = UnClipDecoder(
200
+ clip_embedding_dim=self.clip_embedding_dim,
201
+ noise_predictor=self.noise_predictor,
202
+ forward_diffusion=self.forward_diffusion,
203
+ reverse_diffusion=self.reverse_diffusion,
204
+ glide_text_encoder=self.glide_text_encoder,
205
+ device=self.device
206
+ )
207
+
208
+ # Test forward pass
209
+ image_embeddings = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
210
+ text_embeddings = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
211
+ images = torch.randn(self.batch_size, *self.image_size).to(self.device)
212
+ texts = ["test prompt"] * self.batch_size
213
+ predicted_noise, noise = decoder(image_embeddings, text_embeddings, images, texts, p_classifier_free=0.0,
214
+ p_text_drop=0.0)
215
+ self.assertEqual(predicted_noise.shape, images.shape)
216
+ self.assertEqual(noise.shape, images.shape)
217
+
218
+ # Test classifier-free guidance
219
+ modified_embeddings = decoder._apply_classifier_free_guidance(image_embeddings, p_value=0.05)
220
+ self.assertEqual(modified_embeddings.shape, image_embeddings.shape)
221
+
222
+ # Test text dropout
223
+ dropped_embeddings = decoder._apply_text_dropout(text_embeddings, p_value=0.6)
224
+ self.assertIsNone(dropped_embeddings)
225
+
226
+ def test_unclip_transformer_prior(self):
227
+ # Initialize UnCLIPTransformerPrior
228
+ prior = UnCLIPTransformerPrior(
229
+ forward_diffusion=self.forward_diffusion,
230
+ reverse_diffusion=self.reverse_diffusion,
231
+ clip_text_projection=None,
232
+ clip_image_projection=None,
233
+ transformer_embedding_dim=self.clip_embedding_dim
234
+ ).to(self.device)
235
+
236
+ # Test forward pass
237
+ text_embeddings = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
238
+ noisy_image_embeddings = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
239
+ timesteps = torch.randint(0, self.num_steps, (self.batch_size,), device=self.device)
240
+ predicted_embeddings = prior(text_embeddings, noisy_image_embeddings, timesteps)
241
+ self.assertEqual(predicted_embeddings.shape, (self.batch_size, self.clip_embedding_dim))
242
+
243
+ def test_clip_context_projection(self):
244
+ # Initialize CLIPContextProjection
245
+ projection = CLIPContextProjection(clip_embedding_dim=self.clip_embedding_dim, num_tokens=4).to(self.device)
246
+
247
+ # Test forward pass
248
+ z_i = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
249
+ c = projection(z_i)
250
+ self.assertEqual(c.shape, (self.batch_size, 4, self.clip_embedding_dim))
251
+
252
+ def test_clip_embedding_projection(self):
253
+ # Initialize CLIPEmbeddingProjection
254
+ projection = CLIPEmbeddingProjection(
255
+ clip_embedding_dim=self.clip_embedding_dim,
256
+ transformer_embedding_dim=320
257
+ ).to(self.device)
258
+
259
+ # Test forward and inverse transform
260
+ x = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
261
+ x_reduced = projection(x)
262
+ self.assertEqual(x_reduced.shape, (self.batch_size, 320))
263
+ x_reconstructed = projection.inverse_transform(x_reduced)
264
+ self.assertEqual(x_reconstructed.shape, x.shape)
265
+
266
+ # Test reconstruction loss
267
+ loss = projection.reconstruction_loss(x)
268
+ self.assertTrue(torch.isfinite(loss))
269
+
270
+ def test_upsampler_unclip(self):
271
+ # Initialize UpsamplerUnCLIP
272
+ upsampler = UpsamplerUnCLIP(
273
+ forward_diffusion=self.forward_diffusion,
274
+ reverse_diffusion=self.reverse_diffusion,
275
+ in_channels=3,
276
+ out_channels=3,
277
+ model_channels=64,
278
+ num_res_blocks=2,
279
+ low_res_size=64,
280
+ high_res_size=256
281
+ ).to(self.device)
282
+
283
+ # Test forward pass
284
+ x_high = torch.randn(self.batch_size, 3, 256, 256).to(self.device)
285
+ t = torch.randint(0, self.tau_num_steps, (self.batch_size,), device=self.device)
286
+ x_low = torch.randn(self.batch_size, 3, 64, 64).to(self.device)
287
+ predicted_noise = upsampler(x_high, t, x_low)
288
+ self.assertEqual(predicted_noise.shape, (self.batch_size, 3, 256, 256))
289
+
290
+ def test_train_upsampler_unclip(self):
291
+ # Initialize TrainUpsamplerUnCLIP
292
+ upsampler = UpsamplerUnCLIP(
293
+ forward_diffusion=self.forward_diffusion,
294
+ reverse_diffusion=self.reverse_diffusion,
295
+ in_channels=3,
296
+ out_channels=3,
297
+ model_channels=64,
298
+ num_res_blocks=2
299
+ ).to(self.device)
300
+ train_loader = MockDataLoader(batch_size=self.batch_size)
301
+ optimizer = torch.optim.Adam(upsampler.parameters(), lr=1e-3)
302
+ trainer = TrainUpsamplerUnCLIP(
303
+ upsampler_model=upsampler,
304
+ train_loader=train_loader,
305
+ optimizer=optimizer,
306
+ objective=self.metric,
307
+ max_epochs=1,
308
+ device=self.device,
309
+ use_ddp=False,
310
+ use_autocast=False
311
+ )
312
+
313
+ # Test training
314
+ train_losses, best_val_loss = trainer()
315
+ self.assertTrue(len(train_losses) == 1)
316
+ self.assertTrue(isinstance(best_val_loss, float))
317
+
318
+ def test_sample_unclip(self):
319
+ # Initialize SampleUnCLIP
320
+ prior = UnCLIPTransformerPrior(
321
+ forward_diffusion=self.forward_diffusion,
322
+ reverse_diffusion=self.reverse_diffusion,
323
+ transformer_embedding_dim=self.clip_embedding_dim
324
+ ).to(self.device)
325
+ decoder = UnClipDecoder(
326
+ clip_embedding_dim=self.clip_embedding_dim,
327
+ noise_predictor=self.noise_predictor,
328
+ forward_diffusion=self.forward_diffusion,
329
+ reverse_diffusion=self.reverse_diffusion,
330
+ glide_text_encoder=self.glide_text_encoder,
331
+ device=self.device
332
+ )
333
+ clip_encoder = CLIPEncoder(model_name="mock", device=self.device)
334
+ clip_encoder.model = MockCLIPModel().to(self.device)
335
+ clip_encoder.processor = MockCLIPProcessor()
336
+ upsampler = UpsamplerUnCLIP(
337
+ forward_diffusion=self.forward_diffusion,
338
+ reverse_diffusion=self.reverse_diffusion,
339
+ in_channels=3,
340
+ out_channels=3,
341
+ model_channels=64
342
+ ).to(self.device)
343
+
344
+ sample_unclip = SampleUnCLIP(
345
+ prior_model=prior,
346
+ decoder_model=decoder,
347
+ clip_model=clip_encoder,
348
+ low_res_upsampler=upsampler,
349
+ second_upsampler_model=None,
350
+ device=self.device,
351
+ batch_size=self.batch_size
352
+ )
353
+
354
+ # Test full pipeline
355
+ prompts = ["A test image"] * self.batch_size
356
+ final_images = sample_unclip(prompts=prompts, save_images=False)
357
+ self.assertEqual(final_images.shape, (self.batch_size, 3, 256, 256))
358
+ self.assertTrue(torch.all(torch.isfinite(final_images)))
359
+
360
+ # Verify saved images (optional)
361
+ final_images = sample_unclip(prompts=prompts, save_images=True, save_path="test_output")
362
+ self.assertTrue(os.path.exists(os.path.join("test_output", "images_256", "image_1.png")))
363
+
364
+
365
+ if __name__ == "__main__":
366
+ unittest.main()