TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
|
File without changes
|
|
@@ -0,0 +1,551 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Comprehensive Test Suite for DDIM Implementation
|
|
3
|
+
|
|
4
|
+
This test suite validates the core components of the DDIM (Denoising Diffusion Implicit Models)
|
|
5
|
+
implementation including forward/reverse diffusion, variance scheduling, training, and sampling.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import numpy as np
|
|
12
|
+
import tempfile
|
|
13
|
+
import os
|
|
14
|
+
import shutil
|
|
15
|
+
from typing import Tuple, List
|
|
16
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
17
|
+
from transformers import BertTokenizer
|
|
18
|
+
from torchdiff.ddim import (
|
|
19
|
+
ForwardDDIM,
|
|
20
|
+
ReverseDDIM,
|
|
21
|
+
VarianceSchedulerDDIM,
|
|
22
|
+
TrainDDIM,
|
|
23
|
+
SampleDDIM
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Mock implementations for testing
|
|
29
|
+
class MockNoisePredictor(nn.Module):
|
|
30
|
+
"""Mock noise predictor for testing"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, in_channels=1, time_embed_dim=32):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
|
|
35
|
+
self.time_embed = nn.Linear(1, time_embed_dim)
|
|
36
|
+
|
|
37
|
+
def forward(self, x, t, y_encoded=None, *args):
|
|
38
|
+
# Simple implementation for testing
|
|
39
|
+
batch_size = x.shape[0]
|
|
40
|
+
t_embed = self.time_embed(t.float().unsqueeze(1))
|
|
41
|
+
noise = self.conv(x)
|
|
42
|
+
return noise
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class MockConditionalModel(nn.Module):
|
|
46
|
+
"""Mock conditional model for testing"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, embed_dim=32):
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.embed = nn.Embedding(1000, embed_dim)
|
|
51
|
+
|
|
52
|
+
def forward(self, input_ids, attention_mask=None):
|
|
53
|
+
return self.embed(input_ids[:, 0]) # Simple mock
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class MockMetrics:
|
|
57
|
+
"""Mock metrics class for testing"""
|
|
58
|
+
|
|
59
|
+
def __init__(self):
|
|
60
|
+
self.fid = True
|
|
61
|
+
self.metrics = True
|
|
62
|
+
self.lpips = True
|
|
63
|
+
|
|
64
|
+
def forward(self, x_real, x_fake):
|
|
65
|
+
# Return mock metric values
|
|
66
|
+
return 50.0, 0.1, 25.0, 0.8, 0.05 # fid, mse, psnr, ssim, lpips
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class TestVarianceSchedulerDDIM:
|
|
70
|
+
"""Test cases for VarianceSchedulerDDIM"""
|
|
71
|
+
|
|
72
|
+
def test_initialization_valid_params(self):
|
|
73
|
+
"""Test scheduler initialization with valid parameters"""
|
|
74
|
+
scheduler = VarianceSchedulerDDIM(
|
|
75
|
+
num_steps=100,
|
|
76
|
+
tau_num_steps=10,
|
|
77
|
+
beta_start=1e-4,
|
|
78
|
+
beta_end=0.02
|
|
79
|
+
)
|
|
80
|
+
assert scheduler.num_steps == 100
|
|
81
|
+
assert scheduler.tau_num_steps == 10
|
|
82
|
+
assert scheduler.beta_start == 1e-4
|
|
83
|
+
assert scheduler.beta_end == 0.02
|
|
84
|
+
|
|
85
|
+
def test_initialization_invalid_params(self):
|
|
86
|
+
"""Test scheduler initialization with invalid parameters"""
|
|
87
|
+
with pytest.raises(ValueError):
|
|
88
|
+
# Invalid beta range
|
|
89
|
+
VarianceSchedulerDDIM(beta_start=0.02, beta_end=1e-4)
|
|
90
|
+
|
|
91
|
+
with pytest.raises(ValueError):
|
|
92
|
+
# Invalid num_steps
|
|
93
|
+
VarianceSchedulerDDIM(num_steps=0)
|
|
94
|
+
|
|
95
|
+
def test_beta_schedule_methods(self):
|
|
96
|
+
"""Test different beta scheduling methods"""
|
|
97
|
+
methods = ["linear", "sigmoid", "quadratic", "constant", "inverse_time"]
|
|
98
|
+
|
|
99
|
+
for method in methods:
|
|
100
|
+
scheduler = VarianceSchedulerDDIM(
|
|
101
|
+
num_steps=100,
|
|
102
|
+
beta_method=method
|
|
103
|
+
)
|
|
104
|
+
betas = scheduler.betas
|
|
105
|
+
assert len(betas) == 100
|
|
106
|
+
assert torch.all(betas >= scheduler.beta_start)
|
|
107
|
+
assert torch.all(betas <= scheduler.beta_end)
|
|
108
|
+
|
|
109
|
+
def test_trainable_beta(self):
|
|
110
|
+
"""Test trainable beta functionality"""
|
|
111
|
+
scheduler = VarianceSchedulerDDIM(
|
|
112
|
+
num_steps=100,
|
|
113
|
+
trainable_beta=True
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Check that beta_raw is a parameter
|
|
117
|
+
assert hasattr(scheduler, 'beta_raw')
|
|
118
|
+
assert isinstance(scheduler.beta_raw, nn.Parameter)
|
|
119
|
+
|
|
120
|
+
# Check that betas are in valid range
|
|
121
|
+
betas = scheduler.betas
|
|
122
|
+
assert torch.all(betas >= scheduler.beta_start)
|
|
123
|
+
assert torch.all(betas <= scheduler.beta_end)
|
|
124
|
+
|
|
125
|
+
def test_tau_schedule(self):
|
|
126
|
+
"""Test subsampled (tau) schedule generation"""
|
|
127
|
+
scheduler = VarianceSchedulerDDIM(
|
|
128
|
+
num_steps=1000,
|
|
129
|
+
tau_num_steps=100
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = scheduler.get_tau_schedule()
|
|
133
|
+
|
|
134
|
+
assert len(tau_betas) == 100
|
|
135
|
+
assert len(tau_alphas) == 100
|
|
136
|
+
assert len(tau_alpha_cumprod) == 100
|
|
137
|
+
assert torch.all(tau_alphas == 1 - tau_betas)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class TestForwardDDIM:
|
|
141
|
+
"""Test cases for ForwardDDIM"""
|
|
142
|
+
|
|
143
|
+
def test_forward_process(self):
|
|
144
|
+
"""Test forward diffusion process"""
|
|
145
|
+
scheduler = VarianceSchedulerDDIM(num_steps=100)
|
|
146
|
+
forward_ddim = ForwardDDIM(scheduler)
|
|
147
|
+
|
|
148
|
+
batch_size, channels, height, width = 4, 1, 28, 28
|
|
149
|
+
x0 = torch.randn(batch_size, channels, height, width)
|
|
150
|
+
noise = torch.randn_like(x0)
|
|
151
|
+
time_steps = torch.randint(0, scheduler.num_steps, (batch_size,))
|
|
152
|
+
|
|
153
|
+
xt = forward_ddim(x0, noise, time_steps)
|
|
154
|
+
|
|
155
|
+
assert xt.shape == x0.shape
|
|
156
|
+
assert not torch.equal(xt, x0) # Should be different due to noise
|
|
157
|
+
|
|
158
|
+
def test_invalid_time_steps(self):
|
|
159
|
+
"""Test forward process with invalid time steps"""
|
|
160
|
+
scheduler = VarianceSchedulerDDIM(num_steps=100)
|
|
161
|
+
forward_ddim = ForwardDDIM(scheduler)
|
|
162
|
+
|
|
163
|
+
x0 = torch.randn(2, 1, 28, 28)
|
|
164
|
+
noise = torch.randn_like(x0)
|
|
165
|
+
|
|
166
|
+
# Test time steps out of range
|
|
167
|
+
with pytest.raises(ValueError):
|
|
168
|
+
time_steps = torch.tensor([100, 150]) # Out of range
|
|
169
|
+
forward_ddim(x0, noise, time_steps)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class TestReverseDDIM:
|
|
173
|
+
"""Test cases for ReverseDDIM"""
|
|
174
|
+
|
|
175
|
+
def test_reverse_process(self):
|
|
176
|
+
"""Test reverse diffusion process"""
|
|
177
|
+
scheduler = VarianceSchedulerDDIM(
|
|
178
|
+
num_steps=1000,
|
|
179
|
+
tau_num_steps=100,
|
|
180
|
+
eta=0.0 # Deterministic
|
|
181
|
+
)
|
|
182
|
+
reverse_ddim = ReverseDDIM(scheduler)
|
|
183
|
+
|
|
184
|
+
batch_size, channels, height, width = 4, 1, 28, 28
|
|
185
|
+
xt = torch.randn(batch_size, channels, height, width)
|
|
186
|
+
predicted_noise = torch.randn_like(xt)
|
|
187
|
+
time_steps = torch.randint(1, scheduler.tau_num_steps, (batch_size,))
|
|
188
|
+
prev_time_steps = time_steps - 1
|
|
189
|
+
|
|
190
|
+
xt_prev, x0 = reverse_ddim(xt, predicted_noise, time_steps, prev_time_steps)
|
|
191
|
+
|
|
192
|
+
assert xt_prev.shape == xt.shape
|
|
193
|
+
assert x0.shape == xt.shape
|
|
194
|
+
|
|
195
|
+
def test_stochastic_sampling(self):
|
|
196
|
+
"""Test stochastic sampling with eta > 0"""
|
|
197
|
+
scheduler = VarianceSchedulerDDIM(
|
|
198
|
+
num_steps=1000,
|
|
199
|
+
tau_num_steps=100,
|
|
200
|
+
eta=0.5 # Stochastic
|
|
201
|
+
)
|
|
202
|
+
reverse_ddim = ReverseDDIM(scheduler)
|
|
203
|
+
|
|
204
|
+
xt = torch.randn(2, 1, 28, 28)
|
|
205
|
+
predicted_noise = torch.randn_like(xt)
|
|
206
|
+
time_steps = torch.tensor([50, 60])
|
|
207
|
+
prev_time_steps = torch.tensor([49, 59])
|
|
208
|
+
|
|
209
|
+
# Multiple runs should give different results due to stochasticity
|
|
210
|
+
xt_prev1, _ = reverse_ddim(xt, predicted_noise, time_steps, prev_time_steps)
|
|
211
|
+
xt_prev2, _ = reverse_ddim(xt, predicted_noise, time_steps, prev_time_steps)
|
|
212
|
+
|
|
213
|
+
assert not torch.allclose(xt_prev1, xt_prev2, atol=1e-6)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class TestTrainDDIM:
|
|
217
|
+
"""Test cases for TrainDDIM"""
|
|
218
|
+
|
|
219
|
+
def setup_training_components(self):
|
|
220
|
+
"""Setup components needed for training tests"""
|
|
221
|
+
# Create mock data
|
|
222
|
+
x = torch.randn(100, 1, 28, 28)
|
|
223
|
+
y = torch.randint(0, 10, (100,)).float() # Mock labels
|
|
224
|
+
dataset = TensorDataset(x, y)
|
|
225
|
+
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)
|
|
226
|
+
val_loader = DataLoader(dataset, batch_size=16, shuffle=False)
|
|
227
|
+
|
|
228
|
+
# Models
|
|
229
|
+
noise_predictor = MockNoisePredictor()
|
|
230
|
+
conditional_model = MockConditionalModel()
|
|
231
|
+
|
|
232
|
+
# DDIM components
|
|
233
|
+
scheduler = VarianceSchedulerDDIM(num_steps=100, tau_num_steps=10)
|
|
234
|
+
forward_ddim = ForwardDDIM(scheduler)
|
|
235
|
+
reverse_ddim = ReverseDDIM(scheduler)
|
|
236
|
+
|
|
237
|
+
# Training components
|
|
238
|
+
optimizer = torch.optim.Adam(
|
|
239
|
+
list(noise_predictor.parameters()) + list(conditional_model.parameters()),
|
|
240
|
+
lr=1e-3
|
|
241
|
+
)
|
|
242
|
+
loss_fn = nn.MSELoss()
|
|
243
|
+
metrics = MockMetrics()
|
|
244
|
+
|
|
245
|
+
return {
|
|
246
|
+
'train_loader': train_loader,
|
|
247
|
+
'val_loader': val_loader,
|
|
248
|
+
'noise_predictor': noise_predictor,
|
|
249
|
+
'conditional_model': conditional_model,
|
|
250
|
+
'forward_ddim': forward_ddim,
|
|
251
|
+
'reverse_ddim': reverse_ddim,
|
|
252
|
+
'optimizer': optimizer,
|
|
253
|
+
'loss_fn': loss_fn,
|
|
254
|
+
'metrics': metrics
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
def test_trainer_initialization(self):
|
|
258
|
+
"""Test TrainDDIM initialization"""
|
|
259
|
+
components = self.setup_training_components()
|
|
260
|
+
|
|
261
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
262
|
+
trainer = TrainDDIM(
|
|
263
|
+
noise_predictor=components['noise_predictor'],
|
|
264
|
+
forward_diffusion=components['forward_ddim'],
|
|
265
|
+
reverse_diffusion=components['reverse_ddim'],
|
|
266
|
+
data_loader=components['train_loader'],
|
|
267
|
+
optimizer=components['optimizer'],
|
|
268
|
+
objective=components['loss_fn'],
|
|
269
|
+
val_loader=components['val_loader'],
|
|
270
|
+
max_epochs=2,
|
|
271
|
+
device='cpu',
|
|
272
|
+
conditional_model=components['conditional_model'],
|
|
273
|
+
metrics_=components['metrics'],
|
|
274
|
+
store_path=temp_dir,
|
|
275
|
+
use_ddp=False
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
assert trainer.max_epochs == 2
|
|
279
|
+
assert trainer.device == torch.device('cpu')
|
|
280
|
+
|
|
281
|
+
def test_training_loop(self):
|
|
282
|
+
"""Test basic training functionality"""
|
|
283
|
+
components = self.setup_training_components()
|
|
284
|
+
|
|
285
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
286
|
+
trainer = TrainDDIM(
|
|
287
|
+
noise_predictor=components['noise_predictor'],
|
|
288
|
+
forward_diffusion=components['forward_ddim'],
|
|
289
|
+
reverse_diffusion=components['reverse_ddim'],
|
|
290
|
+
data_loader=components['train_loader'],
|
|
291
|
+
optimizer=components['optimizer'],
|
|
292
|
+
objective=components['loss_fn'],
|
|
293
|
+
max_epochs=2,
|
|
294
|
+
device='cpu',
|
|
295
|
+
store_path=temp_dir,
|
|
296
|
+
use_ddp=False,
|
|
297
|
+
log_frequency=1
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Run training
|
|
301
|
+
train_losses, best_val_loss = trainer()
|
|
302
|
+
|
|
303
|
+
assert len(train_losses) == 2
|
|
304
|
+
assert isinstance(best_val_loss, float)
|
|
305
|
+
|
|
306
|
+
def test_checkpoint_save_load(self):
|
|
307
|
+
"""Test checkpoint saving and loading"""
|
|
308
|
+
components = self.setup_training_components()
|
|
309
|
+
|
|
310
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
311
|
+
trainer = TrainDDIM(
|
|
312
|
+
noise_predictor=components['noise_predictor'],
|
|
313
|
+
forward_diffusion=components['forward_ddim'],
|
|
314
|
+
reverse_diffusion=components['reverse_ddim'],
|
|
315
|
+
data_loader=components['train_loader'],
|
|
316
|
+
optimizer=components['optimizer'],
|
|
317
|
+
objective=components['loss_fn'],
|
|
318
|
+
max_epochs=1,
|
|
319
|
+
device='cpu',
|
|
320
|
+
store_path=temp_dir,
|
|
321
|
+
use_ddp=False
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Save checkpoint manually
|
|
325
|
+
trainer._save_checkpoint(1, 0.5)
|
|
326
|
+
|
|
327
|
+
# Check if checkpoint file exists
|
|
328
|
+
checkpoint_files = [f for f in os.listdir(temp_dir) if f.endswith('.pth')]
|
|
329
|
+
assert len(checkpoint_files) > 0
|
|
330
|
+
|
|
331
|
+
# Load checkpoint
|
|
332
|
+
checkpoint_path = os.path.join(temp_dir, checkpoint_files[0])
|
|
333
|
+
epoch, loss = trainer.load_checkpoint(checkpoint_path)
|
|
334
|
+
|
|
335
|
+
assert epoch == 1
|
|
336
|
+
assert loss == 0.5
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
class TestSampleDDIM:
|
|
340
|
+
"""Test cases for SampleDDIM"""
|
|
341
|
+
|
|
342
|
+
def setup_sampling_components(self):
|
|
343
|
+
"""Setup components needed for sampling tests"""
|
|
344
|
+
scheduler = VarianceSchedulerDDIM(num_steps=100, tau_num_steps=10)
|
|
345
|
+
reverse_ddim = ReverseDDIM(scheduler)
|
|
346
|
+
noise_predictor = MockNoisePredictor()
|
|
347
|
+
conditional_model = MockConditionalModel()
|
|
348
|
+
|
|
349
|
+
return {
|
|
350
|
+
'reverse_ddim': reverse_ddim,
|
|
351
|
+
'noise_predictor': noise_predictor,
|
|
352
|
+
'conditional_model': conditional_model
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
def test_sampler_initialization(self):
|
|
356
|
+
"""Test SampleDDIM initialization"""
|
|
357
|
+
components = self.setup_sampling_components()
|
|
358
|
+
|
|
359
|
+
sampler = SampleDDIM(
|
|
360
|
+
reverse_diffusion=components['reverse_ddim'],
|
|
361
|
+
noise_predictor=components['noise_predictor'],
|
|
362
|
+
image_shape=(28, 28),
|
|
363
|
+
batch_size=4,
|
|
364
|
+
device='cpu'
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
assert sampler.image_shape == (28, 28)
|
|
368
|
+
assert sampler.batch_size == 4
|
|
369
|
+
assert sampler.device == torch.device('cpu')
|
|
370
|
+
|
|
371
|
+
def test_unconditional_sampling(self):
|
|
372
|
+
"""Test unconditional image generation"""
|
|
373
|
+
components = self.setup_sampling_components()
|
|
374
|
+
|
|
375
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
376
|
+
sampler = SampleDDIM(
|
|
377
|
+
reverse_diffusion=components['reverse_ddim'],
|
|
378
|
+
noise_predictor=components['noise_predictor'],
|
|
379
|
+
image_shape=(28, 28),
|
|
380
|
+
batch_size=2,
|
|
381
|
+
in_channels=1, # Match the mock predictor
|
|
382
|
+
device='cpu',
|
|
383
|
+
image_output_range=(-1, 1)
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Generate samples
|
|
387
|
+
generated_images = sampler(
|
|
388
|
+
conditions=None,
|
|
389
|
+
save_images=True,
|
|
390
|
+
save_path=temp_dir
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
assert generated_images.shape == (2, 1, 28, 28) # 1 channel to match predictor
|
|
394
|
+
assert torch.all(generated_images >= 0) and torch.all(generated_images <= 1)
|
|
395
|
+
|
|
396
|
+
def test_conditional_sampling(self):
|
|
397
|
+
"""Test conditional image generation"""
|
|
398
|
+
components = self.setup_sampling_components()
|
|
399
|
+
|
|
400
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
401
|
+
sampler = SampleDDIM(
|
|
402
|
+
reverse_diffusion=components['reverse_ddim'],
|
|
403
|
+
noise_predictor=components['noise_predictor'],
|
|
404
|
+
conditional_model=components['conditional_model'],
|
|
405
|
+
image_shape=(28, 28),
|
|
406
|
+
batch_size=2,
|
|
407
|
+
in_channels=1, # Match the mock predictor
|
|
408
|
+
device='cpu'
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# Test with string prompts
|
|
412
|
+
conditions = ["cat", "dog"]
|
|
413
|
+
generated_images = sampler(
|
|
414
|
+
conditions=conditions,
|
|
415
|
+
save_images=False
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
assert generated_images.shape == (2, 1, 28, 28) # 1 channel to match predictor
|
|
419
|
+
|
|
420
|
+
def test_tokenization(self):
|
|
421
|
+
"""Test text tokenization functionality"""
|
|
422
|
+
components = self.setup_sampling_components()
|
|
423
|
+
|
|
424
|
+
sampler = SampleDDIM(
|
|
425
|
+
reverse_diffusion=components['reverse_ddim'],
|
|
426
|
+
noise_predictor=components['noise_predictor'],
|
|
427
|
+
conditional_model=components['conditional_model'],
|
|
428
|
+
image_shape=(28, 28),
|
|
429
|
+
device='cpu'
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
prompts = ["a cat", "a dog"]
|
|
433
|
+
input_ids, attention_mask = sampler.tokenize(prompts)
|
|
434
|
+
|
|
435
|
+
assert input_ids.shape[0] == 2 # Batch size
|
|
436
|
+
assert attention_mask.shape[0] == 2
|
|
437
|
+
assert input_ids.shape[1] == sampler.max_token_length
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
class TestIntegrationDDIM:
|
|
441
|
+
"""Integration tests for full DDIM pipeline"""
|
|
442
|
+
|
|
443
|
+
def test_end_to_end_pipeline(self):
|
|
444
|
+
"""Test complete training and sampling pipeline"""
|
|
445
|
+
# Setup data
|
|
446
|
+
x = torch.randn(20, 1, 16, 16) # Smaller for faster testing
|
|
447
|
+
y = torch.randint(0, 5, (20,)).float()
|
|
448
|
+
dataset = TensorDataset(x, y)
|
|
449
|
+
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)
|
|
450
|
+
|
|
451
|
+
# Setup models
|
|
452
|
+
noise_predictor = MockNoisePredictor()
|
|
453
|
+
conditional_model = MockConditionalModel()
|
|
454
|
+
|
|
455
|
+
# Setup DDIM components
|
|
456
|
+
scheduler = VarianceSchedulerDDIM(num_steps=50, tau_num_steps=5)
|
|
457
|
+
forward_ddim = ForwardDDIM(scheduler)
|
|
458
|
+
reverse_ddim = ReverseDDIM(scheduler)
|
|
459
|
+
|
|
460
|
+
# Setup training
|
|
461
|
+
optimizer = torch.optim.Adam(
|
|
462
|
+
list(noise_predictor.parameters()) + list(conditional_model.parameters()),
|
|
463
|
+
lr=1e-3
|
|
464
|
+
)
|
|
465
|
+
loss_fn = nn.MSELoss()
|
|
466
|
+
|
|
467
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
468
|
+
# Train model
|
|
469
|
+
trainer = TrainDDIM(
|
|
470
|
+
noise_predictor=noise_predictor,
|
|
471
|
+
forward_diffusion=forward_ddim,
|
|
472
|
+
reverse_diffusion=reverse_ddim,
|
|
473
|
+
data_loader=train_loader,
|
|
474
|
+
optimizer=optimizer,
|
|
475
|
+
objective=loss_fn,
|
|
476
|
+
max_epochs=1,
|
|
477
|
+
device='cpu',
|
|
478
|
+
conditional_model=conditional_model,
|
|
479
|
+
store_path=temp_dir,
|
|
480
|
+
use_ddp=False
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
train_losses, _ = trainer()
|
|
484
|
+
assert len(train_losses) == 1
|
|
485
|
+
|
|
486
|
+
# Test sampling
|
|
487
|
+
sampler = SampleDDIM(
|
|
488
|
+
reverse_diffusion=reverse_ddim,
|
|
489
|
+
noise_predictor=noise_predictor,
|
|
490
|
+
conditional_model=conditional_model,
|
|
491
|
+
image_shape=(16, 16),
|
|
492
|
+
batch_size=2,
|
|
493
|
+
in_channels=1,
|
|
494
|
+
device='cpu'
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
conditions = ["test1", "test2"]
|
|
498
|
+
generated_images = sampler(conditions=conditions, save_images=False)
|
|
499
|
+
|
|
500
|
+
assert generated_images.shape == (2, 1, 16, 16)
|
|
501
|
+
|
|
502
|
+
def test_forward_reverse_consistency(self):
|
|
503
|
+
"""Test that forward and reverse processes are mathematically consistent"""
|
|
504
|
+
scheduler = VarianceSchedulerDDIM(
|
|
505
|
+
num_steps=100,
|
|
506
|
+
tau_num_steps=10,
|
|
507
|
+
eta=0.0 # Deterministic
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
forward_ddim = ForwardDDIM(scheduler)
|
|
511
|
+
reverse_ddim = ReverseDDIM(scheduler)
|
|
512
|
+
noise_predictor = MockNoisePredictor()
|
|
513
|
+
|
|
514
|
+
# Original image
|
|
515
|
+
x0 = torch.randn(1, 1, 16, 16)
|
|
516
|
+
|
|
517
|
+
# Apply forward process using full timestep range
|
|
518
|
+
noise = torch.randn_like(x0)
|
|
519
|
+
t_full = torch.tensor([50]) # Use full timestep range for forward
|
|
520
|
+
xt = forward_ddim(x0, noise, t_full)
|
|
521
|
+
|
|
522
|
+
# Apply reverse process using tau (subsampled) timestep range
|
|
523
|
+
t_tau = torch.tensor([5]) # Middle of tau range (0-9)
|
|
524
|
+
prev_t_tau = torch.tensor([4]) # Previous tau timestep
|
|
525
|
+
xt_prev, x0_pred = reverse_ddim(xt, noise, t_tau, prev_t_tau)
|
|
526
|
+
|
|
527
|
+
# Check that shapes are consistent
|
|
528
|
+
assert xt_prev.shape == x0.shape
|
|
529
|
+
assert x0_pred.shape == x0.shape
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
# Pytest configuration and runner
|
|
533
|
+
if __name__ == "__main__":
|
|
534
|
+
# Run specific test categories
|
|
535
|
+
print("Running VarianceSchedulerDDIM tests...")
|
|
536
|
+
pytest.main(["-v", "TestVarianceSchedulerDDIM"])
|
|
537
|
+
|
|
538
|
+
print("\nRunning ForwardDDIM tests...")
|
|
539
|
+
pytest.main(["-v", "TestForwardDDIM"])
|
|
540
|
+
|
|
541
|
+
print("\nRunning ReverseDDIM tests...")
|
|
542
|
+
pytest.main(["-v", "TestReverseDDIM"])
|
|
543
|
+
|
|
544
|
+
print("\nRunning TrainDDIM tests...")
|
|
545
|
+
pytest.main(["-v", "TestTrainDDIM"])
|
|
546
|
+
|
|
547
|
+
print("\nRunning SampleDDIM tests...")
|
|
548
|
+
pytest.main(["-v", "TestSampleDDIM"])
|
|
549
|
+
|
|
550
|
+
print("\nRunning Integration tests...")
|
|
551
|
+
pytest.main(["-v", "TestIntegrationDDIM"])
|