TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
|
@@ -0,0 +1,1188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Comprehensive test suite for the DDPM (Denoising Diffusion Probabilistic Models) implementation.
|
|
3
|
+
|
|
4
|
+
This test suite covers all major components of the DDPM implementation including:
|
|
5
|
+
- VarianceSchedulerDDPM (noise scheduling)
|
|
6
|
+
- ForwardDDPM (forward diffusion process)
|
|
7
|
+
- ReverseDDPM (reverse diffusion process)
|
|
8
|
+
- TrainDDPM (training loop)
|
|
9
|
+
- SampleDDPM (image generation)
|
|
10
|
+
- Integration tests with different configurations
|
|
11
|
+
|
|
12
|
+
Usage:
|
|
13
|
+
python test_ddpm.py
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import unittest
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
import tempfile
|
|
20
|
+
import shutil
|
|
21
|
+
import os
|
|
22
|
+
from unittest.mock import Mock, patch
|
|
23
|
+
import numpy as np
|
|
24
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
25
|
+
import warnings
|
|
26
|
+
from torchdiff.ddpm import VarianceSchedulerDDPM, ForwardDDPM, ReverseDDPM, TrainDDPM, SampleDDPM
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# Simple mock classes for testing
|
|
31
|
+
class MockNoisePredictor(nn.Module):
|
|
32
|
+
"""Simple noise predictor for testing"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, in_channels=1):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.conv1 = nn.Conv2d(in_channels, 16, 3, padding=1)
|
|
37
|
+
self.conv2 = nn.Conv2d(16, in_channels, 3, padding=1)
|
|
38
|
+
self.time_mlp = nn.Linear(1, 16)
|
|
39
|
+
|
|
40
|
+
def forward(self, x, t, y_encoded=None, context=None):
|
|
41
|
+
# Simple time embedding
|
|
42
|
+
t_embed = self.time_mlp(t.float().unsqueeze(-1))
|
|
43
|
+
t_embed = t_embed.view(t_embed.shape[0], t_embed.shape[1], 1, 1)
|
|
44
|
+
t_embed = t_embed.expand(-1, -1, x.shape[2], x.shape[3])
|
|
45
|
+
|
|
46
|
+
x = self.conv1(x)
|
|
47
|
+
x = x + t_embed
|
|
48
|
+
x = torch.relu(x)
|
|
49
|
+
x = self.conv2(x)
|
|
50
|
+
return x
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class MockConditionalModel(nn.Module):
|
|
54
|
+
"""Simple conditional model for testing"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, embed_dim=32):
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.embed_dim = embed_dim
|
|
59
|
+
|
|
60
|
+
def forward(self, input_ids, attention_mask):
|
|
61
|
+
batch_size = input_ids.shape[0]
|
|
62
|
+
return torch.randn(batch_size, self.embed_dim)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class MockMetrics:
|
|
66
|
+
"""Mock metrics class for testing"""
|
|
67
|
+
|
|
68
|
+
def __init__(self):
|
|
69
|
+
self.fid = True
|
|
70
|
+
self.metrics = True
|
|
71
|
+
self.lpips = True
|
|
72
|
+
|
|
73
|
+
def forward(self, real_imgs, fake_imgs):
|
|
74
|
+
# Return mock metrics
|
|
75
|
+
return 50.0, 0.1, 25.0, 0.8, 0.05 # fid, mse, psnr, ssim, lpips
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class TestVarianceSchedulerDDPM(unittest.TestCase):
|
|
79
|
+
"""Test cases for VarianceSchedulerDDPM"""
|
|
80
|
+
|
|
81
|
+
def setUp(self):
|
|
82
|
+
self.num_steps = 100
|
|
83
|
+
self.beta_start = 1e-4
|
|
84
|
+
self.beta_end = 0.02
|
|
85
|
+
|
|
86
|
+
def test_initialization_fixed_schedule(self):
|
|
87
|
+
"""Test initialization with fixed (non-trainable) schedule"""
|
|
88
|
+
scheduler = VarianceSchedulerDDPM(
|
|
89
|
+
num_steps=self.num_steps,
|
|
90
|
+
beta_start=self.beta_start,
|
|
91
|
+
beta_end=self.beta_end,
|
|
92
|
+
trainable_beta=False
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
self.assertEqual(scheduler.num_steps, self.num_steps)
|
|
96
|
+
self.assertEqual(scheduler.beta_start, self.beta_start)
|
|
97
|
+
self.assertEqual(scheduler.beta_end, self.beta_end)
|
|
98
|
+
self.assertFalse(scheduler.trainable_beta)
|
|
99
|
+
|
|
100
|
+
# Check that buffers are registered
|
|
101
|
+
self.assertTrue(hasattr(scheduler, 'betas'))
|
|
102
|
+
self.assertTrue(hasattr(scheduler, 'alphas'))
|
|
103
|
+
self.assertTrue(hasattr(scheduler, 'alpha_bars'))
|
|
104
|
+
self.assertEqual(scheduler.betas.shape[0], self.num_steps)
|
|
105
|
+
|
|
106
|
+
def test_initialization_trainable_schedule(self):
|
|
107
|
+
"""Test initialization with trainable schedule"""
|
|
108
|
+
scheduler = VarianceSchedulerDDPM(
|
|
109
|
+
num_steps=self.num_steps,
|
|
110
|
+
trainable_beta=True
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
self.assertTrue(scheduler.trainable_beta)
|
|
114
|
+
self.assertIsInstance(scheduler.betas, nn.Parameter)
|
|
115
|
+
|
|
116
|
+
def test_invalid_parameters(self):
|
|
117
|
+
"""Test initialization with invalid parameters"""
|
|
118
|
+
with self.assertRaises(ValueError):
|
|
119
|
+
VarianceSchedulerDDPM(beta_start=0.02, beta_end=1e-4) # start > end
|
|
120
|
+
|
|
121
|
+
with self.assertRaises(ValueError):
|
|
122
|
+
VarianceSchedulerDDPM(num_steps=-1) # negative steps
|
|
123
|
+
|
|
124
|
+
with self.assertRaises(ValueError):
|
|
125
|
+
VarianceSchedulerDDPM(beta_start=0.0) # zero start
|
|
126
|
+
|
|
127
|
+
def test_beta_schedule_methods(self):
|
|
128
|
+
"""Test different beta schedule computation methods"""
|
|
129
|
+
methods = ["linear", "sigmoid", "quadratic", "constant", "inverse_time"]
|
|
130
|
+
|
|
131
|
+
for method in methods:
|
|
132
|
+
scheduler = VarianceSchedulerDDPM(
|
|
133
|
+
num_steps=self.num_steps,
|
|
134
|
+
beta_method=method
|
|
135
|
+
)
|
|
136
|
+
self.assertEqual(scheduler.betas.shape[0], self.num_steps)
|
|
137
|
+
self.assertTrue(torch.all(scheduler.betas >= self.beta_start))
|
|
138
|
+
self.assertTrue(torch.all(scheduler.betas <= self.beta_end))
|
|
139
|
+
|
|
140
|
+
def test_invalid_beta_method(self):
|
|
141
|
+
"""Test invalid beta schedule method"""
|
|
142
|
+
with self.assertRaises(ValueError):
|
|
143
|
+
VarianceSchedulerDDPM(beta_method="invalid_method")
|
|
144
|
+
|
|
145
|
+
def test_compute_schedule_fixed(self):
|
|
146
|
+
"""Test compute_schedule method with fixed schedule"""
|
|
147
|
+
scheduler = VarianceSchedulerDDPM(num_steps=self.num_steps, trainable_beta=False)
|
|
148
|
+
|
|
149
|
+
# Test without time_steps (all steps)
|
|
150
|
+
betas, alphas, alpha_bars, sqrt_alpha_bars, sqrt_one_minus_alpha_bars = scheduler.compute_schedule()
|
|
151
|
+
|
|
152
|
+
self.assertEqual(betas.shape[0], self.num_steps)
|
|
153
|
+
self.assertEqual(alphas.shape[0], self.num_steps)
|
|
154
|
+
self.assertEqual(alpha_bars.shape[0], self.num_steps)
|
|
155
|
+
|
|
156
|
+
# Test with specific time_steps
|
|
157
|
+
time_steps = torch.tensor([0, 50, 99])
|
|
158
|
+
result = scheduler.compute_schedule(time_steps)
|
|
159
|
+
betas_t, alphas_t, alpha_bars_t, sqrt_alpha_bars_t, sqrt_one_minus_alpha_bars_t = result
|
|
160
|
+
|
|
161
|
+
self.assertEqual(betas_t.shape[0], 3)
|
|
162
|
+
self.assertEqual(alphas_t.shape[0], 3)
|
|
163
|
+
|
|
164
|
+
def test_compute_schedule_trainable(self):
|
|
165
|
+
"""Test compute_schedule method with trainable schedule"""
|
|
166
|
+
scheduler = VarianceSchedulerDDPM(num_steps=self.num_steps, trainable_beta=True)
|
|
167
|
+
|
|
168
|
+
time_steps = torch.tensor([0, 50, 99])
|
|
169
|
+
result = scheduler.compute_schedule(time_steps)
|
|
170
|
+
betas_t, alphas_t, alpha_bars_t, sqrt_alpha_bars_t, sqrt_one_minus_alpha_bars_t = result
|
|
171
|
+
|
|
172
|
+
self.assertEqual(betas_t.shape[0], 3)
|
|
173
|
+
self.assertTrue(torch.all(betas_t > 0))
|
|
174
|
+
self.assertTrue(torch.all(betas_t < 1))
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class TestForwardDDPM(unittest.TestCase):
|
|
178
|
+
"""Test cases for ForwardDDPM"""
|
|
179
|
+
|
|
180
|
+
def setUp(self):
|
|
181
|
+
self.variance_scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=False)
|
|
182
|
+
self.forward_ddpm = ForwardDDPM(self.variance_scheduler)
|
|
183
|
+
self.batch_size = 4
|
|
184
|
+
self.channels = 1
|
|
185
|
+
self.height = 28
|
|
186
|
+
self.width = 28
|
|
187
|
+
|
|
188
|
+
def test_forward_process(self):
|
|
189
|
+
"""Test the forward diffusion process"""
|
|
190
|
+
x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
191
|
+
noise = torch.randn_like(x0)
|
|
192
|
+
time_steps = torch.randint(0, self.variance_scheduler.num_steps, (self.batch_size,))
|
|
193
|
+
|
|
194
|
+
xt = self.forward_ddpm(x0, noise, time_steps)
|
|
195
|
+
|
|
196
|
+
# Check output shape
|
|
197
|
+
self.assertEqual(xt.shape, x0.shape)
|
|
198
|
+
|
|
199
|
+
# Check that output is finite
|
|
200
|
+
self.assertTrue(torch.all(torch.isfinite(xt)))
|
|
201
|
+
|
|
202
|
+
def test_forward_process_trainable(self):
|
|
203
|
+
"""Test forward process with trainable variance scheduler"""
|
|
204
|
+
variance_scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=True)
|
|
205
|
+
forward_ddpm = ForwardDDPM(variance_scheduler)
|
|
206
|
+
|
|
207
|
+
x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
208
|
+
noise = torch.randn_like(x0)
|
|
209
|
+
time_steps = torch.randint(0, variance_scheduler.num_steps, (self.batch_size,))
|
|
210
|
+
|
|
211
|
+
xt = forward_ddpm(x0, noise, time_steps)
|
|
212
|
+
self.assertEqual(xt.shape, x0.shape)
|
|
213
|
+
|
|
214
|
+
def test_invalid_time_steps(self):
|
|
215
|
+
"""Test with invalid time steps"""
|
|
216
|
+
x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
217
|
+
noise = torch.randn_like(x0)
|
|
218
|
+
|
|
219
|
+
# Time steps out of range
|
|
220
|
+
invalid_time_steps = torch.tensor([100, 200, -1, 50])
|
|
221
|
+
|
|
222
|
+
with self.assertRaises(ValueError):
|
|
223
|
+
self.forward_ddpm(x0, noise, invalid_time_steps)
|
|
224
|
+
|
|
225
|
+
def test_edge_cases(self):
|
|
226
|
+
"""Test edge cases like t=0 and t=max"""
|
|
227
|
+
x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
228
|
+
noise = torch.randn_like(x0)
|
|
229
|
+
|
|
230
|
+
# Test t=0 (should be close to original)
|
|
231
|
+
time_steps = torch.zeros(self.batch_size, dtype=torch.long)
|
|
232
|
+
xt = self.forward_ddpm(x0, noise, time_steps)
|
|
233
|
+
self.assertEqual(xt.shape, x0.shape)
|
|
234
|
+
|
|
235
|
+
# Test t=max-1
|
|
236
|
+
time_steps = torch.full((self.batch_size,), self.variance_scheduler.num_steps - 1, dtype=torch.long)
|
|
237
|
+
xt = self.forward_ddpm(x0, noise, time_steps)
|
|
238
|
+
self.assertEqual(xt.shape, x0.shape)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class TestReverseDDPM(unittest.TestCase):
|
|
242
|
+
"""Test cases for ReverseDDPM"""
|
|
243
|
+
|
|
244
|
+
def setUp(self):
|
|
245
|
+
self.variance_scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=False)
|
|
246
|
+
self.reverse_ddpm = ReverseDDPM(self.variance_scheduler)
|
|
247
|
+
self.batch_size = 4
|
|
248
|
+
self.channels = 1
|
|
249
|
+
self.height = 28
|
|
250
|
+
self.width = 28
|
|
251
|
+
|
|
252
|
+
def test_reverse_process(self):
|
|
253
|
+
"""Test the reverse diffusion process"""
|
|
254
|
+
xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
255
|
+
predicted_noise = torch.randn_like(xt)
|
|
256
|
+
time_steps = torch.randint(1, self.variance_scheduler.num_steps, (self.batch_size,))
|
|
257
|
+
|
|
258
|
+
xt_minus_1 = self.reverse_ddpm(xt, predicted_noise, time_steps)
|
|
259
|
+
|
|
260
|
+
# Check output shape
|
|
261
|
+
self.assertEqual(xt_minus_1.shape, xt.shape)
|
|
262
|
+
|
|
263
|
+
# Check that output is finite
|
|
264
|
+
self.assertTrue(torch.all(torch.isfinite(xt_minus_1)))
|
|
265
|
+
|
|
266
|
+
def test_reverse_process_t_zero(self):
|
|
267
|
+
"""Test reverse process at t=0 (deterministic)"""
|
|
268
|
+
xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
269
|
+
predicted_noise = torch.randn_like(xt)
|
|
270
|
+
time_steps = torch.zeros(self.batch_size, dtype=torch.long)
|
|
271
|
+
|
|
272
|
+
xt_minus_1 = self.reverse_ddpm(xt, predicted_noise, time_steps)
|
|
273
|
+
|
|
274
|
+
# At t=0, the process should be deterministic (no random noise added)
|
|
275
|
+
self.assertEqual(xt_minus_1.shape, xt.shape)
|
|
276
|
+
|
|
277
|
+
def test_reverse_process_trainable(self):
|
|
278
|
+
"""Test reverse process with trainable variance scheduler"""
|
|
279
|
+
variance_scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=True)
|
|
280
|
+
reverse_ddpm = ReverseDDPM(variance_scheduler)
|
|
281
|
+
|
|
282
|
+
xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
283
|
+
predicted_noise = torch.randn_like(xt)
|
|
284
|
+
time_steps = torch.randint(1, variance_scheduler.num_steps, (self.batch_size,))
|
|
285
|
+
|
|
286
|
+
xt_minus_1 = reverse_ddpm(xt, predicted_noise, time_steps)
|
|
287
|
+
self.assertEqual(xt_minus_1.shape, xt.shape)
|
|
288
|
+
|
|
289
|
+
def test_invalid_time_steps(self):
|
|
290
|
+
"""Test with invalid time steps"""
|
|
291
|
+
xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
292
|
+
predicted_noise = torch.randn_like(xt)
|
|
293
|
+
|
|
294
|
+
# Time steps out of range
|
|
295
|
+
invalid_time_steps = torch.tensor([100, 200, -1, 50])
|
|
296
|
+
|
|
297
|
+
with self.assertRaises(ValueError):
|
|
298
|
+
self.reverse_ddpm(xt, predicted_noise, invalid_time_steps)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
class TestTrainDDPM(unittest.TestCase):
|
|
302
|
+
"""Test cases for TrainDDPM"""
|
|
303
|
+
|
|
304
|
+
def setUp(self):
|
|
305
|
+
# Create temporary directory for test outputs
|
|
306
|
+
self.test_dir = tempfile.mkdtemp()
|
|
307
|
+
|
|
308
|
+
# Create mock data
|
|
309
|
+
self.batch_size = 8
|
|
310
|
+
self.channels = 1
|
|
311
|
+
self.height = 28
|
|
312
|
+
self.width = 28
|
|
313
|
+
|
|
314
|
+
# Create mock dataset
|
|
315
|
+
x_data = torch.randn(32, self.channels, self.height, self.width)
|
|
316
|
+
y_data = ['test prompt'] * 32
|
|
317
|
+
dataset = TensorDataset(x_data, torch.arange(32)) # Simple labels
|
|
318
|
+
|
|
319
|
+
self.train_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
|
|
320
|
+
self.val_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
|
|
321
|
+
|
|
322
|
+
# Create components
|
|
323
|
+
self.variance_scheduler = VarianceSchedulerDDPM(num_steps=50, trainable_beta=False)
|
|
324
|
+
self.forward_ddpm = ForwardDDPM(self.variance_scheduler)
|
|
325
|
+
self.reverse_ddpm = ReverseDDPM(self.variance_scheduler)
|
|
326
|
+
self.noise_predictor = MockNoisePredictor(self.channels)
|
|
327
|
+
self.optimizer = torch.optim.Adam(self.noise_predictor.parameters(), lr=1e-3)
|
|
328
|
+
self.objective = nn.MSELoss()
|
|
329
|
+
|
|
330
|
+
def tearDown(self):
|
|
331
|
+
# Clean up temporary directory
|
|
332
|
+
shutil.rmtree(self.test_dir, ignore_errors=True)
|
|
333
|
+
|
|
334
|
+
def test_initialization(self):
|
|
335
|
+
"""Test TrainDDPM initialization"""
|
|
336
|
+
trainer = TrainDDPM(
|
|
337
|
+
noise_predictor=self.noise_predictor,
|
|
338
|
+
forward_diffusion=self.forward_ddpm,
|
|
339
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
340
|
+
data_loader=self.train_loader,
|
|
341
|
+
optimizer=self.optimizer,
|
|
342
|
+
objective=self.objective,
|
|
343
|
+
max_epochs=2,
|
|
344
|
+
device='cpu',
|
|
345
|
+
store_path=self.test_dir
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
self.assertEqual(trainer.max_epochs, 2)
|
|
349
|
+
self.assertEqual(trainer.device, torch.device('cpu'))
|
|
350
|
+
self.assertIsNotNone(trainer.scheduler)
|
|
351
|
+
self.assertIsNotNone(trainer.warmup_lr_scheduler)
|
|
352
|
+
|
|
353
|
+
def test_training_loop_basic(self):
|
|
354
|
+
"""Test basic training loop without validation"""
|
|
355
|
+
trainer = TrainDDPM(
|
|
356
|
+
noise_predictor=self.noise_predictor,
|
|
357
|
+
forward_diffusion=self.forward_ddpm,
|
|
358
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
359
|
+
data_loader=self.train_loader,
|
|
360
|
+
optimizer=self.optimizer,
|
|
361
|
+
objective=self.objective,
|
|
362
|
+
max_epochs=2,
|
|
363
|
+
device='cpu',
|
|
364
|
+
store_path=self.test_dir,
|
|
365
|
+
log_frequency=1
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
train_losses, best_val_loss = trainer()
|
|
369
|
+
|
|
370
|
+
self.assertEqual(len(train_losses), 2) # 2 epochs
|
|
371
|
+
self.assertTrue(all(isinstance(loss, float) for loss in train_losses))
|
|
372
|
+
self.assertIsInstance(best_val_loss, float)
|
|
373
|
+
|
|
374
|
+
def test_training_with_validation(self):
|
|
375
|
+
"""Test training loop with validation"""
|
|
376
|
+
metrics = MockMetrics()
|
|
377
|
+
|
|
378
|
+
trainer = TrainDDPM(
|
|
379
|
+
noise_predictor=self.noise_predictor,
|
|
380
|
+
forward_diffusion=self.forward_ddpm,
|
|
381
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
382
|
+
data_loader=self.train_loader,
|
|
383
|
+
optimizer=self.optimizer,
|
|
384
|
+
objective=self.objective,
|
|
385
|
+
val_loader=self.val_loader,
|
|
386
|
+
metrics_=metrics,
|
|
387
|
+
max_epochs=2,
|
|
388
|
+
device='cpu',
|
|
389
|
+
store_path=self.test_dir,
|
|
390
|
+
val_frequency=1,
|
|
391
|
+
log_frequency=1
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
train_losses, best_val_loss = trainer()
|
|
395
|
+
|
|
396
|
+
self.assertEqual(len(train_losses), 2)
|
|
397
|
+
self.assertIsInstance(best_val_loss, float)
|
|
398
|
+
|
|
399
|
+
def test_training_with_conditional_model(self):
|
|
400
|
+
"""Test training with conditional model"""
|
|
401
|
+
conditional_model = MockConditionalModel()
|
|
402
|
+
|
|
403
|
+
# Modify data loader to return string labels
|
|
404
|
+
x_data = torch.randn(16, self.channels, self.height, self.width)
|
|
405
|
+
y_data = ['test prompt'] * 16
|
|
406
|
+
|
|
407
|
+
# Create custom dataset that returns proper string prompts
|
|
408
|
+
class StringDataset:
|
|
409
|
+
def __init__(self, x_data, y_data):
|
|
410
|
+
self.x_data = x_data
|
|
411
|
+
self.y_data = y_data
|
|
412
|
+
|
|
413
|
+
def __len__(self):
|
|
414
|
+
return len(self.x_data)
|
|
415
|
+
|
|
416
|
+
def __getitem__(self, idx):
|
|
417
|
+
return self.x_data[idx], self.y_data[idx]
|
|
418
|
+
|
|
419
|
+
string_dataset = StringDataset(x_data, y_data)
|
|
420
|
+
string_loader = DataLoader(string_dataset, batch_size=4, shuffle=True)
|
|
421
|
+
|
|
422
|
+
trainer = TrainDDPM(
|
|
423
|
+
noise_predictor=self.noise_predictor,
|
|
424
|
+
forward_diffusion=self.forward_ddpm,
|
|
425
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
426
|
+
data_loader=string_loader,
|
|
427
|
+
optimizer=self.optimizer,
|
|
428
|
+
objective=self.objective,
|
|
429
|
+
conditional_model=conditional_model,
|
|
430
|
+
max_epochs=1,
|
|
431
|
+
device='cpu',
|
|
432
|
+
store_path=self.test_dir,
|
|
433
|
+
log_frequency=1
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
train_losses, best_val_loss = trainer()
|
|
437
|
+
self.assertEqual(len(train_losses), 1)
|
|
438
|
+
|
|
439
|
+
def test_gradient_accumulation(self):
|
|
440
|
+
"""Test training with gradient accumulation"""
|
|
441
|
+
trainer = TrainDDPM(
|
|
442
|
+
noise_predictor=self.noise_predictor,
|
|
443
|
+
forward_diffusion=self.forward_ddpm,
|
|
444
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
445
|
+
data_loader=self.train_loader,
|
|
446
|
+
optimizer=self.optimizer,
|
|
447
|
+
objective=self.objective,
|
|
448
|
+
max_epochs=1,
|
|
449
|
+
device='cpu',
|
|
450
|
+
store_path=self.test_dir,
|
|
451
|
+
grad_accumulation_steps=2,
|
|
452
|
+
log_frequency=1
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
train_losses, best_val_loss = trainer()
|
|
456
|
+
self.assertEqual(len(train_losses), 1)
|
|
457
|
+
|
|
458
|
+
def test_warmup_scheduler(self):
|
|
459
|
+
"""Test warmup scheduler functionality"""
|
|
460
|
+
optimizer = torch.optim.Adam(self.noise_predictor.parameters(), lr=1e-3)
|
|
461
|
+
warmup_epochs = 5
|
|
462
|
+
|
|
463
|
+
scheduler = TrainDDPM.warmup_scheduler(optimizer, warmup_epochs)
|
|
464
|
+
|
|
465
|
+
# Test the lambda function directly
|
|
466
|
+
lr_lambda = scheduler.lr_lambdas[0]
|
|
467
|
+
|
|
468
|
+
# Test warmup phase - lr should increase linearly from 0 to 1
|
|
469
|
+
for epoch in range(warmup_epochs):
|
|
470
|
+
expected_lr = epoch / warmup_epochs
|
|
471
|
+
actual_lr = lr_lambda(epoch)
|
|
472
|
+
self.assertAlmostEqual(actual_lr, expected_lr, places=5)
|
|
473
|
+
|
|
474
|
+
# Test post-warmup phase - should be 1.0
|
|
475
|
+
self.assertEqual(lr_lambda(warmup_epochs), 1.0)
|
|
476
|
+
self.assertEqual(lr_lambda(warmup_epochs + 1), 1.0)
|
|
477
|
+
self.assertEqual(lr_lambda(warmup_epochs + 10), 1.0) # Much later
|
|
478
|
+
|
|
479
|
+
# Test edge cases
|
|
480
|
+
self.assertEqual(lr_lambda(0), 0.0) # At start, lr multiplier should be 0
|
|
481
|
+
self.assertAlmostEqual(lr_lambda(warmup_epochs - 1), (warmup_epochs - 1) / warmup_epochs, places=5)
|
|
482
|
+
|
|
483
|
+
# Test with different warmup_epochs value
|
|
484
|
+
scheduler2 = TrainDDPM.warmup_scheduler(optimizer, 10)
|
|
485
|
+
lr_lambda2 = scheduler2.lr_lambdas[0]
|
|
486
|
+
|
|
487
|
+
self.assertEqual(lr_lambda2(0), 0.0)
|
|
488
|
+
self.assertAlmostEqual(lr_lambda2(5), 0.5, places=5)
|
|
489
|
+
self.assertEqual(lr_lambda2(10), 1.0)
|
|
490
|
+
|
|
491
|
+
@patch('torch.save')
|
|
492
|
+
def test_checkpoint_saving(self, mock_save):
|
|
493
|
+
"""Test checkpoint saving functionality"""
|
|
494
|
+
trainer = TrainDDPM(
|
|
495
|
+
noise_predictor=self.noise_predictor,
|
|
496
|
+
forward_diffusion=self.forward_ddpm,
|
|
497
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
498
|
+
data_loader=self.train_loader,
|
|
499
|
+
optimizer=self.optimizer,
|
|
500
|
+
objective=self.objective,
|
|
501
|
+
max_epochs=1,
|
|
502
|
+
device='cpu',
|
|
503
|
+
store_path=self.test_dir,
|
|
504
|
+
val_frequency=1
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Test private checkpoint saving method
|
|
508
|
+
trainer._save_checkpoint(1, 0.5)
|
|
509
|
+
|
|
510
|
+
# Verify torch.save was called
|
|
511
|
+
self.assertTrue(mock_save.called)
|
|
512
|
+
|
|
513
|
+
def test_checkpoint_loading(self):
|
|
514
|
+
"""Test checkpoint loading functionality"""
|
|
515
|
+
trainer = TrainDDPM(
|
|
516
|
+
noise_predictor=self.noise_predictor,
|
|
517
|
+
forward_diffusion=self.forward_ddpm,
|
|
518
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
519
|
+
data_loader=self.train_loader,
|
|
520
|
+
optimizer=self.optimizer,
|
|
521
|
+
objective=self.objective,
|
|
522
|
+
max_epochs=1,
|
|
523
|
+
device='cpu',
|
|
524
|
+
store_path=self.test_dir
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
# Create a mock checkpoint
|
|
528
|
+
checkpoint = {
|
|
529
|
+
'epoch': 5,
|
|
530
|
+
'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
|
|
531
|
+
'model_state_dict_conditional': None,
|
|
532
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
533
|
+
'loss': 0.5,
|
|
534
|
+
'variance_scheduler_model': self.variance_scheduler.state_dict(),
|
|
535
|
+
'max_epochs': 10,
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
checkpoint_path = os.path.join(self.test_dir, 'test_checkpoint.pth')
|
|
539
|
+
torch.save(checkpoint, checkpoint_path)
|
|
540
|
+
|
|
541
|
+
# Test loading
|
|
542
|
+
epoch, loss = trainer.load_checkpoint(checkpoint_path)
|
|
543
|
+
|
|
544
|
+
self.assertEqual(epoch, 5)
|
|
545
|
+
self.assertEqual(loss, 0.5)
|
|
546
|
+
|
|
547
|
+
def test_validation_method(self):
|
|
548
|
+
"""Test validation method"""
|
|
549
|
+
metrics = MockMetrics()
|
|
550
|
+
|
|
551
|
+
trainer = TrainDDPM(
|
|
552
|
+
noise_predictor=self.noise_predictor,
|
|
553
|
+
forward_diffusion=self.forward_ddpm,
|
|
554
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
555
|
+
data_loader=self.train_loader,
|
|
556
|
+
optimizer=self.optimizer,
|
|
557
|
+
objective=self.objective,
|
|
558
|
+
val_loader=self.val_loader,
|
|
559
|
+
metrics_=metrics,
|
|
560
|
+
max_epochs=1,
|
|
561
|
+
device='cpu',
|
|
562
|
+
store_path=self.test_dir
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# Call validation
|
|
566
|
+
val_metrics = trainer.validate()
|
|
567
|
+
|
|
568
|
+
self.assertEqual(len(val_metrics), 6) # val_loss, fid, mse, psnr, ssim, lpips
|
|
569
|
+
val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
|
|
570
|
+
self.assertIsInstance(val_loss, float)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
class TestSampleDDPM(unittest.TestCase):
|
|
574
|
+
"""Test cases for SampleDDPM"""
|
|
575
|
+
|
|
576
|
+
def setUp(self):
|
|
577
|
+
self.test_dir = tempfile.mkdtemp()
|
|
578
|
+
|
|
579
|
+
self.variance_scheduler = VarianceSchedulerDDPM(num_steps=50, trainable_beta=False)
|
|
580
|
+
self.reverse_ddpm = ReverseDDPM(self.variance_scheduler)
|
|
581
|
+
self.noise_predictor = MockNoisePredictor(in_channels=3)
|
|
582
|
+
self.conditional_model = MockConditionalModel()
|
|
583
|
+
|
|
584
|
+
self.image_shape = (32, 32)
|
|
585
|
+
self.batch_size = 4
|
|
586
|
+
|
|
587
|
+
def tearDown(self):
|
|
588
|
+
shutil.rmtree(self.test_dir, ignore_errors=True)
|
|
589
|
+
|
|
590
|
+
def test_initialization(self):
|
|
591
|
+
"""Test SampleDDPM initialization"""
|
|
592
|
+
sampler = SampleDDPM(
|
|
593
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
594
|
+
noise_predictor=self.noise_predictor,
|
|
595
|
+
image_shape=self.image_shape,
|
|
596
|
+
batch_size=self.batch_size,
|
|
597
|
+
device='cpu'
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
self.assertEqual(sampler.image_shape, self.image_shape)
|
|
601
|
+
self.assertEqual(sampler.batch_size, self.batch_size)
|
|
602
|
+
self.assertEqual(sampler.device, torch.device('cpu'))
|
|
603
|
+
|
|
604
|
+
def test_invalid_initialization(self):
|
|
605
|
+
"""Test initialization with invalid parameters"""
|
|
606
|
+
with self.assertRaises(ValueError):
|
|
607
|
+
SampleDDPM(
|
|
608
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
609
|
+
noise_predictor=self.noise_predictor,
|
|
610
|
+
image_shape=(32,), # Invalid shape
|
|
611
|
+
batch_size=self.batch_size
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
with self.assertRaises(ValueError):
|
|
615
|
+
SampleDDPM(
|
|
616
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
617
|
+
noise_predictor=self.noise_predictor,
|
|
618
|
+
image_shape=self.image_shape,
|
|
619
|
+
batch_size=0 # Invalid batch size
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
def test_tokenization(self):
|
|
623
|
+
"""Test text tokenization"""
|
|
624
|
+
sampler = SampleDDPM(
|
|
625
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
626
|
+
noise_predictor=self.noise_predictor,
|
|
627
|
+
image_shape=self.image_shape,
|
|
628
|
+
conditional_model=self.conditional_model,
|
|
629
|
+
batch_size=self.batch_size,
|
|
630
|
+
device='cpu'
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
# Test single prompt
|
|
634
|
+
prompts = "a beautiful landscape"
|
|
635
|
+
input_ids, attention_mask = sampler.tokenize(prompts)
|
|
636
|
+
|
|
637
|
+
self.assertEqual(input_ids.shape[0], 1)
|
|
638
|
+
self.assertEqual(attention_mask.shape[0], 1)
|
|
639
|
+
|
|
640
|
+
# Test multiple prompts
|
|
641
|
+
prompts = ["prompt 1", "prompt 2", "prompt 3"]
|
|
642
|
+
input_ids, attention_mask = sampler.tokenize(prompts)
|
|
643
|
+
|
|
644
|
+
self.assertEqual(input_ids.shape[0], 3)
|
|
645
|
+
self.assertEqual(attention_mask.shape[0], 3)
|
|
646
|
+
|
|
647
|
+
def test_unconditional_generation(self):
|
|
648
|
+
"""Test unconditional image generation"""
|
|
649
|
+
sampler = SampleDDPM(
|
|
650
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
651
|
+
noise_predictor=self.noise_predictor,
|
|
652
|
+
image_shape=self.image_shape,
|
|
653
|
+
batch_size=self.batch_size,
|
|
654
|
+
device='cpu'
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
generated_imgs = sampler(
|
|
658
|
+
conditions=None,
|
|
659
|
+
normalize_output=True,
|
|
660
|
+
save_images=False
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
expected_shape = (self.batch_size, 3, self.image_shape[0], self.image_shape[1])
|
|
664
|
+
self.assertEqual(generated_imgs.shape, expected_shape)
|
|
665
|
+
|
|
666
|
+
# Check that images are normalized to [0, 1]
|
|
667
|
+
self.assertTrue(torch.all(generated_imgs >= 0))
|
|
668
|
+
self.assertTrue(torch.all(generated_imgs <= 1))
|
|
669
|
+
|
|
670
|
+
def test_conditional_generation(self):
|
|
671
|
+
"""Test conditional image generation"""
|
|
672
|
+
sampler = SampleDDPM(
|
|
673
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
674
|
+
noise_predictor=self.noise_predictor,
|
|
675
|
+
image_shape=self.image_shape,
|
|
676
|
+
conditional_model=self.conditional_model,
|
|
677
|
+
batch_size=self.batch_size,
|
|
678
|
+
device='cpu'
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
prompts = ["a beautiful sunset", "a mountain landscape", "a city skyline", "a forest path"]
|
|
682
|
+
|
|
683
|
+
generated_imgs = sampler(
|
|
684
|
+
conditions=prompts,
|
|
685
|
+
normalize_output=True,
|
|
686
|
+
save_images=False
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
expected_shape = (self.batch_size, 3, self.image_shape[0], self.image_shape[1])
|
|
690
|
+
self.assertEqual(generated_imgs.shape, expected_shape)
|
|
691
|
+
|
|
692
|
+
def test_image_saving(self):
|
|
693
|
+
"""Test image saving functionality"""
|
|
694
|
+
sampler = SampleDDPM(
|
|
695
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
696
|
+
noise_predictor=self.noise_predictor,
|
|
697
|
+
image_shape=self.image_shape,
|
|
698
|
+
batch_size=2, # Use smaller batch for faster test
|
|
699
|
+
device='cpu'
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
generated_imgs = sampler(
|
|
703
|
+
conditions=None,
|
|
704
|
+
save_images=True,
|
|
705
|
+
save_path=self.test_dir
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
# Check that images were saved
|
|
709
|
+
saved_files = os.listdir(self.test_dir)
|
|
710
|
+
self.assertEqual(len(saved_files), 2) # 2 images
|
|
711
|
+
self.assertTrue(all(f.endswith('.png') for f in saved_files))
|
|
712
|
+
|
|
713
|
+
def test_device_transfer(self):
|
|
714
|
+
"""Test device transfer functionality"""
|
|
715
|
+
sampler = SampleDDPM(
|
|
716
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
717
|
+
noise_predictor=self.noise_predictor,
|
|
718
|
+
image_shape=self.image_shape,
|
|
719
|
+
batch_size=self.batch_size,
|
|
720
|
+
device='cpu'
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# Test moving to CPU (should work)
|
|
724
|
+
sampler_moved = sampler.to(torch.device('cpu'))
|
|
725
|
+
self.assertEqual(sampler_moved.device, torch.device('cpu'))
|
|
726
|
+
|
|
727
|
+
def test_error_conditions(self):
|
|
728
|
+
"""Test error conditions in sampling"""
|
|
729
|
+
# Test conditional generation without conditional model
|
|
730
|
+
sampler = SampleDDPM(
|
|
731
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
732
|
+
noise_predictor=self.noise_predictor,
|
|
733
|
+
image_shape=self.image_shape,
|
|
734
|
+
batch_size=self.batch_size,
|
|
735
|
+
device='cpu'
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
with self.assertRaises(ValueError):
|
|
739
|
+
sampler(conditions=["test prompt"])
|
|
740
|
+
|
|
741
|
+
# Test unconditional generation with conditional model
|
|
742
|
+
sampler_conditional = SampleDDPM(
|
|
743
|
+
reverse_diffusion=self.reverse_ddpm,
|
|
744
|
+
noise_predictor=self.noise_predictor,
|
|
745
|
+
image_shape=self.image_shape,
|
|
746
|
+
conditional_model=self.conditional_model,
|
|
747
|
+
batch_size=self.batch_size,
|
|
748
|
+
device='cpu'
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
with self.assertRaises(ValueError):
|
|
752
|
+
sampler_conditional(conditions=None)
|
|
753
|
+
|
|
754
|
+
class TestIntegration(unittest.TestCase):
|
|
755
|
+
"""Integration tests for the complete DDPM pipeline"""
|
|
756
|
+
|
|
757
|
+
def setUp(self):
|
|
758
|
+
self.test_dir = tempfile.mkdtemp()
|
|
759
|
+
|
|
760
|
+
# Create a simple dataset
|
|
761
|
+
self.batch_size = 4
|
|
762
|
+
self.channels = 1
|
|
763
|
+
self.height = 16 # Smaller for faster tests
|
|
764
|
+
self.width = 16
|
|
765
|
+
|
|
766
|
+
x_data = torch.randn(16, self.channels, self.height, self.width)
|
|
767
|
+
y_data = torch.arange(16)
|
|
768
|
+
dataset = TensorDataset(x_data, y_data)
|
|
769
|
+
|
|
770
|
+
self.data_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
|
|
771
|
+
|
|
772
|
+
def tearDown(self):
|
|
773
|
+
shutil.rmtree(self.test_dir, ignore_errors=True)
|
|
774
|
+
|
|
775
|
+
def test_complete_pipeline_unconditional(self):
|
|
776
|
+
"""Test complete pipeline: training -> sampling (unconditional)"""
|
|
777
|
+
# Setup components
|
|
778
|
+
variance_scheduler = VarianceSchedulerDDPM(num_steps=20, trainable_beta=False) # Small for speed
|
|
779
|
+
forward_ddpm = ForwardDDPM(variance_scheduler)
|
|
780
|
+
reverse_ddpm = ReverseDDPM(variance_scheduler)
|
|
781
|
+
noise_predictor = MockNoisePredictor(self.channels)
|
|
782
|
+
optimizer = torch.optim.Adam(noise_predictor.parameters(), lr=1e-3)
|
|
783
|
+
objective = nn.MSELoss()
|
|
784
|
+
|
|
785
|
+
# Training
|
|
786
|
+
trainer = TrainDDPM(
|
|
787
|
+
noise_predictor=noise_predictor,
|
|
788
|
+
forward_diffusion=forward_ddpm,
|
|
789
|
+
reverse_diffusion=reverse_ddpm,
|
|
790
|
+
data_loader=self.data_loader,
|
|
791
|
+
optimizer=optimizer,
|
|
792
|
+
objective=objective,
|
|
793
|
+
max_epochs=1, # Just 1 epoch for speed
|
|
794
|
+
device='cpu',
|
|
795
|
+
store_path=self.test_dir,
|
|
796
|
+
log_frequency=1
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
train_losses, best_val_loss = trainer()
|
|
800
|
+
|
|
801
|
+
# Sampling
|
|
802
|
+
sampler = SampleDDPM(
|
|
803
|
+
reverse_diffusion=reverse_ddpm,
|
|
804
|
+
noise_predictor=noise_predictor,
|
|
805
|
+
image_shape=(self.height, self.width),
|
|
806
|
+
batch_size=2,
|
|
807
|
+
in_channels=self.channels,
|
|
808
|
+
device='cpu'
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
generated_imgs = sampler(save_images=False)
|
|
812
|
+
|
|
813
|
+
# Verify results
|
|
814
|
+
self.assertEqual(len(train_losses), 1)
|
|
815
|
+
expected_shape = (2, self.channels, self.height, self.width)
|
|
816
|
+
self.assertEqual(generated_imgs.shape, expected_shape)
|
|
817
|
+
|
|
818
|
+
def test_complete_pipeline_conditional(self):
|
|
819
|
+
"""Test complete pipeline with conditional generation"""
|
|
820
|
+
|
|
821
|
+
# Create string dataset for conditional training
|
|
822
|
+
class StringDataset:
|
|
823
|
+
def __init__(self, x_data, prompts):
|
|
824
|
+
self.x_data = x_data
|
|
825
|
+
self.prompts = prompts
|
|
826
|
+
|
|
827
|
+
def __len__(self):
|
|
828
|
+
return len(self.x_data)
|
|
829
|
+
|
|
830
|
+
def __getitem__(self, idx):
|
|
831
|
+
return self.x_data[idx], self.prompts[idx]
|
|
832
|
+
|
|
833
|
+
x_data = torch.randn(8, self.channels, self.height, self.width)
|
|
834
|
+
prompts = [f"image {i}" for i in range(8)]
|
|
835
|
+
string_dataset = StringDataset(x_data, prompts)
|
|
836
|
+
string_loader = DataLoader(string_dataset, batch_size=2, shuffle=True)
|
|
837
|
+
|
|
838
|
+
# Setup components
|
|
839
|
+
variance_scheduler = VarianceSchedulerDDPM(num_steps=20, trainable_beta=False)
|
|
840
|
+
forward_ddpm = ForwardDDPM(variance_scheduler)
|
|
841
|
+
reverse_ddpm = ReverseDDPM(variance_scheduler)
|
|
842
|
+
noise_predictor = MockNoisePredictor(self.channels)
|
|
843
|
+
conditional_model = MockConditionalModel()
|
|
844
|
+
optimizer = torch.optim.Adam([
|
|
845
|
+
*noise_predictor.parameters(),
|
|
846
|
+
*conditional_model.parameters()
|
|
847
|
+
], lr=1e-3)
|
|
848
|
+
objective = nn.MSELoss()
|
|
849
|
+
|
|
850
|
+
# Training
|
|
851
|
+
trainer = TrainDDPM(
|
|
852
|
+
noise_predictor=noise_predictor,
|
|
853
|
+
forward_diffusion=forward_ddpm,
|
|
854
|
+
reverse_diffusion=reverse_ddpm,
|
|
855
|
+
data_loader=string_loader,
|
|
856
|
+
optimizer=optimizer,
|
|
857
|
+
objective=objective,
|
|
858
|
+
conditional_model=conditional_model,
|
|
859
|
+
max_epochs=1,
|
|
860
|
+
device='cpu',
|
|
861
|
+
store_path=self.test_dir,
|
|
862
|
+
log_frequency=1
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
train_losses, best_val_loss = trainer()
|
|
866
|
+
|
|
867
|
+
# Conditional sampling
|
|
868
|
+
sampler = SampleDDPM(
|
|
869
|
+
reverse_diffusion=reverse_ddpm,
|
|
870
|
+
noise_predictor=noise_predictor,
|
|
871
|
+
image_shape=(self.height, self.width),
|
|
872
|
+
conditional_model=conditional_model,
|
|
873
|
+
batch_size=2,
|
|
874
|
+
in_channels=self.channels,
|
|
875
|
+
device='cpu'
|
|
876
|
+
)
|
|
877
|
+
|
|
878
|
+
test_prompts = ["test image 1", "test image 2"]
|
|
879
|
+
generated_imgs = sampler(conditions=test_prompts, save_images=False)
|
|
880
|
+
|
|
881
|
+
# Verify results
|
|
882
|
+
self.assertEqual(len(train_losses), 1)
|
|
883
|
+
expected_shape = (2, self.channels, self.height, self.width)
|
|
884
|
+
self.assertEqual(generated_imgs.shape, expected_shape)
|
|
885
|
+
|
|
886
|
+
def test_trainable_variance_schedule(self):
|
|
887
|
+
"""Test pipeline with trainable variance schedule"""
|
|
888
|
+
variance_scheduler = VarianceSchedulerDDPM(num_steps=20, trainable_beta=True)
|
|
889
|
+
forward_ddpm = ForwardDDPM(variance_scheduler)
|
|
890
|
+
reverse_ddpm = ReverseDDPM(variance_scheduler)
|
|
891
|
+
noise_predictor = MockNoisePredictor(self.channels)
|
|
892
|
+
|
|
893
|
+
# Include variance scheduler parameters in optimizer
|
|
894
|
+
all_params = list(noise_predictor.parameters()) + list(variance_scheduler.parameters())
|
|
895
|
+
optimizer = torch.optim.Adam(all_params, lr=1e-3)
|
|
896
|
+
objective = nn.MSELoss()
|
|
897
|
+
|
|
898
|
+
trainer = TrainDDPM(
|
|
899
|
+
noise_predictor=noise_predictor,
|
|
900
|
+
forward_diffusion=forward_ddpm,
|
|
901
|
+
reverse_diffusion=reverse_ddpm,
|
|
902
|
+
data_loader=self.data_loader,
|
|
903
|
+
optimizer=optimizer,
|
|
904
|
+
objective=objective,
|
|
905
|
+
max_epochs=1,
|
|
906
|
+
device='cpu',
|
|
907
|
+
store_path=self.test_dir,
|
|
908
|
+
log_frequency=1
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
train_losses, best_val_loss = trainer()
|
|
912
|
+
self.assertEqual(len(train_losses), 1)
|
|
913
|
+
|
|
914
|
+
def test_checkpoint_save_load_cycle(self):
|
|
915
|
+
"""Test saving and loading checkpoints"""
|
|
916
|
+
variance_scheduler = VarianceSchedulerDDPM(num_steps=20, trainable_beta=False)
|
|
917
|
+
forward_ddpm = ForwardDDPM(variance_scheduler)
|
|
918
|
+
reverse_ddpm = ReverseDDPM(variance_scheduler)
|
|
919
|
+
noise_predictor = MockNoisePredictor(self.channels)
|
|
920
|
+
optimizer = torch.optim.Adam(noise_predictor.parameters(), lr=1e-3)
|
|
921
|
+
objective = nn.MSELoss()
|
|
922
|
+
|
|
923
|
+
# Create first trainer and train
|
|
924
|
+
trainer1 = TrainDDPM(
|
|
925
|
+
noise_predictor=noise_predictor,
|
|
926
|
+
forward_diffusion=forward_ddpm,
|
|
927
|
+
reverse_diffusion=reverse_ddpm,
|
|
928
|
+
data_loader=self.data_loader,
|
|
929
|
+
optimizer=optimizer,
|
|
930
|
+
objective=objective,
|
|
931
|
+
max_epochs=1,
|
|
932
|
+
device='cpu',
|
|
933
|
+
store_path=self.test_dir,
|
|
934
|
+
val_frequency=1
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
train_losses1, _ = trainer1()
|
|
938
|
+
|
|
939
|
+
# Create new components and load checkpoint
|
|
940
|
+
new_noise_predictor = MockNoisePredictor(self.channels)
|
|
941
|
+
new_optimizer = torch.optim.Adam(new_noise_predictor.parameters(), lr=1e-3)
|
|
942
|
+
|
|
943
|
+
trainer2 = TrainDDPM(
|
|
944
|
+
noise_predictor=new_noise_predictor,
|
|
945
|
+
forward_diffusion=forward_ddpm,
|
|
946
|
+
reverse_diffusion=reverse_ddpm,
|
|
947
|
+
data_loader=self.data_loader,
|
|
948
|
+
optimizer=new_optimizer,
|
|
949
|
+
objective=objective,
|
|
950
|
+
max_epochs=1,
|
|
951
|
+
device='cpu',
|
|
952
|
+
store_path=self.test_dir
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
# Find checkpoint file
|
|
956
|
+
checkpoint_files = [f for f in os.listdir(self.test_dir) if f.endswith('.pth')]
|
|
957
|
+
self.assertTrue(len(checkpoint_files) > 0, "No checkpoint file found")
|
|
958
|
+
|
|
959
|
+
checkpoint_path = os.path.join(self.test_dir, checkpoint_files[0])
|
|
960
|
+
epoch, loss = trainer2.load_checkpoint(checkpoint_path)
|
|
961
|
+
|
|
962
|
+
self.assertIsInstance(epoch, int)
|
|
963
|
+
self.assertIsInstance(loss, float)
|
|
964
|
+
|
|
965
|
+
class TestEdgeCases(unittest.TestCase):
|
|
966
|
+
"""Test edge cases and error conditions"""
|
|
967
|
+
|
|
968
|
+
def test_different_beta_schedules_consistency(self):
|
|
969
|
+
"""Test that different beta schedules produce valid results"""
|
|
970
|
+
methods = ["linear", "sigmoid", "quadratic", "constant", "inverse_time"]
|
|
971
|
+
|
|
972
|
+
for method in methods:
|
|
973
|
+
scheduler = VarianceSchedulerDDPM(
|
|
974
|
+
num_steps=50,
|
|
975
|
+
beta_method=method,
|
|
976
|
+
trainable_beta=False
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
forward_ddpm = ForwardDDPM(scheduler)
|
|
980
|
+
reverse_ddpm = ReverseDDPM(scheduler)
|
|
981
|
+
|
|
982
|
+
# Test forward-reverse consistency
|
|
983
|
+
x0 = torch.randn(2, 1, 16, 16)
|
|
984
|
+
noise = torch.randn_like(x0)
|
|
985
|
+
t = torch.randint(1, scheduler.num_steps - 1, (2,))
|
|
986
|
+
|
|
987
|
+
# Forward process
|
|
988
|
+
xt = forward_ddpm(x0, noise, t)
|
|
989
|
+
|
|
990
|
+
# Reverse process (using ground truth noise)
|
|
991
|
+
xt_minus_1 = reverse_ddpm(xt, noise, t)
|
|
992
|
+
|
|
993
|
+
# Results should be finite
|
|
994
|
+
self.assertTrue(torch.all(torch.isfinite(xt)))
|
|
995
|
+
self.assertTrue(torch.all(torch.isfinite(xt_minus_1)))
|
|
996
|
+
|
|
997
|
+
def test_extreme_timesteps(self):
|
|
998
|
+
"""Test behavior at extreme timesteps"""
|
|
999
|
+
scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=False)
|
|
1000
|
+
forward_ddpm = ForwardDDPM(scheduler)
|
|
1001
|
+
reverse_ddpm = ReverseDDPM(scheduler)
|
|
1002
|
+
|
|
1003
|
+
x0 = torch.randn(2, 1, 16, 16)
|
|
1004
|
+
noise = torch.randn_like(x0)
|
|
1005
|
+
|
|
1006
|
+
# Test t=0 (minimal noise)
|
|
1007
|
+
t_zero = torch.zeros(2, dtype=torch.long)
|
|
1008
|
+
xt_zero = forward_ddpm(x0, noise, t_zero)
|
|
1009
|
+
|
|
1010
|
+
# At t=0, result should be very close to original
|
|
1011
|
+
alpha_bar_0 = scheduler.alpha_bars[0]
|
|
1012
|
+
expected_xt_zero = torch.sqrt(alpha_bar_0) * x0 + torch.sqrt(1 - alpha_bar_0) * noise
|
|
1013
|
+
torch.testing.assert_close(xt_zero, expected_xt_zero, atol=1e-6, rtol=1e-6)
|
|
1014
|
+
|
|
1015
|
+
# Test t=max-1 (maximum noise)
|
|
1016
|
+
t_max = torch.full((2,), scheduler.num_steps - 1, dtype=torch.long)
|
|
1017
|
+
xt_max = forward_ddpm(x0, noise, t_max)
|
|
1018
|
+
|
|
1019
|
+
# Result should be mostly noise
|
|
1020
|
+
self.assertTrue(torch.all(torch.isfinite(xt_max)))
|
|
1021
|
+
|
|
1022
|
+
def test_batch_size_consistency(self):
|
|
1023
|
+
"""Test consistency across different batch sizes"""
|
|
1024
|
+
scheduler = VarianceSchedulerDDPM(num_steps=50, trainable_beta=False)
|
|
1025
|
+
forward_ddpm = ForwardDDPM(scheduler)
|
|
1026
|
+
|
|
1027
|
+
# Test with different batch sizes
|
|
1028
|
+
for batch_size in [1, 4, 8]:
|
|
1029
|
+
x0 = torch.randn(batch_size, 3, 16, 16)
|
|
1030
|
+
noise = torch.randn_like(x0)
|
|
1031
|
+
t = torch.randint(0, scheduler.num_steps, (batch_size,))
|
|
1032
|
+
|
|
1033
|
+
xt = forward_ddpm(x0, noise, t)
|
|
1034
|
+
|
|
1035
|
+
self.assertEqual(xt.shape[0], batch_size)
|
|
1036
|
+
self.assertTrue(torch.all(torch.isfinite(xt)))
|
|
1037
|
+
|
|
1038
|
+
def test_memory_efficiency(self):
|
|
1039
|
+
"""Test that operations don't cause memory leaks (basic check)"""
|
|
1040
|
+
scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=False)
|
|
1041
|
+
forward_ddpm = ForwardDDPM(scheduler)
|
|
1042
|
+
|
|
1043
|
+
initial_memory = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
|
|
1044
|
+
|
|
1045
|
+
# Perform multiple operations
|
|
1046
|
+
for _ in range(10):
|
|
1047
|
+
x0 = torch.randn(4, 3, 32, 32)
|
|
1048
|
+
noise = torch.randn_like(x0)
|
|
1049
|
+
t = torch.randint(0, scheduler.num_steps, (4,))
|
|
1050
|
+
|
|
1051
|
+
xt = forward_ddpm(x0, noise, t)
|
|
1052
|
+
del xt, x0, noise, t
|
|
1053
|
+
|
|
1054
|
+
# Memory usage shouldn't grow significantly
|
|
1055
|
+
if torch.cuda.is_available():
|
|
1056
|
+
torch.cuda.empty_cache()
|
|
1057
|
+
final_memory = torch.cuda.memory_allocated()
|
|
1058
|
+
# Allow for some variance in memory usage
|
|
1059
|
+
self.assertLess(final_memory - initial_memory, 100 * 1024 * 1024) # 100MB threshold
|
|
1060
|
+
|
|
1061
|
+
def run_performance_benchmarks():
|
|
1062
|
+
"""Optional performance benchmarks (not part of main test suite)"""
|
|
1063
|
+
print("\n" + "=" * 50)
|
|
1064
|
+
print("PERFORMANCE BENCHMARKS")
|
|
1065
|
+
print("=" * 50)
|
|
1066
|
+
|
|
1067
|
+
import time
|
|
1068
|
+
|
|
1069
|
+
# Benchmark forward diffusion
|
|
1070
|
+
scheduler = VarianceSchedulerDDPM(num_steps=1000, trainable_beta=False)
|
|
1071
|
+
forward_ddpm = ForwardDDPM(scheduler)
|
|
1072
|
+
|
|
1073
|
+
x0 = torch.randn(32, 3, 64, 64)
|
|
1074
|
+
noise = torch.randn_like(x0)
|
|
1075
|
+
t = torch.randint(0, scheduler.num_steps, (32,))
|
|
1076
|
+
|
|
1077
|
+
# Warmup
|
|
1078
|
+
for _ in range(5):
|
|
1079
|
+
_ = forward_ddpm(x0, noise, t)
|
|
1080
|
+
|
|
1081
|
+
# Benchmark
|
|
1082
|
+
start_time = time.time()
|
|
1083
|
+
for _ in range(100):
|
|
1084
|
+
xt = forward_ddpm(x0, noise, t)
|
|
1085
|
+
end_time = time.time()
|
|
1086
|
+
|
|
1087
|
+
avg_time = (end_time - start_time) / 100
|
|
1088
|
+
print(f"Forward diffusion avg time: {avg_time * 1000:.2f}ms per batch")
|
|
1089
|
+
print(f"Throughput: {32 / avg_time:.0f} images/second")
|
|
1090
|
+
|
|
1091
|
+
class DDPMTestSuite:
|
|
1092
|
+
"""Main test suite runner with custom reporting"""
|
|
1093
|
+
|
|
1094
|
+
def __init__(self):
|
|
1095
|
+
self.test_classes = [
|
|
1096
|
+
TestVarianceSchedulerDDPM,
|
|
1097
|
+
TestForwardDDPM,
|
|
1098
|
+
TestReverseDDPM,
|
|
1099
|
+
TestTrainDDPM,
|
|
1100
|
+
TestSampleDDPM,
|
|
1101
|
+
TestIntegration,
|
|
1102
|
+
TestEdgeCases
|
|
1103
|
+
]
|
|
1104
|
+
|
|
1105
|
+
def run_all_tests(self, verbose=True):
|
|
1106
|
+
"""Run all tests with custom reporting"""
|
|
1107
|
+
print("=" * 60)
|
|
1108
|
+
print("DDPM IMPLEMENTATION TEST SUITE")
|
|
1109
|
+
print("=" * 60)
|
|
1110
|
+
|
|
1111
|
+
total_tests = 0
|
|
1112
|
+
total_failures = 0
|
|
1113
|
+
total_errors = 0
|
|
1114
|
+
|
|
1115
|
+
for test_class in self.test_classes:
|
|
1116
|
+
print(f"\nRunning {test_class.__name__}...")
|
|
1117
|
+
|
|
1118
|
+
# Create test suite for this class
|
|
1119
|
+
suite = unittest.TestLoader().loadTestsFromTestCase(test_class)
|
|
1120
|
+
|
|
1121
|
+
# Run tests with custom result handling
|
|
1122
|
+
result = unittest.TextTestRunner(
|
|
1123
|
+
verbosity=2 if verbose else 1,
|
|
1124
|
+
stream=open(os.devnull, 'w') if not verbose else None
|
|
1125
|
+
).run(suite)
|
|
1126
|
+
|
|
1127
|
+
# Count results
|
|
1128
|
+
tests_run = result.testsRun
|
|
1129
|
+
failures = len(result.failures)
|
|
1130
|
+
errors = len(result.errors)
|
|
1131
|
+
|
|
1132
|
+
total_tests += tests_run
|
|
1133
|
+
total_failures += failures
|
|
1134
|
+
total_errors += errors
|
|
1135
|
+
|
|
1136
|
+
# Print summary for this class
|
|
1137
|
+
status = "PASS" if (failures == 0 and errors == 0) else "FAIL"
|
|
1138
|
+
print(f" {test_class.__name__}: {status}")
|
|
1139
|
+
print(f" Tests: {tests_run}, Failures: {failures}, Errors: {errors}")
|
|
1140
|
+
|
|
1141
|
+
# Print failure details
|
|
1142
|
+
if failures > 0:
|
|
1143
|
+
print(" FAILURES:")
|
|
1144
|
+
for test, traceback in result.failures:
|
|
1145
|
+
print(
|
|
1146
|
+
f" - {test}: {traceback.split('AssertionError:')[-1].strip() if 'AssertionError:' in traceback else 'Unknown failure'}")
|
|
1147
|
+
|
|
1148
|
+
if errors > 0:
|
|
1149
|
+
print(" ERRORS:")
|
|
1150
|
+
for test, traceback in result.errors:
|
|
1151
|
+
error_msg = traceback.split('\n')[-2] if len(traceback.split('\n')) > 1 else "Unknown error"
|
|
1152
|
+
print(f" - {test}: {error_msg}")
|
|
1153
|
+
|
|
1154
|
+
# Final summary
|
|
1155
|
+
print("\n" + "=" * 60)
|
|
1156
|
+
print("FINAL RESULTS")
|
|
1157
|
+
print("=" * 60)
|
|
1158
|
+
print(f"Total Tests: {total_tests}")
|
|
1159
|
+
print(f"Passed: {total_tests - total_failures - total_errors}")
|
|
1160
|
+
print(f"Failed: {total_failures}")
|
|
1161
|
+
print(f"Errors: {total_errors}")
|
|
1162
|
+
|
|
1163
|
+
success_rate = ((total_tests - total_failures - total_errors) / total_tests) * 100 if total_tests > 0 else 0
|
|
1164
|
+
print(f"Success Rate: {success_rate:.1f}%")
|
|
1165
|
+
|
|
1166
|
+
if total_failures == 0 and total_errors == 0:
|
|
1167
|
+
print("\nš ALL TESTS PASSED! š")
|
|
1168
|
+
else:
|
|
1169
|
+
print(f"\nā {total_failures + total_errors} tests failed/errored")
|
|
1170
|
+
|
|
1171
|
+
return total_failures == 0 and total_errors == 0
|
|
1172
|
+
|
|
1173
|
+
if __name__ == "__main__":
|
|
1174
|
+
# Suppress warnings for cleaner output
|
|
1175
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
1176
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
|
1177
|
+
|
|
1178
|
+
# Run main test suite
|
|
1179
|
+
suite = DDPMTestSuite()
|
|
1180
|
+
all_passed = suite.run_all_tests(verbose=True)
|
|
1181
|
+
|
|
1182
|
+
# Optionally run performance benchmarks
|
|
1183
|
+
run_benchmarks = input("\nRun performance benchmarks? (y/n): ").lower().startswith('y')
|
|
1184
|
+
if run_benchmarks:
|
|
1185
|
+
run_performance_benchmarks()
|
|
1186
|
+
|
|
1187
|
+
# Exit with appropriate code
|
|
1188
|
+
exit(0 if all_passed else 1)
|