TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import unittest
|
|
4
|
+
from PIL import Image
|
|
5
|
+
import os
|
|
6
|
+
import numpy as np
|
|
7
|
+
from typing import Optional
|
|
8
|
+
from torchdiff.unclip import(
|
|
9
|
+
VarianceSchedulerUnCLIP, ForwardUnCLIP, ReverseUnCLIP,
|
|
10
|
+
CLIPEncoder, CLIPContextProjection, CLIPEmbeddingProjection,
|
|
11
|
+
UnCLIPTransformerPrior, TrainUnCLIPPrior,
|
|
12
|
+
UnClipDecoder, TrainUnClipDecoder,
|
|
13
|
+
UpsamplerUnCLIP, TrainUpsamplerUnCLIP,
|
|
14
|
+
SampleUnCLIP
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Mock Noise Predictor for UnClipDecoder and UpsamplerUnCLIP
|
|
19
|
+
class MockNoisePredictor(nn.Module):
|
|
20
|
+
def __init__(self, in_channels: int = 3, out_channels: int = 3):
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
|
23
|
+
|
|
24
|
+
def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None,
|
|
25
|
+
clip_embedding: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
26
|
+
return self.conv(x)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Mock GLIDE Text Encoder
|
|
30
|
+
class MockGLIDETextEncoder(nn.Module):
|
|
31
|
+
def __init__(self, embedding_dim: int = 512, max_length: int = 77):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.embedding_dim = embedding_dim
|
|
34
|
+
self.max_length = max_length
|
|
35
|
+
|
|
36
|
+
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
37
|
+
batch_size = input_ids.shape[0]
|
|
38
|
+
return torch.randn(batch_size, self.max_length, self.embedding_dim, device=input_ids.device)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Mock Metric (Loss Function)
|
|
42
|
+
class MockMetric:
|
|
43
|
+
def __call__(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
44
|
+
return torch.mean((pred - target) ** 2)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# Mock DataLoader
|
|
48
|
+
class MockDataLoader:
|
|
49
|
+
def __init__(self, batch_size: int, low_res_size: int = 64, high_res_size: int = 256):
|
|
50
|
+
self.batch_size = batch_size
|
|
51
|
+
self.low_res_size = low_res_size
|
|
52
|
+
self.high_res_size = high_res_size
|
|
53
|
+
|
|
54
|
+
def __iter__(self):
|
|
55
|
+
return self
|
|
56
|
+
|
|
57
|
+
def __next__(self):
|
|
58
|
+
low_res_images = torch.randn(self.batch_size, 3, self.low_res_size, self.low_res_size)
|
|
59
|
+
high_res_images = torch.randn(self.batch_size, 3, self.high_res_size, self.high_res_size)
|
|
60
|
+
return low_res_images, high_res_images
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class TestUnCLIP(unittest.TestCase):
|
|
64
|
+
def setUp(self):
|
|
65
|
+
self.device = torch.device("cpu") # Use CPU for testing to avoid CUDA dependency
|
|
66
|
+
self.batch_size = 2
|
|
67
|
+
self.clip_embedding_dim = 512
|
|
68
|
+
self.image_size = (3, 64, 64)
|
|
69
|
+
self.high_res_size = 256
|
|
70
|
+
self.tau_num_steps = 10
|
|
71
|
+
self.num_steps = 50
|
|
72
|
+
|
|
73
|
+
# Initialize variance scheduler
|
|
74
|
+
self.variance_scheduler = VarianceSchedulerUnCLIP(
|
|
75
|
+
num_steps=self.num_steps,
|
|
76
|
+
tau_num_steps=self.tau_num_steps,
|
|
77
|
+
beta_start=1e-4,
|
|
78
|
+
beta_end=0.02,
|
|
79
|
+
beta_method="linear"
|
|
80
|
+
).to(self.device)
|
|
81
|
+
|
|
82
|
+
# Initialize forward and reverse diffusion
|
|
83
|
+
self.forward_diffusion = ForwardUnCLIP(self.variance_scheduler).to(self.device)
|
|
84
|
+
self.reverse_diffusion = ReverseUnCLIP(self.variance_scheduler, prediction_type="noise").to(self.device)
|
|
85
|
+
|
|
86
|
+
# Initialize mock components
|
|
87
|
+
self.noise_predictor = MockNoisePredictor(in_channels=3, out_channels=3).to(self.device)
|
|
88
|
+
self.glide_text_encoder = MockGLIDETextEncoder(embedding_dim=self.clip_embedding_dim).to(self.device)
|
|
89
|
+
self.metric = MockMetric()
|
|
90
|
+
|
|
91
|
+
def test_variance_scheduler_unclip(self):
|
|
92
|
+
# Test initialization
|
|
93
|
+
self.assertEqual(self.variance_scheduler.num_steps, self.num_steps)
|
|
94
|
+
self.assertEqual(self.variance_scheduler.tau_num_steps, self.tau_num_steps)
|
|
95
|
+
self.assertEqual(self.variance_scheduler.betas.shape, (self.num_steps,))
|
|
96
|
+
self.assertTrue(torch.all(self.variance_scheduler.betas >= 1e-4))
|
|
97
|
+
self.assertTrue(torch.all(self.variance_scheduler.betas <= 0.02))
|
|
98
|
+
|
|
99
|
+
# Test beta schedule computation
|
|
100
|
+
betas = self.variance_scheduler.compute_beta_schedule((1e-4, 0.02), self.num_steps, "linear")
|
|
101
|
+
self.assertEqual(betas.shape, (self.num_steps,))
|
|
102
|
+
self.assertTrue(torch.all(betas >= 1e-4))
|
|
103
|
+
self.assertTrue(torch.all(betas <= 0.02))
|
|
104
|
+
|
|
105
|
+
# Test tau schedule
|
|
106
|
+
tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = \
|
|
107
|
+
self.variance_scheduler.get_tau_schedule()
|
|
108
|
+
self.assertEqual(tau_betas.shape, (self.tau_num_steps,))
|
|
109
|
+
self.assertEqual(tau_alphas.shape, (self.tau_num_steps,))
|
|
110
|
+
self.assertEqual(tau_alpha_cumprod.shape, (self.tau_num_steps,))
|
|
111
|
+
|
|
112
|
+
def test_forward_unclip(self):
|
|
113
|
+
# Test forward diffusion
|
|
114
|
+
x0 = torch.randn(self.batch_size, *self.image_size).to(self.device)
|
|
115
|
+
noise = torch.randn_like(x0)
|
|
116
|
+
time_steps = torch.randint(0, self.num_steps, (self.batch_size,), device=self.device)
|
|
117
|
+
xt = self.forward_diffusion(x0, noise, time_steps)
|
|
118
|
+
self.assertEqual(xt.shape, x0.shape)
|
|
119
|
+
self.assertTrue(torch.all(torch.isfinite(xt)))
|
|
120
|
+
|
|
121
|
+
# Test 2D input (latent embeddings)
|
|
122
|
+
x0_2d = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
|
|
123
|
+
noise_2d = torch.randn_like(x0_2d)
|
|
124
|
+
xt_2d = self.forward_diffusion(x0_2d, noise_2d, time_steps)
|
|
125
|
+
self.assertEqual(xt_2d.shape, x0_2d.shape)
|
|
126
|
+
|
|
127
|
+
# Test invalid time_steps
|
|
128
|
+
with self.assertRaises(ValueError):
|
|
129
|
+
invalid_time_steps = torch.tensor([self.num_steps], device=self.device)
|
|
130
|
+
self.forward_diffusion(x0, noise, invalid_time_steps)
|
|
131
|
+
|
|
132
|
+
def test_reverse_unclip(self):
|
|
133
|
+
# Test reverse diffusion (noise prediction)
|
|
134
|
+
xt = torch.randn(self.batch_size, *self.image_size).to(self.device)
|
|
135
|
+
model_prediction = torch.randn_like(xt)
|
|
136
|
+
time_steps = torch.randint(0, self.tau_num_steps, (self.batch_size,), device=self.device)
|
|
137
|
+
prev_time_steps = torch.max(time_steps - 1, torch.tensor(0, device=self.device))
|
|
138
|
+
xt_prev, x0 = self.reverse_diffusion(xt, model_prediction, time_steps, prev_time_steps)
|
|
139
|
+
self.assertEqual(xt_prev.shape, xt.shape)
|
|
140
|
+
self.assertEqual(x0.shape, xt.shape)
|
|
141
|
+
self.assertTrue(torch.all(torch.isfinite(xt_prev)))
|
|
142
|
+
self.assertTrue(torch.all(torch.isfinite(x0)))
|
|
143
|
+
|
|
144
|
+
# Test x0 prediction
|
|
145
|
+
self.reverse_diffusion.set_prediction_type("x0")
|
|
146
|
+
xt_prev, x0 = self.reverse_diffusion(xt, model_prediction, time_steps, prev_time_steps)
|
|
147
|
+
self.assertEqual(xt_prev.shape, xt.shape)
|
|
148
|
+
self.assertEqual(x0.shape, xt.shape)
|
|
149
|
+
|
|
150
|
+
# Test invalid time_steps
|
|
151
|
+
with self.assertRaises(ValueError):
|
|
152
|
+
invalid_time_steps = torch.tensor([self.tau_num_steps], device=self.device)
|
|
153
|
+
self.reverse_diffusion(xt, model_prediction, invalid_time_steps, prev_time_steps)
|
|
154
|
+
|
|
155
|
+
def test_clip_encoder(self):
|
|
156
|
+
# Initialize CLIPEncoder with a mocked CLIP model
|
|
157
|
+
class MockCLIPModel(nn.Module):
|
|
158
|
+
def __init__(self):
|
|
159
|
+
super().__init__()
|
|
160
|
+
|
|
161
|
+
def get_image_features(self, pixel_values):
|
|
162
|
+
return torch.randn(pixel_values.shape[0], self.clip_embedding_dim)
|
|
163
|
+
|
|
164
|
+
def get_text_features(self, input_ids, attention_mask=None):
|
|
165
|
+
return torch.randn(input_ids.shape[0], self.clip_embedding_dim)
|
|
166
|
+
|
|
167
|
+
class MockCLIPProcessor:
|
|
168
|
+
def __init__(self):
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
def __call__(self, images=None, text=None, return_tensors="pt", padding=True, truncation=True):
|
|
172
|
+
if images:
|
|
173
|
+
return {"pixel_values": torch.randn(len(images), 3, 224, 224)}
|
|
174
|
+
if text:
|
|
175
|
+
return {"input_ids": torch.randint(0, 1000, (len(text), 77)),
|
|
176
|
+
"attention_mask": torch.ones(len(text), 77)}
|
|
177
|
+
|
|
178
|
+
clip_encoder = CLIPEncoder(model_name="mock", device=self.device)
|
|
179
|
+
clip_encoder.model = MockCLIPModel().to(self.device)
|
|
180
|
+
clip_encoder.processor = MockCLIPProcessor()
|
|
181
|
+
|
|
182
|
+
# Test image encoding
|
|
183
|
+
images = [Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)) for _ in
|
|
184
|
+
range(self.batch_size)]
|
|
185
|
+
image_embeddings = clip_encoder(images, data_type="img")
|
|
186
|
+
self.assertEqual(image_embeddings.shape, (self.batch_size, self.clip_embedding_dim))
|
|
187
|
+
|
|
188
|
+
# Test text encoding
|
|
189
|
+
texts = ["A test prompt"] * self.batch_size
|
|
190
|
+
text_embeddings = clip_encoder(texts, data_type="text")
|
|
191
|
+
self.assertEqual(text_embeddings.shape, (self.batch_size, self.clip_embedding_dim))
|
|
192
|
+
|
|
193
|
+
# Test similarity
|
|
194
|
+
similarity = clip_encoder.compute_similarity(image_embeddings, text_embeddings)
|
|
195
|
+
self.assertEqual(similarity.shape, (self.batch_size, self.batch_size))
|
|
196
|
+
|
|
197
|
+
def test_unclip_decoder(self):
|
|
198
|
+
# Initialize UnClipDecoder
|
|
199
|
+
decoder = UnClipDecoder(
|
|
200
|
+
clip_embedding_dim=self.clip_embedding_dim,
|
|
201
|
+
noise_predictor=self.noise_predictor,
|
|
202
|
+
forward_diffusion=self.forward_diffusion,
|
|
203
|
+
reverse_diffusion=self.reverse_diffusion,
|
|
204
|
+
glide_text_encoder=self.glide_text_encoder,
|
|
205
|
+
device=self.device
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Test forward pass
|
|
209
|
+
image_embeddings = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
|
|
210
|
+
text_embeddings = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
|
|
211
|
+
images = torch.randn(self.batch_size, *self.image_size).to(self.device)
|
|
212
|
+
texts = ["test prompt"] * self.batch_size
|
|
213
|
+
predicted_noise, noise = decoder(image_embeddings, text_embeddings, images, texts, p_classifier_free=0.0,
|
|
214
|
+
p_text_drop=0.0)
|
|
215
|
+
self.assertEqual(predicted_noise.shape, images.shape)
|
|
216
|
+
self.assertEqual(noise.shape, images.shape)
|
|
217
|
+
|
|
218
|
+
# Test classifier-free guidance
|
|
219
|
+
modified_embeddings = decoder._apply_classifier_free_guidance(image_embeddings, p_value=0.05)
|
|
220
|
+
self.assertEqual(modified_embeddings.shape, image_embeddings.shape)
|
|
221
|
+
|
|
222
|
+
# Test text dropout
|
|
223
|
+
dropped_embeddings = decoder._apply_text_dropout(text_embeddings, p_value=0.6)
|
|
224
|
+
self.assertIsNone(dropped_embeddings)
|
|
225
|
+
|
|
226
|
+
def test_unclip_transformer_prior(self):
|
|
227
|
+
# Initialize UnCLIPTransformerPrior
|
|
228
|
+
prior = UnCLIPTransformerPrior(
|
|
229
|
+
forward_diffusion=self.forward_diffusion,
|
|
230
|
+
reverse_diffusion=self.reverse_diffusion,
|
|
231
|
+
clip_text_projection=None,
|
|
232
|
+
clip_image_projection=None,
|
|
233
|
+
transformer_embedding_dim=self.clip_embedding_dim
|
|
234
|
+
).to(self.device)
|
|
235
|
+
|
|
236
|
+
# Test forward pass
|
|
237
|
+
text_embeddings = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
|
|
238
|
+
noisy_image_embeddings = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
|
|
239
|
+
timesteps = torch.randint(0, self.num_steps, (self.batch_size,), device=self.device)
|
|
240
|
+
predicted_embeddings = prior(text_embeddings, noisy_image_embeddings, timesteps)
|
|
241
|
+
self.assertEqual(predicted_embeddings.shape, (self.batch_size, self.clip_embedding_dim))
|
|
242
|
+
|
|
243
|
+
def test_clip_context_projection(self):
|
|
244
|
+
# Initialize CLIPContextProjection
|
|
245
|
+
projection = CLIPContextProjection(clip_embedding_dim=self.clip_embedding_dim, num_tokens=4).to(self.device)
|
|
246
|
+
|
|
247
|
+
# Test forward pass
|
|
248
|
+
z_i = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
|
|
249
|
+
c = projection(z_i)
|
|
250
|
+
self.assertEqual(c.shape, (self.batch_size, 4, self.clip_embedding_dim))
|
|
251
|
+
|
|
252
|
+
def test_clip_embedding_projection(self):
|
|
253
|
+
# Initialize CLIPEmbeddingProjection
|
|
254
|
+
projection = CLIPEmbeddingProjection(
|
|
255
|
+
clip_embedding_dim=self.clip_embedding_dim,
|
|
256
|
+
transformer_embedding_dim=320
|
|
257
|
+
).to(self.device)
|
|
258
|
+
|
|
259
|
+
# Test forward and inverse transform
|
|
260
|
+
x = torch.randn(self.batch_size, self.clip_embedding_dim).to(self.device)
|
|
261
|
+
x_reduced = projection(x)
|
|
262
|
+
self.assertEqual(x_reduced.shape, (self.batch_size, 320))
|
|
263
|
+
x_reconstructed = projection.inverse_transform(x_reduced)
|
|
264
|
+
self.assertEqual(x_reconstructed.shape, x.shape)
|
|
265
|
+
|
|
266
|
+
# Test reconstruction loss
|
|
267
|
+
loss = projection.reconstruction_loss(x)
|
|
268
|
+
self.assertTrue(torch.isfinite(loss))
|
|
269
|
+
|
|
270
|
+
def test_upsampler_unclip(self):
|
|
271
|
+
# Initialize UpsamplerUnCLIP
|
|
272
|
+
upsampler = UpsamplerUnCLIP(
|
|
273
|
+
forward_diffusion=self.forward_diffusion,
|
|
274
|
+
reverse_diffusion=self.reverse_diffusion,
|
|
275
|
+
in_channels=3,
|
|
276
|
+
out_channels=3,
|
|
277
|
+
model_channels=64,
|
|
278
|
+
num_res_blocks=2,
|
|
279
|
+
low_res_size=64,
|
|
280
|
+
high_res_size=256
|
|
281
|
+
).to(self.device)
|
|
282
|
+
|
|
283
|
+
# Test forward pass
|
|
284
|
+
x_high = torch.randn(self.batch_size, 3, 256, 256).to(self.device)
|
|
285
|
+
t = torch.randint(0, self.tau_num_steps, (self.batch_size,), device=self.device)
|
|
286
|
+
x_low = torch.randn(self.batch_size, 3, 64, 64).to(self.device)
|
|
287
|
+
predicted_noise = upsampler(x_high, t, x_low)
|
|
288
|
+
self.assertEqual(predicted_noise.shape, (self.batch_size, 3, 256, 256))
|
|
289
|
+
|
|
290
|
+
def test_train_upsampler_unclip(self):
|
|
291
|
+
# Initialize TrainUpsamplerUnCLIP
|
|
292
|
+
upsampler = UpsamplerUnCLIP(
|
|
293
|
+
forward_diffusion=self.forward_diffusion,
|
|
294
|
+
reverse_diffusion=self.reverse_diffusion,
|
|
295
|
+
in_channels=3,
|
|
296
|
+
out_channels=3,
|
|
297
|
+
model_channels=64,
|
|
298
|
+
num_res_blocks=2
|
|
299
|
+
).to(self.device)
|
|
300
|
+
train_loader = MockDataLoader(batch_size=self.batch_size)
|
|
301
|
+
optimizer = torch.optim.Adam(upsampler.parameters(), lr=1e-3)
|
|
302
|
+
trainer = TrainUpsamplerUnCLIP(
|
|
303
|
+
upsampler_model=upsampler,
|
|
304
|
+
train_loader=train_loader,
|
|
305
|
+
optimizer=optimizer,
|
|
306
|
+
objective=self.metric,
|
|
307
|
+
max_epochs=1,
|
|
308
|
+
device=self.device,
|
|
309
|
+
use_ddp=False,
|
|
310
|
+
use_autocast=False
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Test training
|
|
314
|
+
train_losses, best_val_loss = trainer()
|
|
315
|
+
self.assertTrue(len(train_losses) == 1)
|
|
316
|
+
self.assertTrue(isinstance(best_val_loss, float))
|
|
317
|
+
|
|
318
|
+
def test_sample_unclip(self):
|
|
319
|
+
# Initialize SampleUnCLIP
|
|
320
|
+
prior = UnCLIPTransformerPrior(
|
|
321
|
+
forward_diffusion=self.forward_diffusion,
|
|
322
|
+
reverse_diffusion=self.reverse_diffusion,
|
|
323
|
+
transformer_embedding_dim=self.clip_embedding_dim
|
|
324
|
+
).to(self.device)
|
|
325
|
+
decoder = UnClipDecoder(
|
|
326
|
+
clip_embedding_dim=self.clip_embedding_dim,
|
|
327
|
+
noise_predictor=self.noise_predictor,
|
|
328
|
+
forward_diffusion=self.forward_diffusion,
|
|
329
|
+
reverse_diffusion=self.reverse_diffusion,
|
|
330
|
+
glide_text_encoder=self.glide_text_encoder,
|
|
331
|
+
device=self.device
|
|
332
|
+
)
|
|
333
|
+
clip_encoder = CLIPEncoder(model_name="mock", device=self.device)
|
|
334
|
+
clip_encoder.model = MockCLIPModel().to(self.device)
|
|
335
|
+
clip_encoder.processor = MockCLIPProcessor()
|
|
336
|
+
upsampler = UpsamplerUnCLIP(
|
|
337
|
+
forward_diffusion=self.forward_diffusion,
|
|
338
|
+
reverse_diffusion=self.reverse_diffusion,
|
|
339
|
+
in_channels=3,
|
|
340
|
+
out_channels=3,
|
|
341
|
+
model_channels=64
|
|
342
|
+
).to(self.device)
|
|
343
|
+
|
|
344
|
+
sample_unclip = SampleUnCLIP(
|
|
345
|
+
prior_model=prior,
|
|
346
|
+
decoder_model=decoder,
|
|
347
|
+
clip_model=clip_encoder,
|
|
348
|
+
low_res_upsampler=upsampler,
|
|
349
|
+
second_upsampler_model=None,
|
|
350
|
+
device=self.device,
|
|
351
|
+
batch_size=self.batch_size
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Test full pipeline
|
|
355
|
+
prompts = ["A test image"] * self.batch_size
|
|
356
|
+
final_images = sample_unclip(prompts=prompts, save_images=False)
|
|
357
|
+
self.assertEqual(final_images.shape, (self.batch_size, 3, 256, 256))
|
|
358
|
+
self.assertTrue(torch.all(torch.isfinite(final_images)))
|
|
359
|
+
|
|
360
|
+
# Verify saved images (optional)
|
|
361
|
+
final_images = sample_unclip(prompts=prompts, save_images=True, save_path="test_output")
|
|
362
|
+
self.assertTrue(os.path.exists(os.path.join("test_output", "images_256", "image_1.png")))
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
if __name__ == "__main__":
|
|
366
|
+
unittest.main()
|