TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
File without changes
@@ -0,0 +1,551 @@
1
+ """
2
+ Comprehensive Test Suite for DDIM Implementation
3
+
4
+ This test suite validates the core components of the DDIM (Denoising Diffusion Implicit Models)
5
+ implementation including forward/reverse diffusion, variance scheduling, training, and sampling.
6
+ """
7
+
8
+ import pytest
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import tempfile
13
+ import os
14
+ import shutil
15
+ from typing import Tuple, List
16
+ from torch.utils.data import DataLoader, TensorDataset
17
+ from transformers import BertTokenizer
18
+ from torchdiff.ddim import (
19
+ ForwardDDIM,
20
+ ReverseDDIM,
21
+ VarianceSchedulerDDIM,
22
+ TrainDDIM,
23
+ SampleDDIM
24
+ )
25
+
26
+
27
+
28
+ # Mock implementations for testing
29
+ class MockNoisePredictor(nn.Module):
30
+ """Mock noise predictor for testing"""
31
+
32
+ def __init__(self, in_channels=1, time_embed_dim=32):
33
+ super().__init__()
34
+ self.conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
35
+ self.time_embed = nn.Linear(1, time_embed_dim)
36
+
37
+ def forward(self, x, t, y_encoded=None, *args):
38
+ # Simple implementation for testing
39
+ batch_size = x.shape[0]
40
+ t_embed = self.time_embed(t.float().unsqueeze(1))
41
+ noise = self.conv(x)
42
+ return noise
43
+
44
+
45
+ class MockConditionalModel(nn.Module):
46
+ """Mock conditional model for testing"""
47
+
48
+ def __init__(self, embed_dim=32):
49
+ super().__init__()
50
+ self.embed = nn.Embedding(1000, embed_dim)
51
+
52
+ def forward(self, input_ids, attention_mask=None):
53
+ return self.embed(input_ids[:, 0]) # Simple mock
54
+
55
+
56
+ class MockMetrics:
57
+ """Mock metrics class for testing"""
58
+
59
+ def __init__(self):
60
+ self.fid = True
61
+ self.metrics = True
62
+ self.lpips = True
63
+
64
+ def forward(self, x_real, x_fake):
65
+ # Return mock metric values
66
+ return 50.0, 0.1, 25.0, 0.8, 0.05 # fid, mse, psnr, ssim, lpips
67
+
68
+
69
+ class TestVarianceSchedulerDDIM:
70
+ """Test cases for VarianceSchedulerDDIM"""
71
+
72
+ def test_initialization_valid_params(self):
73
+ """Test scheduler initialization with valid parameters"""
74
+ scheduler = VarianceSchedulerDDIM(
75
+ num_steps=100,
76
+ tau_num_steps=10,
77
+ beta_start=1e-4,
78
+ beta_end=0.02
79
+ )
80
+ assert scheduler.num_steps == 100
81
+ assert scheduler.tau_num_steps == 10
82
+ assert scheduler.beta_start == 1e-4
83
+ assert scheduler.beta_end == 0.02
84
+
85
+ def test_initialization_invalid_params(self):
86
+ """Test scheduler initialization with invalid parameters"""
87
+ with pytest.raises(ValueError):
88
+ # Invalid beta range
89
+ VarianceSchedulerDDIM(beta_start=0.02, beta_end=1e-4)
90
+
91
+ with pytest.raises(ValueError):
92
+ # Invalid num_steps
93
+ VarianceSchedulerDDIM(num_steps=0)
94
+
95
+ def test_beta_schedule_methods(self):
96
+ """Test different beta scheduling methods"""
97
+ methods = ["linear", "sigmoid", "quadratic", "constant", "inverse_time"]
98
+
99
+ for method in methods:
100
+ scheduler = VarianceSchedulerDDIM(
101
+ num_steps=100,
102
+ beta_method=method
103
+ )
104
+ betas = scheduler.betas
105
+ assert len(betas) == 100
106
+ assert torch.all(betas >= scheduler.beta_start)
107
+ assert torch.all(betas <= scheduler.beta_end)
108
+
109
+ def test_trainable_beta(self):
110
+ """Test trainable beta functionality"""
111
+ scheduler = VarianceSchedulerDDIM(
112
+ num_steps=100,
113
+ trainable_beta=True
114
+ )
115
+
116
+ # Check that beta_raw is a parameter
117
+ assert hasattr(scheduler, 'beta_raw')
118
+ assert isinstance(scheduler.beta_raw, nn.Parameter)
119
+
120
+ # Check that betas are in valid range
121
+ betas = scheduler.betas
122
+ assert torch.all(betas >= scheduler.beta_start)
123
+ assert torch.all(betas <= scheduler.beta_end)
124
+
125
+ def test_tau_schedule(self):
126
+ """Test subsampled (tau) schedule generation"""
127
+ scheduler = VarianceSchedulerDDIM(
128
+ num_steps=1000,
129
+ tau_num_steps=100
130
+ )
131
+
132
+ tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = scheduler.get_tau_schedule()
133
+
134
+ assert len(tau_betas) == 100
135
+ assert len(tau_alphas) == 100
136
+ assert len(tau_alpha_cumprod) == 100
137
+ assert torch.all(tau_alphas == 1 - tau_betas)
138
+
139
+
140
+ class TestForwardDDIM:
141
+ """Test cases for ForwardDDIM"""
142
+
143
+ def test_forward_process(self):
144
+ """Test forward diffusion process"""
145
+ scheduler = VarianceSchedulerDDIM(num_steps=100)
146
+ forward_ddim = ForwardDDIM(scheduler)
147
+
148
+ batch_size, channels, height, width = 4, 1, 28, 28
149
+ x0 = torch.randn(batch_size, channels, height, width)
150
+ noise = torch.randn_like(x0)
151
+ time_steps = torch.randint(0, scheduler.num_steps, (batch_size,))
152
+
153
+ xt = forward_ddim(x0, noise, time_steps)
154
+
155
+ assert xt.shape == x0.shape
156
+ assert not torch.equal(xt, x0) # Should be different due to noise
157
+
158
+ def test_invalid_time_steps(self):
159
+ """Test forward process with invalid time steps"""
160
+ scheduler = VarianceSchedulerDDIM(num_steps=100)
161
+ forward_ddim = ForwardDDIM(scheduler)
162
+
163
+ x0 = torch.randn(2, 1, 28, 28)
164
+ noise = torch.randn_like(x0)
165
+
166
+ # Test time steps out of range
167
+ with pytest.raises(ValueError):
168
+ time_steps = torch.tensor([100, 150]) # Out of range
169
+ forward_ddim(x0, noise, time_steps)
170
+
171
+
172
+ class TestReverseDDIM:
173
+ """Test cases for ReverseDDIM"""
174
+
175
+ def test_reverse_process(self):
176
+ """Test reverse diffusion process"""
177
+ scheduler = VarianceSchedulerDDIM(
178
+ num_steps=1000,
179
+ tau_num_steps=100,
180
+ eta=0.0 # Deterministic
181
+ )
182
+ reverse_ddim = ReverseDDIM(scheduler)
183
+
184
+ batch_size, channels, height, width = 4, 1, 28, 28
185
+ xt = torch.randn(batch_size, channels, height, width)
186
+ predicted_noise = torch.randn_like(xt)
187
+ time_steps = torch.randint(1, scheduler.tau_num_steps, (batch_size,))
188
+ prev_time_steps = time_steps - 1
189
+
190
+ xt_prev, x0 = reverse_ddim(xt, predicted_noise, time_steps, prev_time_steps)
191
+
192
+ assert xt_prev.shape == xt.shape
193
+ assert x0.shape == xt.shape
194
+
195
+ def test_stochastic_sampling(self):
196
+ """Test stochastic sampling with eta > 0"""
197
+ scheduler = VarianceSchedulerDDIM(
198
+ num_steps=1000,
199
+ tau_num_steps=100,
200
+ eta=0.5 # Stochastic
201
+ )
202
+ reverse_ddim = ReverseDDIM(scheduler)
203
+
204
+ xt = torch.randn(2, 1, 28, 28)
205
+ predicted_noise = torch.randn_like(xt)
206
+ time_steps = torch.tensor([50, 60])
207
+ prev_time_steps = torch.tensor([49, 59])
208
+
209
+ # Multiple runs should give different results due to stochasticity
210
+ xt_prev1, _ = reverse_ddim(xt, predicted_noise, time_steps, prev_time_steps)
211
+ xt_prev2, _ = reverse_ddim(xt, predicted_noise, time_steps, prev_time_steps)
212
+
213
+ assert not torch.allclose(xt_prev1, xt_prev2, atol=1e-6)
214
+
215
+
216
+ class TestTrainDDIM:
217
+ """Test cases for TrainDDIM"""
218
+
219
+ def setup_training_components(self):
220
+ """Setup components needed for training tests"""
221
+ # Create mock data
222
+ x = torch.randn(100, 1, 28, 28)
223
+ y = torch.randint(0, 10, (100,)).float() # Mock labels
224
+ dataset = TensorDataset(x, y)
225
+ train_loader = DataLoader(dataset, batch_size=16, shuffle=True)
226
+ val_loader = DataLoader(dataset, batch_size=16, shuffle=False)
227
+
228
+ # Models
229
+ noise_predictor = MockNoisePredictor()
230
+ conditional_model = MockConditionalModel()
231
+
232
+ # DDIM components
233
+ scheduler = VarianceSchedulerDDIM(num_steps=100, tau_num_steps=10)
234
+ forward_ddim = ForwardDDIM(scheduler)
235
+ reverse_ddim = ReverseDDIM(scheduler)
236
+
237
+ # Training components
238
+ optimizer = torch.optim.Adam(
239
+ list(noise_predictor.parameters()) + list(conditional_model.parameters()),
240
+ lr=1e-3
241
+ )
242
+ loss_fn = nn.MSELoss()
243
+ metrics = MockMetrics()
244
+
245
+ return {
246
+ 'train_loader': train_loader,
247
+ 'val_loader': val_loader,
248
+ 'noise_predictor': noise_predictor,
249
+ 'conditional_model': conditional_model,
250
+ 'forward_ddim': forward_ddim,
251
+ 'reverse_ddim': reverse_ddim,
252
+ 'optimizer': optimizer,
253
+ 'loss_fn': loss_fn,
254
+ 'metrics': metrics
255
+ }
256
+
257
+ def test_trainer_initialization(self):
258
+ """Test TrainDDIM initialization"""
259
+ components = self.setup_training_components()
260
+
261
+ with tempfile.TemporaryDirectory() as temp_dir:
262
+ trainer = TrainDDIM(
263
+ noise_predictor=components['noise_predictor'],
264
+ forward_diffusion=components['forward_ddim'],
265
+ reverse_diffusion=components['reverse_ddim'],
266
+ data_loader=components['train_loader'],
267
+ optimizer=components['optimizer'],
268
+ objective=components['loss_fn'],
269
+ val_loader=components['val_loader'],
270
+ max_epochs=2,
271
+ device='cpu',
272
+ conditional_model=components['conditional_model'],
273
+ metrics_=components['metrics'],
274
+ store_path=temp_dir,
275
+ use_ddp=False
276
+ )
277
+
278
+ assert trainer.max_epochs == 2
279
+ assert trainer.device == torch.device('cpu')
280
+
281
+ def test_training_loop(self):
282
+ """Test basic training functionality"""
283
+ components = self.setup_training_components()
284
+
285
+ with tempfile.TemporaryDirectory() as temp_dir:
286
+ trainer = TrainDDIM(
287
+ noise_predictor=components['noise_predictor'],
288
+ forward_diffusion=components['forward_ddim'],
289
+ reverse_diffusion=components['reverse_ddim'],
290
+ data_loader=components['train_loader'],
291
+ optimizer=components['optimizer'],
292
+ objective=components['loss_fn'],
293
+ max_epochs=2,
294
+ device='cpu',
295
+ store_path=temp_dir,
296
+ use_ddp=False,
297
+ log_frequency=1
298
+ )
299
+
300
+ # Run training
301
+ train_losses, best_val_loss = trainer()
302
+
303
+ assert len(train_losses) == 2
304
+ assert isinstance(best_val_loss, float)
305
+
306
+ def test_checkpoint_save_load(self):
307
+ """Test checkpoint saving and loading"""
308
+ components = self.setup_training_components()
309
+
310
+ with tempfile.TemporaryDirectory() as temp_dir:
311
+ trainer = TrainDDIM(
312
+ noise_predictor=components['noise_predictor'],
313
+ forward_diffusion=components['forward_ddim'],
314
+ reverse_diffusion=components['reverse_ddim'],
315
+ data_loader=components['train_loader'],
316
+ optimizer=components['optimizer'],
317
+ objective=components['loss_fn'],
318
+ max_epochs=1,
319
+ device='cpu',
320
+ store_path=temp_dir,
321
+ use_ddp=False
322
+ )
323
+
324
+ # Save checkpoint manually
325
+ trainer._save_checkpoint(1, 0.5)
326
+
327
+ # Check if checkpoint file exists
328
+ checkpoint_files = [f for f in os.listdir(temp_dir) if f.endswith('.pth')]
329
+ assert len(checkpoint_files) > 0
330
+
331
+ # Load checkpoint
332
+ checkpoint_path = os.path.join(temp_dir, checkpoint_files[0])
333
+ epoch, loss = trainer.load_checkpoint(checkpoint_path)
334
+
335
+ assert epoch == 1
336
+ assert loss == 0.5
337
+
338
+
339
+ class TestSampleDDIM:
340
+ """Test cases for SampleDDIM"""
341
+
342
+ def setup_sampling_components(self):
343
+ """Setup components needed for sampling tests"""
344
+ scheduler = VarianceSchedulerDDIM(num_steps=100, tau_num_steps=10)
345
+ reverse_ddim = ReverseDDIM(scheduler)
346
+ noise_predictor = MockNoisePredictor()
347
+ conditional_model = MockConditionalModel()
348
+
349
+ return {
350
+ 'reverse_ddim': reverse_ddim,
351
+ 'noise_predictor': noise_predictor,
352
+ 'conditional_model': conditional_model
353
+ }
354
+
355
+ def test_sampler_initialization(self):
356
+ """Test SampleDDIM initialization"""
357
+ components = self.setup_sampling_components()
358
+
359
+ sampler = SampleDDIM(
360
+ reverse_diffusion=components['reverse_ddim'],
361
+ noise_predictor=components['noise_predictor'],
362
+ image_shape=(28, 28),
363
+ batch_size=4,
364
+ device='cpu'
365
+ )
366
+
367
+ assert sampler.image_shape == (28, 28)
368
+ assert sampler.batch_size == 4
369
+ assert sampler.device == torch.device('cpu')
370
+
371
+ def test_unconditional_sampling(self):
372
+ """Test unconditional image generation"""
373
+ components = self.setup_sampling_components()
374
+
375
+ with tempfile.TemporaryDirectory() as temp_dir:
376
+ sampler = SampleDDIM(
377
+ reverse_diffusion=components['reverse_ddim'],
378
+ noise_predictor=components['noise_predictor'],
379
+ image_shape=(28, 28),
380
+ batch_size=2,
381
+ in_channels=1, # Match the mock predictor
382
+ device='cpu',
383
+ image_output_range=(-1, 1)
384
+ )
385
+
386
+ # Generate samples
387
+ generated_images = sampler(
388
+ conditions=None,
389
+ save_images=True,
390
+ save_path=temp_dir
391
+ )
392
+
393
+ assert generated_images.shape == (2, 1, 28, 28) # 1 channel to match predictor
394
+ assert torch.all(generated_images >= 0) and torch.all(generated_images <= 1)
395
+
396
+ def test_conditional_sampling(self):
397
+ """Test conditional image generation"""
398
+ components = self.setup_sampling_components()
399
+
400
+ with tempfile.TemporaryDirectory() as temp_dir:
401
+ sampler = SampleDDIM(
402
+ reverse_diffusion=components['reverse_ddim'],
403
+ noise_predictor=components['noise_predictor'],
404
+ conditional_model=components['conditional_model'],
405
+ image_shape=(28, 28),
406
+ batch_size=2,
407
+ in_channels=1, # Match the mock predictor
408
+ device='cpu'
409
+ )
410
+
411
+ # Test with string prompts
412
+ conditions = ["cat", "dog"]
413
+ generated_images = sampler(
414
+ conditions=conditions,
415
+ save_images=False
416
+ )
417
+
418
+ assert generated_images.shape == (2, 1, 28, 28) # 1 channel to match predictor
419
+
420
+ def test_tokenization(self):
421
+ """Test text tokenization functionality"""
422
+ components = self.setup_sampling_components()
423
+
424
+ sampler = SampleDDIM(
425
+ reverse_diffusion=components['reverse_ddim'],
426
+ noise_predictor=components['noise_predictor'],
427
+ conditional_model=components['conditional_model'],
428
+ image_shape=(28, 28),
429
+ device='cpu'
430
+ )
431
+
432
+ prompts = ["a cat", "a dog"]
433
+ input_ids, attention_mask = sampler.tokenize(prompts)
434
+
435
+ assert input_ids.shape[0] == 2 # Batch size
436
+ assert attention_mask.shape[0] == 2
437
+ assert input_ids.shape[1] == sampler.max_token_length
438
+
439
+
440
+ class TestIntegrationDDIM:
441
+ """Integration tests for full DDIM pipeline"""
442
+
443
+ def test_end_to_end_pipeline(self):
444
+ """Test complete training and sampling pipeline"""
445
+ # Setup data
446
+ x = torch.randn(20, 1, 16, 16) # Smaller for faster testing
447
+ y = torch.randint(0, 5, (20,)).float()
448
+ dataset = TensorDataset(x, y)
449
+ train_loader = DataLoader(dataset, batch_size=4, shuffle=True)
450
+
451
+ # Setup models
452
+ noise_predictor = MockNoisePredictor()
453
+ conditional_model = MockConditionalModel()
454
+
455
+ # Setup DDIM components
456
+ scheduler = VarianceSchedulerDDIM(num_steps=50, tau_num_steps=5)
457
+ forward_ddim = ForwardDDIM(scheduler)
458
+ reverse_ddim = ReverseDDIM(scheduler)
459
+
460
+ # Setup training
461
+ optimizer = torch.optim.Adam(
462
+ list(noise_predictor.parameters()) + list(conditional_model.parameters()),
463
+ lr=1e-3
464
+ )
465
+ loss_fn = nn.MSELoss()
466
+
467
+ with tempfile.TemporaryDirectory() as temp_dir:
468
+ # Train model
469
+ trainer = TrainDDIM(
470
+ noise_predictor=noise_predictor,
471
+ forward_diffusion=forward_ddim,
472
+ reverse_diffusion=reverse_ddim,
473
+ data_loader=train_loader,
474
+ optimizer=optimizer,
475
+ objective=loss_fn,
476
+ max_epochs=1,
477
+ device='cpu',
478
+ conditional_model=conditional_model,
479
+ store_path=temp_dir,
480
+ use_ddp=False
481
+ )
482
+
483
+ train_losses, _ = trainer()
484
+ assert len(train_losses) == 1
485
+
486
+ # Test sampling
487
+ sampler = SampleDDIM(
488
+ reverse_diffusion=reverse_ddim,
489
+ noise_predictor=noise_predictor,
490
+ conditional_model=conditional_model,
491
+ image_shape=(16, 16),
492
+ batch_size=2,
493
+ in_channels=1,
494
+ device='cpu'
495
+ )
496
+
497
+ conditions = ["test1", "test2"]
498
+ generated_images = sampler(conditions=conditions, save_images=False)
499
+
500
+ assert generated_images.shape == (2, 1, 16, 16)
501
+
502
+ def test_forward_reverse_consistency(self):
503
+ """Test that forward and reverse processes are mathematically consistent"""
504
+ scheduler = VarianceSchedulerDDIM(
505
+ num_steps=100,
506
+ tau_num_steps=10,
507
+ eta=0.0 # Deterministic
508
+ )
509
+
510
+ forward_ddim = ForwardDDIM(scheduler)
511
+ reverse_ddim = ReverseDDIM(scheduler)
512
+ noise_predictor = MockNoisePredictor()
513
+
514
+ # Original image
515
+ x0 = torch.randn(1, 1, 16, 16)
516
+
517
+ # Apply forward process using full timestep range
518
+ noise = torch.randn_like(x0)
519
+ t_full = torch.tensor([50]) # Use full timestep range for forward
520
+ xt = forward_ddim(x0, noise, t_full)
521
+
522
+ # Apply reverse process using tau (subsampled) timestep range
523
+ t_tau = torch.tensor([5]) # Middle of tau range (0-9)
524
+ prev_t_tau = torch.tensor([4]) # Previous tau timestep
525
+ xt_prev, x0_pred = reverse_ddim(xt, noise, t_tau, prev_t_tau)
526
+
527
+ # Check that shapes are consistent
528
+ assert xt_prev.shape == x0.shape
529
+ assert x0_pred.shape == x0.shape
530
+
531
+
532
+ # Pytest configuration and runner
533
+ if __name__ == "__main__":
534
+ # Run specific test categories
535
+ print("Running VarianceSchedulerDDIM tests...")
536
+ pytest.main(["-v", "TestVarianceSchedulerDDIM"])
537
+
538
+ print("\nRunning ForwardDDIM tests...")
539
+ pytest.main(["-v", "TestForwardDDIM"])
540
+
541
+ print("\nRunning ReverseDDIM tests...")
542
+ pytest.main(["-v", "TestReverseDDIM"])
543
+
544
+ print("\nRunning TrainDDIM tests...")
545
+ pytest.main(["-v", "TestTrainDDIM"])
546
+
547
+ print("\nRunning SampleDDIM tests...")
548
+ pytest.main(["-v", "TestSampleDDIM"])
549
+
550
+ print("\nRunning Integration tests...")
551
+ pytest.main(["-v", "TestIntegrationDDIM"])