TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
@@ -0,0 +1,1188 @@
1
+ """
2
+ Comprehensive test suite for the DDPM (Denoising Diffusion Probabilistic Models) implementation.
3
+
4
+ This test suite covers all major components of the DDPM implementation including:
5
+ - VarianceSchedulerDDPM (noise scheduling)
6
+ - ForwardDDPM (forward diffusion process)
7
+ - ReverseDDPM (reverse diffusion process)
8
+ - TrainDDPM (training loop)
9
+ - SampleDDPM (image generation)
10
+ - Integration tests with different configurations
11
+
12
+ Usage:
13
+ python test_ddpm.py
14
+ """
15
+
16
+ import unittest
17
+ import torch
18
+ import torch.nn as nn
19
+ import tempfile
20
+ import shutil
21
+ import os
22
+ from unittest.mock import Mock, patch
23
+ import numpy as np
24
+ from torch.utils.data import DataLoader, TensorDataset
25
+ import warnings
26
+ from torchdiff.ddpm import VarianceSchedulerDDPM, ForwardDDPM, ReverseDDPM, TrainDDPM, SampleDDPM
27
+
28
+
29
+
30
+ # Simple mock classes for testing
31
+ class MockNoisePredictor(nn.Module):
32
+ """Simple noise predictor for testing"""
33
+
34
+ def __init__(self, in_channels=1):
35
+ super().__init__()
36
+ self.conv1 = nn.Conv2d(in_channels, 16, 3, padding=1)
37
+ self.conv2 = nn.Conv2d(16, in_channels, 3, padding=1)
38
+ self.time_mlp = nn.Linear(1, 16)
39
+
40
+ def forward(self, x, t, y_encoded=None, context=None):
41
+ # Simple time embedding
42
+ t_embed = self.time_mlp(t.float().unsqueeze(-1))
43
+ t_embed = t_embed.view(t_embed.shape[0], t_embed.shape[1], 1, 1)
44
+ t_embed = t_embed.expand(-1, -1, x.shape[2], x.shape[3])
45
+
46
+ x = self.conv1(x)
47
+ x = x + t_embed
48
+ x = torch.relu(x)
49
+ x = self.conv2(x)
50
+ return x
51
+
52
+
53
+ class MockConditionalModel(nn.Module):
54
+ """Simple conditional model for testing"""
55
+
56
+ def __init__(self, embed_dim=32):
57
+ super().__init__()
58
+ self.embed_dim = embed_dim
59
+
60
+ def forward(self, input_ids, attention_mask):
61
+ batch_size = input_ids.shape[0]
62
+ return torch.randn(batch_size, self.embed_dim)
63
+
64
+
65
+ class MockMetrics:
66
+ """Mock metrics class for testing"""
67
+
68
+ def __init__(self):
69
+ self.fid = True
70
+ self.metrics = True
71
+ self.lpips = True
72
+
73
+ def forward(self, real_imgs, fake_imgs):
74
+ # Return mock metrics
75
+ return 50.0, 0.1, 25.0, 0.8, 0.05 # fid, mse, psnr, ssim, lpips
76
+
77
+
78
+ class TestVarianceSchedulerDDPM(unittest.TestCase):
79
+ """Test cases for VarianceSchedulerDDPM"""
80
+
81
+ def setUp(self):
82
+ self.num_steps = 100
83
+ self.beta_start = 1e-4
84
+ self.beta_end = 0.02
85
+
86
+ def test_initialization_fixed_schedule(self):
87
+ """Test initialization with fixed (non-trainable) schedule"""
88
+ scheduler = VarianceSchedulerDDPM(
89
+ num_steps=self.num_steps,
90
+ beta_start=self.beta_start,
91
+ beta_end=self.beta_end,
92
+ trainable_beta=False
93
+ )
94
+
95
+ self.assertEqual(scheduler.num_steps, self.num_steps)
96
+ self.assertEqual(scheduler.beta_start, self.beta_start)
97
+ self.assertEqual(scheduler.beta_end, self.beta_end)
98
+ self.assertFalse(scheduler.trainable_beta)
99
+
100
+ # Check that buffers are registered
101
+ self.assertTrue(hasattr(scheduler, 'betas'))
102
+ self.assertTrue(hasattr(scheduler, 'alphas'))
103
+ self.assertTrue(hasattr(scheduler, 'alpha_bars'))
104
+ self.assertEqual(scheduler.betas.shape[0], self.num_steps)
105
+
106
+ def test_initialization_trainable_schedule(self):
107
+ """Test initialization with trainable schedule"""
108
+ scheduler = VarianceSchedulerDDPM(
109
+ num_steps=self.num_steps,
110
+ trainable_beta=True
111
+ )
112
+
113
+ self.assertTrue(scheduler.trainable_beta)
114
+ self.assertIsInstance(scheduler.betas, nn.Parameter)
115
+
116
+ def test_invalid_parameters(self):
117
+ """Test initialization with invalid parameters"""
118
+ with self.assertRaises(ValueError):
119
+ VarianceSchedulerDDPM(beta_start=0.02, beta_end=1e-4) # start > end
120
+
121
+ with self.assertRaises(ValueError):
122
+ VarianceSchedulerDDPM(num_steps=-1) # negative steps
123
+
124
+ with self.assertRaises(ValueError):
125
+ VarianceSchedulerDDPM(beta_start=0.0) # zero start
126
+
127
+ def test_beta_schedule_methods(self):
128
+ """Test different beta schedule computation methods"""
129
+ methods = ["linear", "sigmoid", "quadratic", "constant", "inverse_time"]
130
+
131
+ for method in methods:
132
+ scheduler = VarianceSchedulerDDPM(
133
+ num_steps=self.num_steps,
134
+ beta_method=method
135
+ )
136
+ self.assertEqual(scheduler.betas.shape[0], self.num_steps)
137
+ self.assertTrue(torch.all(scheduler.betas >= self.beta_start))
138
+ self.assertTrue(torch.all(scheduler.betas <= self.beta_end))
139
+
140
+ def test_invalid_beta_method(self):
141
+ """Test invalid beta schedule method"""
142
+ with self.assertRaises(ValueError):
143
+ VarianceSchedulerDDPM(beta_method="invalid_method")
144
+
145
+ def test_compute_schedule_fixed(self):
146
+ """Test compute_schedule method with fixed schedule"""
147
+ scheduler = VarianceSchedulerDDPM(num_steps=self.num_steps, trainable_beta=False)
148
+
149
+ # Test without time_steps (all steps)
150
+ betas, alphas, alpha_bars, sqrt_alpha_bars, sqrt_one_minus_alpha_bars = scheduler.compute_schedule()
151
+
152
+ self.assertEqual(betas.shape[0], self.num_steps)
153
+ self.assertEqual(alphas.shape[0], self.num_steps)
154
+ self.assertEqual(alpha_bars.shape[0], self.num_steps)
155
+
156
+ # Test with specific time_steps
157
+ time_steps = torch.tensor([0, 50, 99])
158
+ result = scheduler.compute_schedule(time_steps)
159
+ betas_t, alphas_t, alpha_bars_t, sqrt_alpha_bars_t, sqrt_one_minus_alpha_bars_t = result
160
+
161
+ self.assertEqual(betas_t.shape[0], 3)
162
+ self.assertEqual(alphas_t.shape[0], 3)
163
+
164
+ def test_compute_schedule_trainable(self):
165
+ """Test compute_schedule method with trainable schedule"""
166
+ scheduler = VarianceSchedulerDDPM(num_steps=self.num_steps, trainable_beta=True)
167
+
168
+ time_steps = torch.tensor([0, 50, 99])
169
+ result = scheduler.compute_schedule(time_steps)
170
+ betas_t, alphas_t, alpha_bars_t, sqrt_alpha_bars_t, sqrt_one_minus_alpha_bars_t = result
171
+
172
+ self.assertEqual(betas_t.shape[0], 3)
173
+ self.assertTrue(torch.all(betas_t > 0))
174
+ self.assertTrue(torch.all(betas_t < 1))
175
+
176
+
177
+ class TestForwardDDPM(unittest.TestCase):
178
+ """Test cases for ForwardDDPM"""
179
+
180
+ def setUp(self):
181
+ self.variance_scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=False)
182
+ self.forward_ddpm = ForwardDDPM(self.variance_scheduler)
183
+ self.batch_size = 4
184
+ self.channels = 1
185
+ self.height = 28
186
+ self.width = 28
187
+
188
+ def test_forward_process(self):
189
+ """Test the forward diffusion process"""
190
+ x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
191
+ noise = torch.randn_like(x0)
192
+ time_steps = torch.randint(0, self.variance_scheduler.num_steps, (self.batch_size,))
193
+
194
+ xt = self.forward_ddpm(x0, noise, time_steps)
195
+
196
+ # Check output shape
197
+ self.assertEqual(xt.shape, x0.shape)
198
+
199
+ # Check that output is finite
200
+ self.assertTrue(torch.all(torch.isfinite(xt)))
201
+
202
+ def test_forward_process_trainable(self):
203
+ """Test forward process with trainable variance scheduler"""
204
+ variance_scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=True)
205
+ forward_ddpm = ForwardDDPM(variance_scheduler)
206
+
207
+ x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
208
+ noise = torch.randn_like(x0)
209
+ time_steps = torch.randint(0, variance_scheduler.num_steps, (self.batch_size,))
210
+
211
+ xt = forward_ddpm(x0, noise, time_steps)
212
+ self.assertEqual(xt.shape, x0.shape)
213
+
214
+ def test_invalid_time_steps(self):
215
+ """Test with invalid time steps"""
216
+ x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
217
+ noise = torch.randn_like(x0)
218
+
219
+ # Time steps out of range
220
+ invalid_time_steps = torch.tensor([100, 200, -1, 50])
221
+
222
+ with self.assertRaises(ValueError):
223
+ self.forward_ddpm(x0, noise, invalid_time_steps)
224
+
225
+ def test_edge_cases(self):
226
+ """Test edge cases like t=0 and t=max"""
227
+ x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
228
+ noise = torch.randn_like(x0)
229
+
230
+ # Test t=0 (should be close to original)
231
+ time_steps = torch.zeros(self.batch_size, dtype=torch.long)
232
+ xt = self.forward_ddpm(x0, noise, time_steps)
233
+ self.assertEqual(xt.shape, x0.shape)
234
+
235
+ # Test t=max-1
236
+ time_steps = torch.full((self.batch_size,), self.variance_scheduler.num_steps - 1, dtype=torch.long)
237
+ xt = self.forward_ddpm(x0, noise, time_steps)
238
+ self.assertEqual(xt.shape, x0.shape)
239
+
240
+
241
+ class TestReverseDDPM(unittest.TestCase):
242
+ """Test cases for ReverseDDPM"""
243
+
244
+ def setUp(self):
245
+ self.variance_scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=False)
246
+ self.reverse_ddpm = ReverseDDPM(self.variance_scheduler)
247
+ self.batch_size = 4
248
+ self.channels = 1
249
+ self.height = 28
250
+ self.width = 28
251
+
252
+ def test_reverse_process(self):
253
+ """Test the reverse diffusion process"""
254
+ xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
255
+ predicted_noise = torch.randn_like(xt)
256
+ time_steps = torch.randint(1, self.variance_scheduler.num_steps, (self.batch_size,))
257
+
258
+ xt_minus_1 = self.reverse_ddpm(xt, predicted_noise, time_steps)
259
+
260
+ # Check output shape
261
+ self.assertEqual(xt_minus_1.shape, xt.shape)
262
+
263
+ # Check that output is finite
264
+ self.assertTrue(torch.all(torch.isfinite(xt_minus_1)))
265
+
266
+ def test_reverse_process_t_zero(self):
267
+ """Test reverse process at t=0 (deterministic)"""
268
+ xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
269
+ predicted_noise = torch.randn_like(xt)
270
+ time_steps = torch.zeros(self.batch_size, dtype=torch.long)
271
+
272
+ xt_minus_1 = self.reverse_ddpm(xt, predicted_noise, time_steps)
273
+
274
+ # At t=0, the process should be deterministic (no random noise added)
275
+ self.assertEqual(xt_minus_1.shape, xt.shape)
276
+
277
+ def test_reverse_process_trainable(self):
278
+ """Test reverse process with trainable variance scheduler"""
279
+ variance_scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=True)
280
+ reverse_ddpm = ReverseDDPM(variance_scheduler)
281
+
282
+ xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
283
+ predicted_noise = torch.randn_like(xt)
284
+ time_steps = torch.randint(1, variance_scheduler.num_steps, (self.batch_size,))
285
+
286
+ xt_minus_1 = reverse_ddpm(xt, predicted_noise, time_steps)
287
+ self.assertEqual(xt_minus_1.shape, xt.shape)
288
+
289
+ def test_invalid_time_steps(self):
290
+ """Test with invalid time steps"""
291
+ xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
292
+ predicted_noise = torch.randn_like(xt)
293
+
294
+ # Time steps out of range
295
+ invalid_time_steps = torch.tensor([100, 200, -1, 50])
296
+
297
+ with self.assertRaises(ValueError):
298
+ self.reverse_ddpm(xt, predicted_noise, invalid_time_steps)
299
+
300
+
301
+ class TestTrainDDPM(unittest.TestCase):
302
+ """Test cases for TrainDDPM"""
303
+
304
+ def setUp(self):
305
+ # Create temporary directory for test outputs
306
+ self.test_dir = tempfile.mkdtemp()
307
+
308
+ # Create mock data
309
+ self.batch_size = 8
310
+ self.channels = 1
311
+ self.height = 28
312
+ self.width = 28
313
+
314
+ # Create mock dataset
315
+ x_data = torch.randn(32, self.channels, self.height, self.width)
316
+ y_data = ['test prompt'] * 32
317
+ dataset = TensorDataset(x_data, torch.arange(32)) # Simple labels
318
+
319
+ self.train_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
320
+ self.val_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False)
321
+
322
+ # Create components
323
+ self.variance_scheduler = VarianceSchedulerDDPM(num_steps=50, trainable_beta=False)
324
+ self.forward_ddpm = ForwardDDPM(self.variance_scheduler)
325
+ self.reverse_ddpm = ReverseDDPM(self.variance_scheduler)
326
+ self.noise_predictor = MockNoisePredictor(self.channels)
327
+ self.optimizer = torch.optim.Adam(self.noise_predictor.parameters(), lr=1e-3)
328
+ self.objective = nn.MSELoss()
329
+
330
+ def tearDown(self):
331
+ # Clean up temporary directory
332
+ shutil.rmtree(self.test_dir, ignore_errors=True)
333
+
334
+ def test_initialization(self):
335
+ """Test TrainDDPM initialization"""
336
+ trainer = TrainDDPM(
337
+ noise_predictor=self.noise_predictor,
338
+ forward_diffusion=self.forward_ddpm,
339
+ reverse_diffusion=self.reverse_ddpm,
340
+ data_loader=self.train_loader,
341
+ optimizer=self.optimizer,
342
+ objective=self.objective,
343
+ max_epochs=2,
344
+ device='cpu',
345
+ store_path=self.test_dir
346
+ )
347
+
348
+ self.assertEqual(trainer.max_epochs, 2)
349
+ self.assertEqual(trainer.device, torch.device('cpu'))
350
+ self.assertIsNotNone(trainer.scheduler)
351
+ self.assertIsNotNone(trainer.warmup_lr_scheduler)
352
+
353
+ def test_training_loop_basic(self):
354
+ """Test basic training loop without validation"""
355
+ trainer = TrainDDPM(
356
+ noise_predictor=self.noise_predictor,
357
+ forward_diffusion=self.forward_ddpm,
358
+ reverse_diffusion=self.reverse_ddpm,
359
+ data_loader=self.train_loader,
360
+ optimizer=self.optimizer,
361
+ objective=self.objective,
362
+ max_epochs=2,
363
+ device='cpu',
364
+ store_path=self.test_dir,
365
+ log_frequency=1
366
+ )
367
+
368
+ train_losses, best_val_loss = trainer()
369
+
370
+ self.assertEqual(len(train_losses), 2) # 2 epochs
371
+ self.assertTrue(all(isinstance(loss, float) for loss in train_losses))
372
+ self.assertIsInstance(best_val_loss, float)
373
+
374
+ def test_training_with_validation(self):
375
+ """Test training loop with validation"""
376
+ metrics = MockMetrics()
377
+
378
+ trainer = TrainDDPM(
379
+ noise_predictor=self.noise_predictor,
380
+ forward_diffusion=self.forward_ddpm,
381
+ reverse_diffusion=self.reverse_ddpm,
382
+ data_loader=self.train_loader,
383
+ optimizer=self.optimizer,
384
+ objective=self.objective,
385
+ val_loader=self.val_loader,
386
+ metrics_=metrics,
387
+ max_epochs=2,
388
+ device='cpu',
389
+ store_path=self.test_dir,
390
+ val_frequency=1,
391
+ log_frequency=1
392
+ )
393
+
394
+ train_losses, best_val_loss = trainer()
395
+
396
+ self.assertEqual(len(train_losses), 2)
397
+ self.assertIsInstance(best_val_loss, float)
398
+
399
+ def test_training_with_conditional_model(self):
400
+ """Test training with conditional model"""
401
+ conditional_model = MockConditionalModel()
402
+
403
+ # Modify data loader to return string labels
404
+ x_data = torch.randn(16, self.channels, self.height, self.width)
405
+ y_data = ['test prompt'] * 16
406
+
407
+ # Create custom dataset that returns proper string prompts
408
+ class StringDataset:
409
+ def __init__(self, x_data, y_data):
410
+ self.x_data = x_data
411
+ self.y_data = y_data
412
+
413
+ def __len__(self):
414
+ return len(self.x_data)
415
+
416
+ def __getitem__(self, idx):
417
+ return self.x_data[idx], self.y_data[idx]
418
+
419
+ string_dataset = StringDataset(x_data, y_data)
420
+ string_loader = DataLoader(string_dataset, batch_size=4, shuffle=True)
421
+
422
+ trainer = TrainDDPM(
423
+ noise_predictor=self.noise_predictor,
424
+ forward_diffusion=self.forward_ddpm,
425
+ reverse_diffusion=self.reverse_ddpm,
426
+ data_loader=string_loader,
427
+ optimizer=self.optimizer,
428
+ objective=self.objective,
429
+ conditional_model=conditional_model,
430
+ max_epochs=1,
431
+ device='cpu',
432
+ store_path=self.test_dir,
433
+ log_frequency=1
434
+ )
435
+
436
+ train_losses, best_val_loss = trainer()
437
+ self.assertEqual(len(train_losses), 1)
438
+
439
+ def test_gradient_accumulation(self):
440
+ """Test training with gradient accumulation"""
441
+ trainer = TrainDDPM(
442
+ noise_predictor=self.noise_predictor,
443
+ forward_diffusion=self.forward_ddpm,
444
+ reverse_diffusion=self.reverse_ddpm,
445
+ data_loader=self.train_loader,
446
+ optimizer=self.optimizer,
447
+ objective=self.objective,
448
+ max_epochs=1,
449
+ device='cpu',
450
+ store_path=self.test_dir,
451
+ grad_accumulation_steps=2,
452
+ log_frequency=1
453
+ )
454
+
455
+ train_losses, best_val_loss = trainer()
456
+ self.assertEqual(len(train_losses), 1)
457
+
458
+ def test_warmup_scheduler(self):
459
+ """Test warmup scheduler functionality"""
460
+ optimizer = torch.optim.Adam(self.noise_predictor.parameters(), lr=1e-3)
461
+ warmup_epochs = 5
462
+
463
+ scheduler = TrainDDPM.warmup_scheduler(optimizer, warmup_epochs)
464
+
465
+ # Test the lambda function directly
466
+ lr_lambda = scheduler.lr_lambdas[0]
467
+
468
+ # Test warmup phase - lr should increase linearly from 0 to 1
469
+ for epoch in range(warmup_epochs):
470
+ expected_lr = epoch / warmup_epochs
471
+ actual_lr = lr_lambda(epoch)
472
+ self.assertAlmostEqual(actual_lr, expected_lr, places=5)
473
+
474
+ # Test post-warmup phase - should be 1.0
475
+ self.assertEqual(lr_lambda(warmup_epochs), 1.0)
476
+ self.assertEqual(lr_lambda(warmup_epochs + 1), 1.0)
477
+ self.assertEqual(lr_lambda(warmup_epochs + 10), 1.0) # Much later
478
+
479
+ # Test edge cases
480
+ self.assertEqual(lr_lambda(0), 0.0) # At start, lr multiplier should be 0
481
+ self.assertAlmostEqual(lr_lambda(warmup_epochs - 1), (warmup_epochs - 1) / warmup_epochs, places=5)
482
+
483
+ # Test with different warmup_epochs value
484
+ scheduler2 = TrainDDPM.warmup_scheduler(optimizer, 10)
485
+ lr_lambda2 = scheduler2.lr_lambdas[0]
486
+
487
+ self.assertEqual(lr_lambda2(0), 0.0)
488
+ self.assertAlmostEqual(lr_lambda2(5), 0.5, places=5)
489
+ self.assertEqual(lr_lambda2(10), 1.0)
490
+
491
+ @patch('torch.save')
492
+ def test_checkpoint_saving(self, mock_save):
493
+ """Test checkpoint saving functionality"""
494
+ trainer = TrainDDPM(
495
+ noise_predictor=self.noise_predictor,
496
+ forward_diffusion=self.forward_ddpm,
497
+ reverse_diffusion=self.reverse_ddpm,
498
+ data_loader=self.train_loader,
499
+ optimizer=self.optimizer,
500
+ objective=self.objective,
501
+ max_epochs=1,
502
+ device='cpu',
503
+ store_path=self.test_dir,
504
+ val_frequency=1
505
+ )
506
+
507
+ # Test private checkpoint saving method
508
+ trainer._save_checkpoint(1, 0.5)
509
+
510
+ # Verify torch.save was called
511
+ self.assertTrue(mock_save.called)
512
+
513
+ def test_checkpoint_loading(self):
514
+ """Test checkpoint loading functionality"""
515
+ trainer = TrainDDPM(
516
+ noise_predictor=self.noise_predictor,
517
+ forward_diffusion=self.forward_ddpm,
518
+ reverse_diffusion=self.reverse_ddpm,
519
+ data_loader=self.train_loader,
520
+ optimizer=self.optimizer,
521
+ objective=self.objective,
522
+ max_epochs=1,
523
+ device='cpu',
524
+ store_path=self.test_dir
525
+ )
526
+
527
+ # Create a mock checkpoint
528
+ checkpoint = {
529
+ 'epoch': 5,
530
+ 'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
531
+ 'model_state_dict_conditional': None,
532
+ 'optimizer_state_dict': self.optimizer.state_dict(),
533
+ 'loss': 0.5,
534
+ 'variance_scheduler_model': self.variance_scheduler.state_dict(),
535
+ 'max_epochs': 10,
536
+ }
537
+
538
+ checkpoint_path = os.path.join(self.test_dir, 'test_checkpoint.pth')
539
+ torch.save(checkpoint, checkpoint_path)
540
+
541
+ # Test loading
542
+ epoch, loss = trainer.load_checkpoint(checkpoint_path)
543
+
544
+ self.assertEqual(epoch, 5)
545
+ self.assertEqual(loss, 0.5)
546
+
547
+ def test_validation_method(self):
548
+ """Test validation method"""
549
+ metrics = MockMetrics()
550
+
551
+ trainer = TrainDDPM(
552
+ noise_predictor=self.noise_predictor,
553
+ forward_diffusion=self.forward_ddpm,
554
+ reverse_diffusion=self.reverse_ddpm,
555
+ data_loader=self.train_loader,
556
+ optimizer=self.optimizer,
557
+ objective=self.objective,
558
+ val_loader=self.val_loader,
559
+ metrics_=metrics,
560
+ max_epochs=1,
561
+ device='cpu',
562
+ store_path=self.test_dir
563
+ )
564
+
565
+ # Call validation
566
+ val_metrics = trainer.validate()
567
+
568
+ self.assertEqual(len(val_metrics), 6) # val_loss, fid, mse, psnr, ssim, lpips
569
+ val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
570
+ self.assertIsInstance(val_loss, float)
571
+
572
+
573
+ class TestSampleDDPM(unittest.TestCase):
574
+ """Test cases for SampleDDPM"""
575
+
576
+ def setUp(self):
577
+ self.test_dir = tempfile.mkdtemp()
578
+
579
+ self.variance_scheduler = VarianceSchedulerDDPM(num_steps=50, trainable_beta=False)
580
+ self.reverse_ddpm = ReverseDDPM(self.variance_scheduler)
581
+ self.noise_predictor = MockNoisePredictor(in_channels=3)
582
+ self.conditional_model = MockConditionalModel()
583
+
584
+ self.image_shape = (32, 32)
585
+ self.batch_size = 4
586
+
587
+ def tearDown(self):
588
+ shutil.rmtree(self.test_dir, ignore_errors=True)
589
+
590
+ def test_initialization(self):
591
+ """Test SampleDDPM initialization"""
592
+ sampler = SampleDDPM(
593
+ reverse_diffusion=self.reverse_ddpm,
594
+ noise_predictor=self.noise_predictor,
595
+ image_shape=self.image_shape,
596
+ batch_size=self.batch_size,
597
+ device='cpu'
598
+ )
599
+
600
+ self.assertEqual(sampler.image_shape, self.image_shape)
601
+ self.assertEqual(sampler.batch_size, self.batch_size)
602
+ self.assertEqual(sampler.device, torch.device('cpu'))
603
+
604
+ def test_invalid_initialization(self):
605
+ """Test initialization with invalid parameters"""
606
+ with self.assertRaises(ValueError):
607
+ SampleDDPM(
608
+ reverse_diffusion=self.reverse_ddpm,
609
+ noise_predictor=self.noise_predictor,
610
+ image_shape=(32,), # Invalid shape
611
+ batch_size=self.batch_size
612
+ )
613
+
614
+ with self.assertRaises(ValueError):
615
+ SampleDDPM(
616
+ reverse_diffusion=self.reverse_ddpm,
617
+ noise_predictor=self.noise_predictor,
618
+ image_shape=self.image_shape,
619
+ batch_size=0 # Invalid batch size
620
+ )
621
+
622
+ def test_tokenization(self):
623
+ """Test text tokenization"""
624
+ sampler = SampleDDPM(
625
+ reverse_diffusion=self.reverse_ddpm,
626
+ noise_predictor=self.noise_predictor,
627
+ image_shape=self.image_shape,
628
+ conditional_model=self.conditional_model,
629
+ batch_size=self.batch_size,
630
+ device='cpu'
631
+ )
632
+
633
+ # Test single prompt
634
+ prompts = "a beautiful landscape"
635
+ input_ids, attention_mask = sampler.tokenize(prompts)
636
+
637
+ self.assertEqual(input_ids.shape[0], 1)
638
+ self.assertEqual(attention_mask.shape[0], 1)
639
+
640
+ # Test multiple prompts
641
+ prompts = ["prompt 1", "prompt 2", "prompt 3"]
642
+ input_ids, attention_mask = sampler.tokenize(prompts)
643
+
644
+ self.assertEqual(input_ids.shape[0], 3)
645
+ self.assertEqual(attention_mask.shape[0], 3)
646
+
647
+ def test_unconditional_generation(self):
648
+ """Test unconditional image generation"""
649
+ sampler = SampleDDPM(
650
+ reverse_diffusion=self.reverse_ddpm,
651
+ noise_predictor=self.noise_predictor,
652
+ image_shape=self.image_shape,
653
+ batch_size=self.batch_size,
654
+ device='cpu'
655
+ )
656
+
657
+ generated_imgs = sampler(
658
+ conditions=None,
659
+ normalize_output=True,
660
+ save_images=False
661
+ )
662
+
663
+ expected_shape = (self.batch_size, 3, self.image_shape[0], self.image_shape[1])
664
+ self.assertEqual(generated_imgs.shape, expected_shape)
665
+
666
+ # Check that images are normalized to [0, 1]
667
+ self.assertTrue(torch.all(generated_imgs >= 0))
668
+ self.assertTrue(torch.all(generated_imgs <= 1))
669
+
670
+ def test_conditional_generation(self):
671
+ """Test conditional image generation"""
672
+ sampler = SampleDDPM(
673
+ reverse_diffusion=self.reverse_ddpm,
674
+ noise_predictor=self.noise_predictor,
675
+ image_shape=self.image_shape,
676
+ conditional_model=self.conditional_model,
677
+ batch_size=self.batch_size,
678
+ device='cpu'
679
+ )
680
+
681
+ prompts = ["a beautiful sunset", "a mountain landscape", "a city skyline", "a forest path"]
682
+
683
+ generated_imgs = sampler(
684
+ conditions=prompts,
685
+ normalize_output=True,
686
+ save_images=False
687
+ )
688
+
689
+ expected_shape = (self.batch_size, 3, self.image_shape[0], self.image_shape[1])
690
+ self.assertEqual(generated_imgs.shape, expected_shape)
691
+
692
+ def test_image_saving(self):
693
+ """Test image saving functionality"""
694
+ sampler = SampleDDPM(
695
+ reverse_diffusion=self.reverse_ddpm,
696
+ noise_predictor=self.noise_predictor,
697
+ image_shape=self.image_shape,
698
+ batch_size=2, # Use smaller batch for faster test
699
+ device='cpu'
700
+ )
701
+
702
+ generated_imgs = sampler(
703
+ conditions=None,
704
+ save_images=True,
705
+ save_path=self.test_dir
706
+ )
707
+
708
+ # Check that images were saved
709
+ saved_files = os.listdir(self.test_dir)
710
+ self.assertEqual(len(saved_files), 2) # 2 images
711
+ self.assertTrue(all(f.endswith('.png') for f in saved_files))
712
+
713
+ def test_device_transfer(self):
714
+ """Test device transfer functionality"""
715
+ sampler = SampleDDPM(
716
+ reverse_diffusion=self.reverse_ddpm,
717
+ noise_predictor=self.noise_predictor,
718
+ image_shape=self.image_shape,
719
+ batch_size=self.batch_size,
720
+ device='cpu'
721
+ )
722
+
723
+ # Test moving to CPU (should work)
724
+ sampler_moved = sampler.to(torch.device('cpu'))
725
+ self.assertEqual(sampler_moved.device, torch.device('cpu'))
726
+
727
+ def test_error_conditions(self):
728
+ """Test error conditions in sampling"""
729
+ # Test conditional generation without conditional model
730
+ sampler = SampleDDPM(
731
+ reverse_diffusion=self.reverse_ddpm,
732
+ noise_predictor=self.noise_predictor,
733
+ image_shape=self.image_shape,
734
+ batch_size=self.batch_size,
735
+ device='cpu'
736
+ )
737
+
738
+ with self.assertRaises(ValueError):
739
+ sampler(conditions=["test prompt"])
740
+
741
+ # Test unconditional generation with conditional model
742
+ sampler_conditional = SampleDDPM(
743
+ reverse_diffusion=self.reverse_ddpm,
744
+ noise_predictor=self.noise_predictor,
745
+ image_shape=self.image_shape,
746
+ conditional_model=self.conditional_model,
747
+ batch_size=self.batch_size,
748
+ device='cpu'
749
+ )
750
+
751
+ with self.assertRaises(ValueError):
752
+ sampler_conditional(conditions=None)
753
+
754
+ class TestIntegration(unittest.TestCase):
755
+ """Integration tests for the complete DDPM pipeline"""
756
+
757
+ def setUp(self):
758
+ self.test_dir = tempfile.mkdtemp()
759
+
760
+ # Create a simple dataset
761
+ self.batch_size = 4
762
+ self.channels = 1
763
+ self.height = 16 # Smaller for faster tests
764
+ self.width = 16
765
+
766
+ x_data = torch.randn(16, self.channels, self.height, self.width)
767
+ y_data = torch.arange(16)
768
+ dataset = TensorDataset(x_data, y_data)
769
+
770
+ self.data_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
771
+
772
+ def tearDown(self):
773
+ shutil.rmtree(self.test_dir, ignore_errors=True)
774
+
775
+ def test_complete_pipeline_unconditional(self):
776
+ """Test complete pipeline: training -> sampling (unconditional)"""
777
+ # Setup components
778
+ variance_scheduler = VarianceSchedulerDDPM(num_steps=20, trainable_beta=False) # Small for speed
779
+ forward_ddpm = ForwardDDPM(variance_scheduler)
780
+ reverse_ddpm = ReverseDDPM(variance_scheduler)
781
+ noise_predictor = MockNoisePredictor(self.channels)
782
+ optimizer = torch.optim.Adam(noise_predictor.parameters(), lr=1e-3)
783
+ objective = nn.MSELoss()
784
+
785
+ # Training
786
+ trainer = TrainDDPM(
787
+ noise_predictor=noise_predictor,
788
+ forward_diffusion=forward_ddpm,
789
+ reverse_diffusion=reverse_ddpm,
790
+ data_loader=self.data_loader,
791
+ optimizer=optimizer,
792
+ objective=objective,
793
+ max_epochs=1, # Just 1 epoch for speed
794
+ device='cpu',
795
+ store_path=self.test_dir,
796
+ log_frequency=1
797
+ )
798
+
799
+ train_losses, best_val_loss = trainer()
800
+
801
+ # Sampling
802
+ sampler = SampleDDPM(
803
+ reverse_diffusion=reverse_ddpm,
804
+ noise_predictor=noise_predictor,
805
+ image_shape=(self.height, self.width),
806
+ batch_size=2,
807
+ in_channels=self.channels,
808
+ device='cpu'
809
+ )
810
+
811
+ generated_imgs = sampler(save_images=False)
812
+
813
+ # Verify results
814
+ self.assertEqual(len(train_losses), 1)
815
+ expected_shape = (2, self.channels, self.height, self.width)
816
+ self.assertEqual(generated_imgs.shape, expected_shape)
817
+
818
+ def test_complete_pipeline_conditional(self):
819
+ """Test complete pipeline with conditional generation"""
820
+
821
+ # Create string dataset for conditional training
822
+ class StringDataset:
823
+ def __init__(self, x_data, prompts):
824
+ self.x_data = x_data
825
+ self.prompts = prompts
826
+
827
+ def __len__(self):
828
+ return len(self.x_data)
829
+
830
+ def __getitem__(self, idx):
831
+ return self.x_data[idx], self.prompts[idx]
832
+
833
+ x_data = torch.randn(8, self.channels, self.height, self.width)
834
+ prompts = [f"image {i}" for i in range(8)]
835
+ string_dataset = StringDataset(x_data, prompts)
836
+ string_loader = DataLoader(string_dataset, batch_size=2, shuffle=True)
837
+
838
+ # Setup components
839
+ variance_scheduler = VarianceSchedulerDDPM(num_steps=20, trainable_beta=False)
840
+ forward_ddpm = ForwardDDPM(variance_scheduler)
841
+ reverse_ddpm = ReverseDDPM(variance_scheduler)
842
+ noise_predictor = MockNoisePredictor(self.channels)
843
+ conditional_model = MockConditionalModel()
844
+ optimizer = torch.optim.Adam([
845
+ *noise_predictor.parameters(),
846
+ *conditional_model.parameters()
847
+ ], lr=1e-3)
848
+ objective = nn.MSELoss()
849
+
850
+ # Training
851
+ trainer = TrainDDPM(
852
+ noise_predictor=noise_predictor,
853
+ forward_diffusion=forward_ddpm,
854
+ reverse_diffusion=reverse_ddpm,
855
+ data_loader=string_loader,
856
+ optimizer=optimizer,
857
+ objective=objective,
858
+ conditional_model=conditional_model,
859
+ max_epochs=1,
860
+ device='cpu',
861
+ store_path=self.test_dir,
862
+ log_frequency=1
863
+ )
864
+
865
+ train_losses, best_val_loss = trainer()
866
+
867
+ # Conditional sampling
868
+ sampler = SampleDDPM(
869
+ reverse_diffusion=reverse_ddpm,
870
+ noise_predictor=noise_predictor,
871
+ image_shape=(self.height, self.width),
872
+ conditional_model=conditional_model,
873
+ batch_size=2,
874
+ in_channels=self.channels,
875
+ device='cpu'
876
+ )
877
+
878
+ test_prompts = ["test image 1", "test image 2"]
879
+ generated_imgs = sampler(conditions=test_prompts, save_images=False)
880
+
881
+ # Verify results
882
+ self.assertEqual(len(train_losses), 1)
883
+ expected_shape = (2, self.channels, self.height, self.width)
884
+ self.assertEqual(generated_imgs.shape, expected_shape)
885
+
886
+ def test_trainable_variance_schedule(self):
887
+ """Test pipeline with trainable variance schedule"""
888
+ variance_scheduler = VarianceSchedulerDDPM(num_steps=20, trainable_beta=True)
889
+ forward_ddpm = ForwardDDPM(variance_scheduler)
890
+ reverse_ddpm = ReverseDDPM(variance_scheduler)
891
+ noise_predictor = MockNoisePredictor(self.channels)
892
+
893
+ # Include variance scheduler parameters in optimizer
894
+ all_params = list(noise_predictor.parameters()) + list(variance_scheduler.parameters())
895
+ optimizer = torch.optim.Adam(all_params, lr=1e-3)
896
+ objective = nn.MSELoss()
897
+
898
+ trainer = TrainDDPM(
899
+ noise_predictor=noise_predictor,
900
+ forward_diffusion=forward_ddpm,
901
+ reverse_diffusion=reverse_ddpm,
902
+ data_loader=self.data_loader,
903
+ optimizer=optimizer,
904
+ objective=objective,
905
+ max_epochs=1,
906
+ device='cpu',
907
+ store_path=self.test_dir,
908
+ log_frequency=1
909
+ )
910
+
911
+ train_losses, best_val_loss = trainer()
912
+ self.assertEqual(len(train_losses), 1)
913
+
914
+ def test_checkpoint_save_load_cycle(self):
915
+ """Test saving and loading checkpoints"""
916
+ variance_scheduler = VarianceSchedulerDDPM(num_steps=20, trainable_beta=False)
917
+ forward_ddpm = ForwardDDPM(variance_scheduler)
918
+ reverse_ddpm = ReverseDDPM(variance_scheduler)
919
+ noise_predictor = MockNoisePredictor(self.channels)
920
+ optimizer = torch.optim.Adam(noise_predictor.parameters(), lr=1e-3)
921
+ objective = nn.MSELoss()
922
+
923
+ # Create first trainer and train
924
+ trainer1 = TrainDDPM(
925
+ noise_predictor=noise_predictor,
926
+ forward_diffusion=forward_ddpm,
927
+ reverse_diffusion=reverse_ddpm,
928
+ data_loader=self.data_loader,
929
+ optimizer=optimizer,
930
+ objective=objective,
931
+ max_epochs=1,
932
+ device='cpu',
933
+ store_path=self.test_dir,
934
+ val_frequency=1
935
+ )
936
+
937
+ train_losses1, _ = trainer1()
938
+
939
+ # Create new components and load checkpoint
940
+ new_noise_predictor = MockNoisePredictor(self.channels)
941
+ new_optimizer = torch.optim.Adam(new_noise_predictor.parameters(), lr=1e-3)
942
+
943
+ trainer2 = TrainDDPM(
944
+ noise_predictor=new_noise_predictor,
945
+ forward_diffusion=forward_ddpm,
946
+ reverse_diffusion=reverse_ddpm,
947
+ data_loader=self.data_loader,
948
+ optimizer=new_optimizer,
949
+ objective=objective,
950
+ max_epochs=1,
951
+ device='cpu',
952
+ store_path=self.test_dir
953
+ )
954
+
955
+ # Find checkpoint file
956
+ checkpoint_files = [f for f in os.listdir(self.test_dir) if f.endswith('.pth')]
957
+ self.assertTrue(len(checkpoint_files) > 0, "No checkpoint file found")
958
+
959
+ checkpoint_path = os.path.join(self.test_dir, checkpoint_files[0])
960
+ epoch, loss = trainer2.load_checkpoint(checkpoint_path)
961
+
962
+ self.assertIsInstance(epoch, int)
963
+ self.assertIsInstance(loss, float)
964
+
965
+ class TestEdgeCases(unittest.TestCase):
966
+ """Test edge cases and error conditions"""
967
+
968
+ def test_different_beta_schedules_consistency(self):
969
+ """Test that different beta schedules produce valid results"""
970
+ methods = ["linear", "sigmoid", "quadratic", "constant", "inverse_time"]
971
+
972
+ for method in methods:
973
+ scheduler = VarianceSchedulerDDPM(
974
+ num_steps=50,
975
+ beta_method=method,
976
+ trainable_beta=False
977
+ )
978
+
979
+ forward_ddpm = ForwardDDPM(scheduler)
980
+ reverse_ddpm = ReverseDDPM(scheduler)
981
+
982
+ # Test forward-reverse consistency
983
+ x0 = torch.randn(2, 1, 16, 16)
984
+ noise = torch.randn_like(x0)
985
+ t = torch.randint(1, scheduler.num_steps - 1, (2,))
986
+
987
+ # Forward process
988
+ xt = forward_ddpm(x0, noise, t)
989
+
990
+ # Reverse process (using ground truth noise)
991
+ xt_minus_1 = reverse_ddpm(xt, noise, t)
992
+
993
+ # Results should be finite
994
+ self.assertTrue(torch.all(torch.isfinite(xt)))
995
+ self.assertTrue(torch.all(torch.isfinite(xt_minus_1)))
996
+
997
+ def test_extreme_timesteps(self):
998
+ """Test behavior at extreme timesteps"""
999
+ scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=False)
1000
+ forward_ddpm = ForwardDDPM(scheduler)
1001
+ reverse_ddpm = ReverseDDPM(scheduler)
1002
+
1003
+ x0 = torch.randn(2, 1, 16, 16)
1004
+ noise = torch.randn_like(x0)
1005
+
1006
+ # Test t=0 (minimal noise)
1007
+ t_zero = torch.zeros(2, dtype=torch.long)
1008
+ xt_zero = forward_ddpm(x0, noise, t_zero)
1009
+
1010
+ # At t=0, result should be very close to original
1011
+ alpha_bar_0 = scheduler.alpha_bars[0]
1012
+ expected_xt_zero = torch.sqrt(alpha_bar_0) * x0 + torch.sqrt(1 - alpha_bar_0) * noise
1013
+ torch.testing.assert_close(xt_zero, expected_xt_zero, atol=1e-6, rtol=1e-6)
1014
+
1015
+ # Test t=max-1 (maximum noise)
1016
+ t_max = torch.full((2,), scheduler.num_steps - 1, dtype=torch.long)
1017
+ xt_max = forward_ddpm(x0, noise, t_max)
1018
+
1019
+ # Result should be mostly noise
1020
+ self.assertTrue(torch.all(torch.isfinite(xt_max)))
1021
+
1022
+ def test_batch_size_consistency(self):
1023
+ """Test consistency across different batch sizes"""
1024
+ scheduler = VarianceSchedulerDDPM(num_steps=50, trainable_beta=False)
1025
+ forward_ddpm = ForwardDDPM(scheduler)
1026
+
1027
+ # Test with different batch sizes
1028
+ for batch_size in [1, 4, 8]:
1029
+ x0 = torch.randn(batch_size, 3, 16, 16)
1030
+ noise = torch.randn_like(x0)
1031
+ t = torch.randint(0, scheduler.num_steps, (batch_size,))
1032
+
1033
+ xt = forward_ddpm(x0, noise, t)
1034
+
1035
+ self.assertEqual(xt.shape[0], batch_size)
1036
+ self.assertTrue(torch.all(torch.isfinite(xt)))
1037
+
1038
+ def test_memory_efficiency(self):
1039
+ """Test that operations don't cause memory leaks (basic check)"""
1040
+ scheduler = VarianceSchedulerDDPM(num_steps=100, trainable_beta=False)
1041
+ forward_ddpm = ForwardDDPM(scheduler)
1042
+
1043
+ initial_memory = torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
1044
+
1045
+ # Perform multiple operations
1046
+ for _ in range(10):
1047
+ x0 = torch.randn(4, 3, 32, 32)
1048
+ noise = torch.randn_like(x0)
1049
+ t = torch.randint(0, scheduler.num_steps, (4,))
1050
+
1051
+ xt = forward_ddpm(x0, noise, t)
1052
+ del xt, x0, noise, t
1053
+
1054
+ # Memory usage shouldn't grow significantly
1055
+ if torch.cuda.is_available():
1056
+ torch.cuda.empty_cache()
1057
+ final_memory = torch.cuda.memory_allocated()
1058
+ # Allow for some variance in memory usage
1059
+ self.assertLess(final_memory - initial_memory, 100 * 1024 * 1024) # 100MB threshold
1060
+
1061
+ def run_performance_benchmarks():
1062
+ """Optional performance benchmarks (not part of main test suite)"""
1063
+ print("\n" + "=" * 50)
1064
+ print("PERFORMANCE BENCHMARKS")
1065
+ print("=" * 50)
1066
+
1067
+ import time
1068
+
1069
+ # Benchmark forward diffusion
1070
+ scheduler = VarianceSchedulerDDPM(num_steps=1000, trainable_beta=False)
1071
+ forward_ddpm = ForwardDDPM(scheduler)
1072
+
1073
+ x0 = torch.randn(32, 3, 64, 64)
1074
+ noise = torch.randn_like(x0)
1075
+ t = torch.randint(0, scheduler.num_steps, (32,))
1076
+
1077
+ # Warmup
1078
+ for _ in range(5):
1079
+ _ = forward_ddpm(x0, noise, t)
1080
+
1081
+ # Benchmark
1082
+ start_time = time.time()
1083
+ for _ in range(100):
1084
+ xt = forward_ddpm(x0, noise, t)
1085
+ end_time = time.time()
1086
+
1087
+ avg_time = (end_time - start_time) / 100
1088
+ print(f"Forward diffusion avg time: {avg_time * 1000:.2f}ms per batch")
1089
+ print(f"Throughput: {32 / avg_time:.0f} images/second")
1090
+
1091
+ class DDPMTestSuite:
1092
+ """Main test suite runner with custom reporting"""
1093
+
1094
+ def __init__(self):
1095
+ self.test_classes = [
1096
+ TestVarianceSchedulerDDPM,
1097
+ TestForwardDDPM,
1098
+ TestReverseDDPM,
1099
+ TestTrainDDPM,
1100
+ TestSampleDDPM,
1101
+ TestIntegration,
1102
+ TestEdgeCases
1103
+ ]
1104
+
1105
+ def run_all_tests(self, verbose=True):
1106
+ """Run all tests with custom reporting"""
1107
+ print("=" * 60)
1108
+ print("DDPM IMPLEMENTATION TEST SUITE")
1109
+ print("=" * 60)
1110
+
1111
+ total_tests = 0
1112
+ total_failures = 0
1113
+ total_errors = 0
1114
+
1115
+ for test_class in self.test_classes:
1116
+ print(f"\nRunning {test_class.__name__}...")
1117
+
1118
+ # Create test suite for this class
1119
+ suite = unittest.TestLoader().loadTestsFromTestCase(test_class)
1120
+
1121
+ # Run tests with custom result handling
1122
+ result = unittest.TextTestRunner(
1123
+ verbosity=2 if verbose else 1,
1124
+ stream=open(os.devnull, 'w') if not verbose else None
1125
+ ).run(suite)
1126
+
1127
+ # Count results
1128
+ tests_run = result.testsRun
1129
+ failures = len(result.failures)
1130
+ errors = len(result.errors)
1131
+
1132
+ total_tests += tests_run
1133
+ total_failures += failures
1134
+ total_errors += errors
1135
+
1136
+ # Print summary for this class
1137
+ status = "PASS" if (failures == 0 and errors == 0) else "FAIL"
1138
+ print(f" {test_class.__name__}: {status}")
1139
+ print(f" Tests: {tests_run}, Failures: {failures}, Errors: {errors}")
1140
+
1141
+ # Print failure details
1142
+ if failures > 0:
1143
+ print(" FAILURES:")
1144
+ for test, traceback in result.failures:
1145
+ print(
1146
+ f" - {test}: {traceback.split('AssertionError:')[-1].strip() if 'AssertionError:' in traceback else 'Unknown failure'}")
1147
+
1148
+ if errors > 0:
1149
+ print(" ERRORS:")
1150
+ for test, traceback in result.errors:
1151
+ error_msg = traceback.split('\n')[-2] if len(traceback.split('\n')) > 1 else "Unknown error"
1152
+ print(f" - {test}: {error_msg}")
1153
+
1154
+ # Final summary
1155
+ print("\n" + "=" * 60)
1156
+ print("FINAL RESULTS")
1157
+ print("=" * 60)
1158
+ print(f"Total Tests: {total_tests}")
1159
+ print(f"Passed: {total_tests - total_failures - total_errors}")
1160
+ print(f"Failed: {total_failures}")
1161
+ print(f"Errors: {total_errors}")
1162
+
1163
+ success_rate = ((total_tests - total_failures - total_errors) / total_tests) * 100 if total_tests > 0 else 0
1164
+ print(f"Success Rate: {success_rate:.1f}%")
1165
+
1166
+ if total_failures == 0 and total_errors == 0:
1167
+ print("\nšŸŽ‰ ALL TESTS PASSED! šŸŽ‰")
1168
+ else:
1169
+ print(f"\nāŒ {total_failures + total_errors} tests failed/errored")
1170
+
1171
+ return total_failures == 0 and total_errors == 0
1172
+
1173
+ if __name__ == "__main__":
1174
+ # Suppress warnings for cleaner output
1175
+ warnings.filterwarnings("ignore", category=UserWarning)
1176
+ warnings.filterwarnings("ignore", category=FutureWarning)
1177
+
1178
+ # Run main test suite
1179
+ suite = DDPMTestSuite()
1180
+ all_passed = suite.run_all_tests(verbose=True)
1181
+
1182
+ # Optionally run performance benchmarks
1183
+ run_benchmarks = input("\nRun performance benchmarks? (y/n): ").lower().startswith('y')
1184
+ if run_benchmarks:
1185
+ run_performance_benchmarks()
1186
+
1187
+ # Exit with appropriate code
1188
+ exit(0 if all_passed else 1)