TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
@@ -0,0 +1,742 @@
1
+ import pytest
2
+ import torch
3
+ import torch.nn as nn
4
+ import tempfile
5
+ import shutil
6
+ import os
7
+ from unittest.mock import Mock, patch
8
+ from torch.utils.data import DataLoader, TensorDataset
9
+ from torchvision import transforms
10
+ import numpy as np
11
+
12
+ # Import the LDM components (assuming they're in ldm.py)
13
+ from torchdiff.ldm import (
14
+ AutoencoderLDM, TrainAE, TrainLDM, SampleLDM,
15
+ VectorQuantizer, DownBlock, UpBlock, Conv3,
16
+ DownSampling, UpSampling, Attention
17
+ )
18
+ from torchdiff.sde import ForwardSDE, ReverseSDE, VarianceSchedulerSDE
19
+
20
+
21
+
22
+ # mock utility classes that would normally come from torchdiff.utils.py
23
+ class MockTextEncoder(nn.Module):
24
+ def __init__(self, output_dim=32):
25
+ super().__init__()
26
+ self.output_dim = output_dim
27
+
28
+ def forward(self, input_ids, attention_mask):
29
+ batch_size = input_ids.shape[0]
30
+ return torch.randn(batch_size, self.output_dim)
31
+
32
+
33
+ class MockNoisePredictor(nn.Module):
34
+ def __init__(self, in_channels=2):
35
+ super().__init__()
36
+ self.in_channels = in_channels
37
+ self.conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
38
+
39
+ def forward(self, x, t, y=None, context=None):
40
+ return self.conv(x)
41
+
42
+
43
+ class MockMetrics:
44
+ def __init__(self, device="cpu", fid=True, metrics=True, lpips_=True):
45
+ self.device = device
46
+ self.fid = fid
47
+ self.metrics = metrics
48
+ self.lpips = lpips_
49
+
50
+ def forward(self, x_real, x_fake):
51
+ return (
52
+ torch.tensor(10.0) if self.fid else float('inf'), # FID
53
+ torch.tensor(0.1) if self.metrics else None, # MSE
54
+ torch.tensor(25.0) if self.metrics else None, # PSNR
55
+ torch.tensor(0.8) if self.metrics else None, # SSIM
56
+ torch.tensor(0.2) if self.lpips else None # LPIPS
57
+ )
58
+
59
+
60
+ class TestAutoencoderLDM:
61
+ """Test suite for AutoencoderLDM component."""
62
+
63
+ @pytest.fixture
64
+ def autoencoder_config(self):
65
+ """Standard configuration for autoencoder tests."""
66
+ return {
67
+ 'in_channels': 3,
68
+ 'down_channels': [8, 16],
69
+ 'up_channels': [16, 8],
70
+ 'out_channels': 3,
71
+ 'dropout_rate': 0.1,
72
+ 'num_heads': 1,
73
+ 'num_groups': 8,
74
+ 'num_layers_per_block': 2,
75
+ 'total_down_sampling_factor': 2,
76
+ 'latent_channels': 4,
77
+ 'num_embeddings': 32,
78
+ 'use_vq': False,
79
+ 'beta': 1.0
80
+ }
81
+
82
+ @pytest.fixture
83
+ def sample_data(self):
84
+ """Sample input data for testing."""
85
+ return torch.randn(2, 3, 32, 32)
86
+
87
+ def test_autoencoder_initialization(self, autoencoder_config):
88
+ """Test AutoencoderLDM initialization with different configurations."""
89
+ # Test KL-divergence mode
90
+ model_kl = AutoencoderLDM(**autoencoder_config)
91
+ assert not model_kl.use_vq
92
+ assert model_kl.beta == 1.0
93
+
94
+ # Test VQ mode
95
+ config_vq = autoencoder_config.copy()
96
+ config_vq['use_vq'] = True
97
+ model_vq = AutoencoderLDM(**config_vq)
98
+ assert model_vq.use_vq
99
+ assert hasattr(model_vq, 'vq_layer')
100
+
101
+ def test_autoencoder_forward_pass(self, autoencoder_config, sample_data):
102
+ """Test forward pass of AutoencoderLDM."""
103
+ model = AutoencoderLDM(**autoencoder_config)
104
+ model.eval()
105
+
106
+ with torch.no_grad():
107
+ x_hat, total_loss, reg_loss, z = model(sample_data)
108
+
109
+ # Check output shapes
110
+ assert x_hat.shape == sample_data.shape
111
+ assert isinstance(total_loss, float)
112
+ assert isinstance(reg_loss, (float, torch.Tensor))
113
+
114
+ # Check latent shape
115
+ expected_latent_shape = (2, 4, 16, 16) # Based on config
116
+ assert z.shape == expected_latent_shape
117
+
118
+ def test_encode_decode_consistency(self, autoencoder_config, sample_data):
119
+ """Test encode-decode cycle."""
120
+ model = AutoencoderLDM(**autoencoder_config)
121
+ model.eval()
122
+
123
+ with torch.no_grad():
124
+ z, reg_loss = model.encode(sample_data)
125
+ x_reconstructed = model.decode(z)
126
+
127
+ assert x_reconstructed.shape == sample_data.shape
128
+ assert z.shape[0] == sample_data.shape[0] # Batch size preserved
129
+
130
+ def test_vq_functionality(self, autoencoder_config, sample_data):
131
+ """Test Vector Quantization functionality."""
132
+ config_vq = autoencoder_config.copy()
133
+ config_vq['use_vq'] = True
134
+ model = AutoencoderLDM(**config_vq)
135
+ model.eval()
136
+
137
+ with torch.no_grad():
138
+ z, vq_loss = model.encode(sample_data)
139
+
140
+ assert isinstance(vq_loss, torch.Tensor)
141
+ assert vq_loss.requires_grad == False
142
+
143
+ def test_reparameterization_trick(self, autoencoder_config):
144
+ """Test reparameterization trick for VAE."""
145
+ model = AutoencoderLDM(**autoencoder_config)
146
+
147
+ mu = torch.randn(2, 8, 16, 16)
148
+ logvar = torch.randn(2, 8, 16, 16)
149
+
150
+ z = model.reparameterize(mu, logvar)
151
+ assert z.shape == mu.shape
152
+
153
+ # Test deterministic behavior when logvar is very negative
154
+ logvar_zero = torch.full_like(logvar, -20)
155
+ z_det = model.reparameterize(mu, logvar_zero)
156
+ torch.testing.assert_close(z_det, mu, atol=1e-3, rtol=1e-3)
157
+
158
+
159
+ class TestVectorQuantizer:
160
+ """Test suite for VectorQuantizer component."""
161
+
162
+ def test_vq_initialization(self):
163
+ """Test VectorQuantizer initialization."""
164
+ vq = VectorQuantizer(num_embeddings=64, embedding_dim=32)
165
+
166
+ assert vq.num_embeddings == 64
167
+ assert vq.embedding_dim == 32
168
+ assert vq.embedding.weight.shape == (64, 32)
169
+
170
+ def test_vq_forward_pass(self):
171
+ """Test VectorQuantizer forward pass."""
172
+ vq = VectorQuantizer(num_embeddings=64, embedding_dim=16)
173
+ x = torch.randn(2, 16, 8, 8)
174
+
175
+ quantized, vq_loss = vq(x)
176
+
177
+ assert quantized.shape == x.shape
178
+ assert isinstance(vq_loss, torch.Tensor)
179
+ assert vq_loss.requires_grad
180
+
181
+
182
+ class TestConvolutionalBlocks:
183
+ """Test suite for convolutional building blocks."""
184
+
185
+ def test_conv3_block(self):
186
+ """Test Conv3 block."""
187
+ conv3 = Conv3(in_channels=16, out_channels=32, dropout_rate=0.1)
188
+ x = torch.randn(2, 16, 32, 32)
189
+
190
+ output = conv3(x)
191
+ assert output.shape == (2, 32, 32, 32)
192
+
193
+ def test_down_block(self):
194
+ """Test DownBlock."""
195
+ down_block = DownBlock(
196
+ in_channels=16, out_channels=32,
197
+ num_layers=2, down_sampling_factor=2,
198
+ dropout_rate=0.1
199
+ )
200
+ x = torch.randn(2, 16, 32, 32)
201
+
202
+ output = down_block(x)
203
+ assert output.shape == (2, 32, 16, 16) # Downsampled by factor of 2
204
+
205
+ def test_up_block(self):
206
+ """Test UpBlock."""
207
+ up_block = UpBlock(
208
+ in_channels=32, out_channels=16,
209
+ num_layers=2, up_sampling_factor=2,
210
+ dropout_rate=0.1
211
+ )
212
+ x = torch.randn(2, 32, 16, 16)
213
+
214
+ output = up_block(x)
215
+ assert output.shape == (2, 16, 32, 32) # Upsampled by factor of 2
216
+
217
+ def test_attention_block(self):
218
+ """Test Attention block."""
219
+ attention = Attention(
220
+ num_channels=32, num_heads=4,
221
+ num_groups=8, dropout_rate=0.1
222
+ )
223
+ x = torch.randn(2, 32, 16, 16)
224
+
225
+ output = attention(x)
226
+ assert output.shape == x.shape
227
+
228
+
229
+ class TestSamplingLayers:
230
+ """Test suite for sampling layers."""
231
+
232
+ def test_down_sampling(self):
233
+ """Test DownSampling layer."""
234
+ down_sample = DownSampling(
235
+ in_channels=16, out_channels=32,
236
+ down_sampling_factor=2
237
+ )
238
+ x = torch.randn(2, 16, 32, 32)
239
+
240
+ output = down_sample(x)
241
+ assert output.shape == (2, 32, 16, 16)
242
+
243
+ def test_up_sampling(self):
244
+ """Test UpSampling layer."""
245
+ up_sample = UpSampling(
246
+ in_channels=32, out_channels=16,
247
+ up_sampling_factor=2
248
+ )
249
+ x = torch.randn(2, 32, 16, 16)
250
+
251
+ output = up_sample(x)
252
+ assert output.shape == (2, 16, 32, 32)
253
+
254
+
255
+ class TestVarianceSchedulerSDE:
256
+ """Test suite for SDE variance scheduler."""
257
+
258
+ def test_scheduler_initialization(self):
259
+ """Test VarianceSchedulerSDE initialization."""
260
+ scheduler = VarianceSchedulerSDE(
261
+ num_steps=100, beta_start=1e-4, beta_end=0.02,
262
+ trainable_beta=False
263
+ )
264
+
265
+ assert scheduler.num_steps == 100
266
+ assert scheduler.beta_start == 1e-4
267
+ assert scheduler.beta_end == 0.02
268
+ assert not scheduler.trainable_beta
269
+
270
+ def test_beta_schedules(self):
271
+ """Test different beta scheduling methods."""
272
+ methods = ["linear", "sigmoid", "quadratic", "constant"]
273
+
274
+ for method in methods:
275
+ scheduler = VarianceSchedulerSDE(
276
+ num_steps=100, beta_start=1e-4, beta_end=0.02,
277
+ beta_method=method, trainable_beta=False
278
+ )
279
+
280
+ betas = scheduler.betas
281
+ assert betas.shape == (100,)
282
+ assert torch.all(betas >= 1e-4)
283
+ assert torch.all(betas <= 0.02)
284
+
285
+ def test_trainable_beta(self):
286
+ """Test trainable beta functionality."""
287
+ scheduler = VarianceSchedulerSDE(
288
+ num_steps=50, trainable_beta=True
289
+ )
290
+
291
+ # Check that beta_raw is a parameter
292
+ assert hasattr(scheduler, 'beta_raw')
293
+ assert isinstance(scheduler.beta_raw, nn.Parameter)
294
+
295
+ # Check that betas are in valid range
296
+ betas = scheduler.betas
297
+ assert torch.all(betas >= scheduler.beta_start)
298
+ assert torch.all(betas <= scheduler.beta_end)
299
+
300
+ def test_variance_computation(self):
301
+ """Test variance computation for different SDE methods."""
302
+ scheduler = VarianceSchedulerSDE(num_steps=100, trainable_beta=False)
303
+ time_steps = torch.tensor([10, 20, 30])
304
+
305
+ methods = ["ve", "vp", "sub-vp"]
306
+ for method in methods:
307
+ variance = scheduler.get_variance(time_steps, method)
308
+ assert variance.shape == time_steps.shape
309
+ assert torch.all(variance >= 0)
310
+
311
+
312
+ class TestSDEProcesses:
313
+ """Test suite for SDE forward and reverse processes."""
314
+
315
+ @pytest.fixture
316
+ def sde_setup(self):
317
+ """Setup SDE components for testing."""
318
+ scheduler = VarianceSchedulerSDE(
319
+ num_steps=100, beta_start=1e-4, beta_end=0.02
320
+ )
321
+ return scheduler
322
+
323
+ def test_forward_sde_methods(self, sde_setup):
324
+ """Test ForwardSDE with different methods."""
325
+ methods = ["ve", "vp", "sub-vp", "ode"]
326
+ x0 = torch.randn(2, 3, 32, 32)
327
+ noise = torch.randn_like(x0)
328
+ time_steps = torch.tensor([10, 20])
329
+
330
+ for method in methods:
331
+ forward_sde = ForwardSDE(sde_setup, method)
332
+ xt = forward_sde(x0, noise, time_steps)
333
+
334
+ assert xt.shape == x0.shape
335
+ assert not torch.isnan(xt).any()
336
+
337
+ def test_reverse_sde_methods(self, sde_setup):
338
+ """Test ReverseSDE with different methods."""
339
+ methods = ["ve", "vp", "sub-vp", "ode"]
340
+ xt = torch.randn(2, 3, 32, 32)
341
+ noise = torch.randn_like(xt)
342
+ predicted_noise = torch.randn_like(xt)
343
+ time_steps = torch.tensor([10, 20])
344
+
345
+ for method in methods:
346
+ reverse_sde = ReverseSDE(sde_setup, method)
347
+ x_prev = reverse_sde(xt, noise, predicted_noise, time_steps)
348
+
349
+ assert x_prev.shape == xt.shape
350
+ assert not torch.isnan(x_prev).any()
351
+
352
+ def test_ode_method_without_noise(self, sde_setup):
353
+ """Test ODE method works without noise."""
354
+ forward_sde = ForwardSDE(sde_setup, "ode")
355
+ reverse_sde = ReverseSDE(sde_setup, "ode")
356
+
357
+ x0 = torch.randn(2, 3, 32, 32)
358
+ xt = torch.randn_like(x0)
359
+ predicted_noise = torch.randn_like(x0)
360
+ time_steps = torch.tensor([10, 20])
361
+
362
+ # Forward SDE with ODE method
363
+ xt_forward = forward_sde(x0, torch.randn_like(x0), time_steps)
364
+
365
+ # Reverse SDE with ODE method (noise can be None)
366
+ x_prev = reverse_sde(xt, None, predicted_noise, time_steps)
367
+
368
+ assert not torch.isnan(xt_forward).any()
369
+ assert not torch.isnan(x_prev).any()
370
+
371
+
372
+ class TestTrainAE:
373
+ """Test suite for AutoEncoder trainer."""
374
+
375
+ @pytest.fixture
376
+ def training_setup(self):
377
+ """Setup training components."""
378
+ # Create simple dataset
379
+ data = torch.randn(20, 3, 32, 32)
380
+ labels = torch.randint(0, 10, (20,))
381
+ dataset = TensorDataset(data, labels)
382
+ train_loader = DataLoader(dataset, batch_size=4)
383
+ val_loader = DataLoader(dataset, batch_size=4)
384
+
385
+ # Create model and optimizer
386
+ model = AutoencoderLDM(
387
+ in_channels=3, down_channels=[8, 16], up_channels=[16, 8],
388
+ out_channels=3, dropout_rate=0.1, latent_channels=4,
389
+ num_heads=1, num_groups=8, num_layers_per_block=2,
390
+ total_down_sampling_factor=2, num_embeddings=32
391
+ )
392
+
393
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
394
+ metrics = MockMetrics()
395
+
396
+ return {
397
+ 'model': model,
398
+ 'optimizer': optimizer,
399
+ 'train_loader': train_loader,
400
+ 'val_loader': val_loader,
401
+ 'metrics': metrics
402
+ }
403
+
404
+ def test_train_ae_initialization(self, training_setup):
405
+ """Test TrainAE initialization."""
406
+ trainer = TrainAE(
407
+ model=training_setup['model'],
408
+ optimizer=training_setup['optimizer'],
409
+ data_loader=training_setup['train_loader'],
410
+ val_loader=training_setup['val_loader'],
411
+ max_epochs=2,
412
+ metrics_=training_setup['metrics'],
413
+ device='cpu'
414
+ )
415
+
416
+ assert trainer.max_epochs == 2
417
+ assert trainer.device.type == 'cpu'
418
+
419
+ def test_train_ae_forward(self, training_setup):
420
+ """Test TrainAE training loop."""
421
+ with tempfile.TemporaryDirectory() as temp_dir:
422
+ trainer = TrainAE(
423
+ model=training_setup['model'],
424
+ optimizer=training_setup['optimizer'],
425
+ data_loader=training_setup['train_loader'],
426
+ val_loader=training_setup['val_loader'],
427
+ max_epochs=2,
428
+ metrics_=training_setup['metrics'],
429
+ device='cpu',
430
+ store_path=temp_dir,
431
+ val_frequency=1,
432
+ log_frequency=1
433
+ )
434
+
435
+ train_losses, best_val_loss = trainer()
436
+
437
+ assert len(train_losses) <= 2
438
+ assert isinstance(best_val_loss, float)
439
+ assert best_val_loss < float('inf')
440
+
441
+
442
+ class TestTrainLDM:
443
+ """Test suite for LDM trainer."""
444
+
445
+ @pytest.fixture
446
+ def ldm_training_setup(self):
447
+ """Setup LDM training components."""
448
+ # Create dataset with text labels
449
+ data = torch.randn(16, 3, 32, 32)
450
+ labels = ['class_' + str(i % 4) for i in range(16)]
451
+ dataset = TensorDataset(data, labels)
452
+ train_loader = DataLoader(dataset, batch_size=4)
453
+ val_loader = DataLoader(dataset, batch_size=4)
454
+
455
+ # Create components
456
+ compressor = AutoencoderLDM(
457
+ in_channels=3, down_channels=[8, 16], up_channels=[16, 8],
458
+ out_channels=3, dropout_rate=0.1, latent_channels=4,
459
+ num_heads=1, num_groups=8, num_layers_per_block=2,
460
+ total_down_sampling_factor=2, num_embeddings=32
461
+ )
462
+
463
+ noise_predictor = MockNoisePredictor(in_channels=4)
464
+ text_encoder = MockTextEncoder(output_dim=32)
465
+
466
+ scheduler = VarianceSchedulerSDE(num_steps=50, trainable_beta=False)
467
+ forward_sde = ForwardSDE(scheduler, "ode")
468
+ reverse_sde = ReverseSDE(scheduler, "ode")
469
+
470
+ optimizer = torch.optim.Adam([
471
+ *noise_predictor.parameters(),
472
+ *text_encoder.parameters()
473
+ ], lr=1e-3)
474
+
475
+ return {
476
+ 'compressor': compressor,
477
+ 'noise_predictor': noise_predictor,
478
+ 'text_encoder': text_encoder,
479
+ 'forward_sde': forward_sde,
480
+ 'reverse_sde': reverse_sde,
481
+ 'optimizer': optimizer,
482
+ 'train_loader': train_loader,
483
+ 'val_loader': val_loader
484
+ }
485
+
486
+ def test_train_ldm_initialization(self, ldm_training_setup):
487
+ """Test TrainLDM initialization."""
488
+ setup = ldm_training_setup
489
+
490
+ trainer = TrainLDM(
491
+ diffusion_model="sde",
492
+ forward_diffusion=setup['forward_sde'],
493
+ reverse_diffusion=setup['reverse_sde'],
494
+ noise_predictor=setup['noise_predictor'],
495
+ compressor_model=setup['compressor'],
496
+ optimizer=setup['optimizer'],
497
+ objective=nn.MSELoss(),
498
+ data_loader=setup['train_loader'],
499
+ val_loader=setup['val_loader'],
500
+ conditional_model=setup['text_encoder'],
501
+ max_epochs=2,
502
+ device='cpu'
503
+ )
504
+
505
+ assert trainer.diffusion_model == "sde"
506
+ assert trainer.max_epochs == 2
507
+
508
+ def test_train_ldm_forward(self, ldm_training_setup):
509
+ """Test TrainLDM training loop."""
510
+ setup = ldm_training_setup
511
+
512
+ with tempfile.TemporaryDirectory() as temp_dir:
513
+ trainer = TrainLDM(
514
+ diffusion_model="sde",
515
+ forward_diffusion=setup['forward_sde'],
516
+ reverse_diffusion=setup['reverse_sde'],
517
+ noise_predictor=setup['noise_predictor'],
518
+ compressor_model=setup['compressor'],
519
+ optimizer=setup['optimizer'],
520
+ objective=nn.MSELoss(),
521
+ data_loader=setup['train_loader'],
522
+ val_loader=setup['val_loader'],
523
+ conditional_model=setup['text_encoder'],
524
+ max_epochs=2,
525
+ device='cpu',
526
+ store_path=temp_dir,
527
+ val_frequency=1,
528
+ log_frequency=1,
529
+ metrics_=MockMetrics()
530
+ )
531
+
532
+ train_losses, best_val_loss = trainer()
533
+
534
+ assert len(train_losses) <= 2
535
+ assert isinstance(best_val_loss, float)
536
+
537
+
538
+ class TestSampleLDM:
539
+ """Test suite for LDM sampler."""
540
+
541
+ @pytest.fixture
542
+ def sampling_setup(self):
543
+ """Setup sampling components."""
544
+ compressor = AutoencoderLDM(
545
+ in_channels=3, down_channels=[8, 16], up_channels=[16, 8],
546
+ out_channels=3, dropout_rate=0.1, latent_channels=4,
547
+ num_heads=1, num_groups=8, num_layers_per_block=2,
548
+ total_down_sampling_factor=2, num_embeddings=32
549
+ )
550
+
551
+ noise_predictor = MockNoisePredictor(in_channels=4)
552
+ text_encoder = MockTextEncoder(output_dim=32)
553
+
554
+ scheduler = VarianceSchedulerSDE(num_steps=50, trainable_beta=False)
555
+ reverse_sde = ReverseSDE(scheduler, "ode")
556
+
557
+ return {
558
+ 'compressor': compressor,
559
+ 'noise_predictor': noise_predictor,
560
+ 'text_encoder': text_encoder,
561
+ 'reverse_sde': reverse_sde
562
+ }
563
+
564
+ def test_sample_ldm_initialization(self, sampling_setup):
565
+ """Test SampleLDM initialization."""
566
+ setup = sampling_setup
567
+
568
+ sampler = SampleLDM(
569
+ diffusion_model="sde",
570
+ reverse_diffusion=setup['reverse_sde'],
571
+ noise_predictor=setup['noise_predictor'],
572
+ compressor_model=setup['compressor'],
573
+ image_shape=(32, 32),
574
+ conditional_model=setup['text_encoder'],
575
+ batch_size=2,
576
+ device='cpu'
577
+ )
578
+
579
+ assert sampler.diffusion_model == "sde"
580
+ assert sampler.batch_size == 2
581
+ assert sampler.image_shape == (32, 32)
582
+
583
+ def test_sample_ldm_tokenization(self, sampling_setup):
584
+ """Test text tokenization in SampleLDM."""
585
+ setup = sampling_setup
586
+
587
+ sampler = SampleLDM(
588
+ diffusion_model="sde",
589
+ reverse_diffusion=setup['reverse_sde'],
590
+ noise_predictor=setup['noise_predictor'],
591
+ compressor_model=setup['compressor'],
592
+ image_shape=(32, 32),
593
+ conditional_model=setup['text_encoder'],
594
+ batch_size=2,
595
+ device='cpu'
596
+ )
597
+
598
+ # Test single prompt
599
+ input_ids, attention_mask = sampler.tokenize("test prompt")
600
+ assert input_ids.shape[0] == 1
601
+
602
+ # Test multiple prompts
603
+ input_ids, attention_mask = sampler.tokenize(["prompt1", "prompt2"])
604
+ assert input_ids.shape[0] == 2
605
+
606
+ def test_sample_ldm_generation(self, sampling_setup):
607
+ """Test image generation with SampleLDM."""
608
+ setup = sampling_setup
609
+
610
+ with tempfile.TemporaryDirectory() as temp_dir:
611
+ sampler = SampleLDM(
612
+ diffusion_model="sde",
613
+ reverse_diffusion=setup['reverse_sde'],
614
+ noise_predictor=setup['noise_predictor'],
615
+ compressor_model=setup['compressor'],
616
+ image_shape=(32, 32),
617
+ conditional_model=setup['text_encoder'],
618
+ batch_size=2,
619
+ device='cpu'
620
+ )
621
+
622
+ # Test unconditional generation
623
+ images = sampler(
624
+ conditions=None,
625
+ normalize_output=True,
626
+ save_images=False
627
+ )
628
+
629
+ assert images.shape == (2, 3, 32, 32)
630
+ assert torch.all(images >= 0) and torch.all(images <= 1)
631
+
632
+ # Test conditional generation
633
+ images_cond = sampler(
634
+ conditions=["test1", "test2"],
635
+ normalize_output=True,
636
+ save_images=True,
637
+ save_path=temp_dir
638
+ )
639
+
640
+ assert images_cond.shape == (2, 3, 32, 32)
641
+
642
+ # Check that images were saved
643
+ saved_files = os.listdir(temp_dir)
644
+ assert len(saved_files) == 2
645
+
646
+
647
+ class TestIntegration:
648
+ """Integration tests for full LDM pipeline."""
649
+
650
+ def test_end_to_end_pipeline(self):
651
+ """Test complete LDM pipeline from training to sampling."""
652
+ # Create minimal dataset
653
+ data = torch.randn(8, 3, 32, 32)
654
+ labels = ['class_' + str(i % 2) for i in range(8)]
655
+ dataset = TensorDataset(data, labels)
656
+ train_loader = DataLoader(dataset, batch_size=4)
657
+
658
+ with tempfile.TemporaryDirectory() as temp_dir:
659
+ # 1. Train AutoEncoder
660
+ compressor = AutoencoderLDM(
661
+ in_channels=3, down_channels=[8, 16], up_channels=[16, 8],
662
+ out_channels=3, dropout_rate=0.1, latent_channels=4,
663
+ num_heads=1, num_groups=8, num_layers_per_block=2,
664
+ total_down_sampling_factor=2, num_embeddings=32
665
+ )
666
+
667
+ ae_optimizer = torch.optim.Adam(compressor.parameters(), lr=1e-3)
668
+ ae_trainer = TrainAE(
669
+ model=compressor,
670
+ optimizer=ae_optimizer,
671
+ data_loader=train_loader,
672
+ max_epochs=1,
673
+ device='cpu',
674
+ store_path=temp_dir
675
+ )
676
+
677
+ ae_losses, ae_best_loss = ae_trainer()
678
+ assert len(ae_losses) == 1
679
+
680
+ # 2. Train LDM
681
+ noise_predictor = MockNoisePredictor(in_channels=4)
682
+ text_encoder = MockTextEncoder(output_dim=32)
683
+
684
+ scheduler = VarianceSchedulerSDE(num_steps=20, trainable_beta=False)
685
+ forward_sde = ForwardSDE(scheduler, "ode")
686
+ reverse_sde = ReverseSDE(scheduler, "ode")
687
+
688
+ ldm_optimizer = torch.optim.Adam([
689
+ *noise_predictor.parameters(),
690
+ *text_encoder.parameters()
691
+ ], lr=1e-3)
692
+
693
+ ldm_trainer = TrainLDM(
694
+ diffusion_model="sde",
695
+ forward_diffusion=forward_sde,
696
+ reverse_diffusion=reverse_sde,
697
+ noise_predictor=noise_predictor,
698
+ compressor_model=compressor,
699
+ optimizer=ldm_optimizer,
700
+ objective=nn.MSELoss(),
701
+ data_loader=train_loader,
702
+ conditional_model=text_encoder,
703
+ max_epochs=1,
704
+ device='cpu',
705
+ store_path=temp_dir
706
+ )
707
+
708
+ ldm_losses, ldm_best_loss = ldm_trainer()
709
+ assert len(ldm_losses) == 1
710
+
711
+ # 3. Sample from trained model
712
+ sampler = SampleLDM(
713
+ diffusion_model="sde",
714
+ reverse_diffusion=reverse_sde,
715
+ noise_predictor=noise_predictor,
716
+ compressor_model=compressor,
717
+ image_shape=(32, 32),
718
+ conditional_model=text_encoder,
719
+ batch_size=2,
720
+ device='cpu'
721
+ )
722
+
723
+ generated_images = sampler(
724
+ conditions=["class_0", "class_1"],
725
+ save_images=False
726
+ )
727
+
728
+ assert generated_images.shape == (2, 3, 32, 32)
729
+ assert not torch.isnan(generated_images).any()
730
+
731
+
732
+ if __name__ == "__main__":
733
+ import subprocess
734
+ import sys
735
+
736
+ try:
737
+ import pytest
738
+ except ImportError:
739
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "pytest"])
740
+ import pytest
741
+
742
+ pytest.main([__file__, "-v"])