TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
@@ -0,0 +1,626 @@
1
+ import pytest
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import os
6
+ import tempfile
7
+ from unittest.mock import Mock, patch
8
+ from torch.utils.data import DataLoader, TensorDataset
9
+ from torchdiff.utils import NoisePredictor, TextEncoder
10
+ from torchdiff.sde import (
11
+ VarianceSchedulerSDE,
12
+ ForwardSDE,
13
+ ReverseSDE,
14
+ TrainSDE,
15
+ SampleSDE
16
+ )
17
+
18
+
19
+ class TestVarianceSchedulerSDE:
20
+ """Test cases for VarianceSchedulerSDE class."""
21
+
22
+ def test_init(self):
23
+ """Test initialization with different parameters."""
24
+ # Test default initialization
25
+ scheduler = VarianceSchedulerSDE()
26
+ assert scheduler.num_steps == 1000
27
+ assert scheduler.beta_start == 1e-4
28
+ assert scheduler.beta_end == 0.02
29
+
30
+ # Test custom initialization
31
+ scheduler = VarianceSchedulerSDE(
32
+ num_steps=500,
33
+ beta_start=1e-5,
34
+ beta_end=0.01,
35
+ beta_method="quadratic"
36
+ )
37
+ assert scheduler.num_steps == 500
38
+ assert scheduler.beta_start == 1e-5
39
+ assert scheduler.beta_end == 0.01
40
+
41
+ def test_invalid_parameters(self):
42
+ """Test initialization with invalid parameters."""
43
+ with pytest.raises(ValueError):
44
+ VarianceSchedulerSDE(num_steps=-10)
45
+
46
+ with pytest.raises(ValueError):
47
+ VarianceSchedulerSDE(beta_start=0.1, beta_end=0.01) # start > end
48
+
49
+ with pytest.raises(ValueError):
50
+ VarianceSchedulerSDE(sigma_start=1.0, sigma_end=0.5) # start > end
51
+
52
+ def test_beta_schedule_methods(self):
53
+ """Test all beta schedule computation methods."""
54
+ methods = ["linear", "sigmoid", "quadratic", "constant", "inverse_time"]
55
+
56
+ for method in methods:
57
+ scheduler = VarianceSchedulerSDE(num_steps=100, beta_method=method)
58
+ betas = scheduler.betas
59
+
60
+ assert betas.shape[0] == 100
61
+ assert torch.all(betas >= scheduler.beta_start)
62
+ assert torch.all(betas <= scheduler.beta_end)
63
+
64
+ def test_cumulative_betas(self):
65
+ """Test cumulative beta computation."""
66
+ scheduler = VarianceSchedulerSDE(num_steps=100)
67
+ cum_betas = scheduler._cum_betas
68
+
69
+ assert cum_betas.shape[0] == 100
70
+ assert cum_betas[0] > 0 # Should be positive
71
+ assert cum_betas[-1] > cum_betas[0] # Should be increasing
72
+
73
+ def test_sigmas(self):
74
+ """Test sigma computation."""
75
+ scheduler = VarianceSchedulerSDE(num_steps=100)
76
+ sigmas = scheduler.sigmas
77
+
78
+ assert sigmas.shape[0] == 100
79
+ assert sigmas[0] == scheduler.sigma_start
80
+ assert sigmas[-1] == scheduler.sigma_end
81
+
82
+ def test_get_variance(self):
83
+ """Test variance computation for different methods."""
84
+ scheduler = VarianceSchedulerSDE(num_steps=100)
85
+ time_steps = torch.tensor([0, 50, 99])
86
+
87
+ for method in ["ve", "vp", "sub-vp"]:
88
+ variance = scheduler.get_variance(time_steps, method)
89
+ assert variance.shape[0] == 3
90
+ assert torch.all(variance >= 0) # Variance should be non-negative
91
+
92
+
93
+ class TestForwardSDE:
94
+ """Test cases for ForwardSDE class."""
95
+
96
+ def setup_method(self):
97
+ """Setup test fixtures."""
98
+ self.scheduler = VarianceSchedulerSDE(num_steps=100)
99
+ self.batch_size = 4
100
+ self.channels = 3
101
+ self.height = 32
102
+ self.width = 32
103
+
104
+ def test_init(self):
105
+ """Test initialization."""
106
+ for method in ["ve", "vp", "sub-vp", "ode"]:
107
+ forward_sde = ForwardSDE(self.scheduler, method)
108
+ assert forward_sde.sde_method == method
109
+
110
+ with pytest.raises(ValueError):
111
+ ForwardSDE(self.scheduler, "invalid_method")
112
+
113
+ def test_forward_ve(self):
114
+ """Test VE forward process."""
115
+ forward_sde = ForwardSDE(self.scheduler, "ve")
116
+ x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
117
+ noise = torch.randn_like(x0)
118
+ time_steps = torch.randint(0, 100, (self.batch_size,))
119
+
120
+ xt = forward_sde(x0, noise, time_steps)
121
+ assert xt.shape == x0.shape
122
+
123
+ def test_forward_vp(self):
124
+ """Test VP forward process."""
125
+ forward_sde = ForwardSDE(self.scheduler, "vp")
126
+ x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
127
+ noise = torch.randn_like(x0)
128
+ time_steps = torch.randint(0, 100, (self.batch_size,))
129
+
130
+ xt = forward_sde(x0, noise, time_steps)
131
+ assert xt.shape == x0.shape
132
+
133
+ def test_forward_sub_vp(self):
134
+ """Test sub-VP forward process."""
135
+ forward_sde = ForwardSDE(self.scheduler, "sub-vp")
136
+ x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
137
+ noise = torch.randn_like(x0)
138
+ time_steps = torch.randint(0, 100, (self.batch_size,))
139
+
140
+ xt = forward_sde(x0, noise, time_steps)
141
+ assert xt.shape == x0.shape
142
+
143
+ def test_forward_ode(self):
144
+ """Test ODE forward process."""
145
+ forward_sde = ForwardSDE(self.scheduler, "ode")
146
+ x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
147
+ noise = torch.randn_like(x0)
148
+ time_steps = torch.randint(0, 100, (self.batch_size,))
149
+
150
+ xt = forward_sde(x0, noise, time_steps)
151
+ assert xt.shape == x0.shape
152
+
153
+
154
+ class TestReverseSDE:
155
+ """Test cases for ReverseSDE class."""
156
+
157
+ def setup_method(self):
158
+ """Setup test fixtures."""
159
+ self.scheduler = VarianceSchedulerSDE(num_steps=100)
160
+ self.batch_size = 4
161
+ self.channels = 3
162
+ self.height = 32
163
+ self.width = 32
164
+
165
+ def test_init(self):
166
+ """Test initialization."""
167
+ for method in ["ve", "vp", "sub-vp", "ode"]:
168
+ reverse_sde = ReverseSDE(self.scheduler, method)
169
+ assert reverse_sde.sde_method == method
170
+
171
+ with pytest.raises(ValueError):
172
+ ReverseSDE(self.scheduler, "invalid_method")
173
+
174
+ def test_reverse_ve(self):
175
+ """Test VE reverse process."""
176
+ reverse_sde = ReverseSDE(self.scheduler, "ve")
177
+ xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
178
+ noise = torch.randn_like(xt)
179
+ predicted_noise = torch.randn_like(xt)
180
+ time_steps = torch.randint(1, 100, (self.batch_size,))
181
+
182
+ xt_prev = reverse_sde(xt, noise, predicted_noise, time_steps)
183
+ assert xt_prev.shape == xt.shape
184
+
185
+ def test_reverse_vp(self):
186
+ """Test VP reverse process."""
187
+ reverse_sde = ReverseSDE(self.scheduler, "vp")
188
+ xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
189
+ noise = torch.randn_like(xt)
190
+ predicted_noise = torch.randn_like(xt)
191
+ time_steps = torch.randint(0, 100, (self.batch_size,))
192
+
193
+ xt_prev = reverse_sde(xt, noise, predicted_noise, time_steps)
194
+ assert xt_prev.shape == xt.shape
195
+
196
+ def test_reverse_sub_vp(self):
197
+ """Test sub-VP reverse process."""
198
+ reverse_sde = ReverseSDE(self.scheduler, "sub-vp")
199
+ xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
200
+ noise = torch.randn_like(xt)
201
+ predicted_noise = torch.randn_like(xt)
202
+ time_steps = torch.randint(0, 100, (self.batch_size,))
203
+
204
+ xt_prev = reverse_sde(xt, noise, predicted_noise, time_steps)
205
+ assert xt_prev.shape == xt.shape
206
+
207
+ def test_reverse_ode(self):
208
+ """Test ODE reverse process."""
209
+ reverse_sde = ReverseSDE(self.scheduler, "ode")
210
+ xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
211
+ noise = None # ODE doesn't use noise
212
+ predicted_noise = torch.randn_like(xt)
213
+ time_steps = torch.randint(0, 100, (self.batch_size,))
214
+
215
+ xt_prev = reverse_sde(xt, noise, predicted_noise, time_steps)
216
+ assert xt_prev.shape == xt.shape
217
+
218
+
219
+ class TestTrainSDE:
220
+ """Test cases for TrainSDE class."""
221
+
222
+ def setup_method(self):
223
+ """Setup test fixtures."""
224
+
225
+ # Create simple models for testing
226
+ class SimpleNoisePredictor(nn.Module):
227
+ def __init__(self):
228
+ super().__init__()
229
+ self.conv = nn.Conv2d(3, 3, 3, padding=1)
230
+
231
+ def forward(self, x, t, y=None, mask=None):
232
+ return self.conv(x)
233
+
234
+ class SimpleConditionalModel(nn.Module):
235
+ def __init__(self):
236
+ super().__init__()
237
+ self.embed = nn.Linear(77, 64)
238
+
239
+ def forward(self, input_ids, attention_mask=None):
240
+ return self.embed(input_ids.float())
241
+
242
+ # Create test data
243
+ self.batch_size = 4
244
+ self.channels = 3
245
+ self.height = 32
246
+ self.width = 32
247
+
248
+ x_data = torch.randn(20, self.channels, self.height, self.width)
249
+ y_data = torch.randint(0, 10, (20,))
250
+ dataset = TensorDataset(x_data, y_data)
251
+ self.data_loader = DataLoader(dataset, batch_size=self.batch_size)
252
+
253
+ # Create components
254
+ self.scheduler = VarianceSchedulerSDE(num_steps=10)
255
+ self.forward_sde = ForwardSDE(self.scheduler, "vp")
256
+ self.reverse_sde = ReverseSDE(self.scheduler, "vp")
257
+ self.noise_predictor = SimpleNoisePredictor()
258
+ self.conditional_model = SimpleConditionalModel()
259
+ self.optimizer = torch.optim.Adam(
260
+ list(self.noise_predictor.parameters()) +
261
+ list(self.conditional_model.parameters()),
262
+ lr=1e-4
263
+ )
264
+ self.objective = nn.MSELoss()
265
+
266
+ def test_init(self):
267
+ """Test initialization."""
268
+ trainer = TrainSDE(
269
+ noise_predictor=self.noise_predictor,
270
+ forward_diffusion=self.forward_sde,
271
+ reverse_diffusion=self.reverse_sde,
272
+ data_loader=self.data_loader,
273
+ optimizer=self.optimizer,
274
+ objective=self.objective,
275
+ conditional_model=self.conditional_model,
276
+ max_epochs=2
277
+ )
278
+
279
+ assert trainer is not None
280
+
281
+
282
+ @patch('sde.TrainSDE._setup_ddp')
283
+ def test_ddp_setup(self, mock_setup_ddp):
284
+ trainer = TrainSDE(
285
+ noise_predictor=self.noise_predictor,
286
+ forward_diffusion=self.forward_sde,
287
+ reverse_diffusion=self.reverse_sde,
288
+ data_loader=self.data_loader,
289
+ optimizer=self.optimizer,
290
+ objective=self.objective,
291
+ use_ddp=True
292
+ )
293
+
294
+ mock_setup_ddp.assert_called_once()
295
+
296
+
297
+ def test_single_gpu_setup(self):
298
+ """Test single GPU setup."""
299
+ trainer = TrainSDE(
300
+ noise_predictor=self.noise_predictor,
301
+ forward_diffusion=self.forward_sde,
302
+ reverse_diffusion=self.reverse_sde,
303
+ data_loader=self.data_loader,
304
+ optimizer=self.optimizer,
305
+ objective=self.objective,
306
+ use_ddp=False
307
+ )
308
+
309
+ assert trainer.ddp_rank == 0
310
+ assert trainer.ddp_local_rank == 0
311
+ assert trainer.ddp_world_size == 1
312
+ assert trainer.master_process
313
+
314
+ def test_warmup_scheduler(self):
315
+ """Test warmup scheduler creation."""
316
+ trainer = TrainSDE(
317
+ noise_predictor=self.noise_predictor,
318
+ forward_diffusion=self.forward_sde,
319
+ reverse_diffusion=self.reverse_sde,
320
+ data_loader=self.data_loader,
321
+ optimizer=self.optimizer,
322
+ objective=self.objective
323
+ )
324
+
325
+ scheduler = trainer.warmup_scheduler(self.optimizer, 10)
326
+ assert scheduler is not None
327
+
328
+ def test_process_conditional_input(self):
329
+ """Test conditional input processing."""
330
+ trainer = TrainSDE(
331
+ noise_predictor=self.noise_predictor,
332
+ forward_diffusion=self.forward_sde,
333
+ reverse_diffusion=self.reverse_sde,
334
+ data_loader=self.data_loader,
335
+ optimizer=self.optimizer,
336
+ objective=self.objective,
337
+ conditional_model=self.conditional_model
338
+ )
339
+
340
+ # Test with tensor input
341
+ y_tensor = torch.tensor([1, 2, 3, 4])
342
+ y_encoded = trainer._process_conditional_input(y_tensor)
343
+ assert y_encoded is not None
344
+
345
+ # Test with list input
346
+ y_list = ["test1", "test2", "test3", "test4"]
347
+ y_encoded = trainer._process_conditional_input(y_list)
348
+ assert y_encoded is not None
349
+
350
+ def test_save_checkpoint(self):
351
+ """Test checkpoint saving."""
352
+ with tempfile.TemporaryDirectory() as temp_dir:
353
+ trainer = TrainSDE(
354
+ noise_predictor=self.noise_predictor,
355
+ forward_diffusion=self.forward_sde,
356
+ reverse_diffusion=self.reverse_sde,
357
+ data_loader=self.data_loader,
358
+ optimizer=self.optimizer,
359
+ objective=self.objective,
360
+ store_path=temp_dir
361
+ )
362
+
363
+ trainer._save_checkpoint(1, 0.5)
364
+
365
+ # Check if file was created
366
+ files = os.listdir(temp_dir)
367
+ assert any(f.startswith("sde_epoch_1") for f in files)
368
+
369
+ def test_validate(self):
370
+ """Test validation method."""
371
+ # Mock metrics
372
+ mock_metrics = Mock()
373
+ mock_metrics.forward.return_value = (1.0, 0.1, 25.0, 0.8, 0.2)
374
+ mock_metrics.fid = True
375
+ mock_metrics.metrics = True
376
+ mock_metrics.lpips = True
377
+
378
+ trainer = TrainSDE(
379
+ noise_predictor=self.noise_predictor,
380
+ forward_diffusion=self.forward_sde,
381
+ reverse_diffusion=self.reverse_sde,
382
+ data_loader=self.data_loader,
383
+ optimizer=self.optimizer,
384
+ objective=self.objective,
385
+ val_loader=self.data_loader,
386
+ metrics_=mock_metrics
387
+ )
388
+
389
+ val_loss, fid, mse, psnr, ssim, lpips = trainer.validate()
390
+
391
+ assert isinstance(val_loss, float)
392
+ assert isinstance(fid, float)
393
+ assert isinstance(mse, float)
394
+ assert isinstance(psnr, float)
395
+ assert isinstance(ssim, float)
396
+ assert isinstance(lpips, float)
397
+
398
+
399
+ class TestSampleSDE:
400
+ """Test cases for SampleSDE class."""
401
+
402
+ def setup_method(self):
403
+ """Setup test fixtures."""
404
+
405
+ # Create simple models for testing
406
+ class SimpleNoisePredictor(nn.Module):
407
+ def __init__(self):
408
+ super().__init__()
409
+ self.conv = nn.Conv2d(3, 3, 3, padding=1)
410
+
411
+ def forward(self, x, t, y=None):
412
+ return self.conv(x)
413
+
414
+ class SimpleConditionalModel(nn.Module):
415
+ def __init__(self):
416
+ super().__init__()
417
+ self.embed = nn.Linear(77, 64)
418
+
419
+ def forward(self, input_ids, attention_mask=None):
420
+ return self.embed(input_ids.float())
421
+
422
+ # Create components
423
+ self.scheduler = VarianceSchedulerSDE(num_steps=10)
424
+ self.reverse_sde = ReverseSDE(self.scheduler, "vp")
425
+ self.noise_predictor = SimpleNoisePredictor()
426
+ self.conditional_model = SimpleConditionalModel()
427
+
428
+ def test_init(self):
429
+ """Test initialization."""
430
+ sampler = SampleSDE(
431
+ reverse_diffusion=self.reverse_sde,
432
+ noise_predictor=self.noise_predictor,
433
+ image_shape=(32, 32)
434
+ )
435
+
436
+ assert sampler is not None
437
+
438
+ def test_tokenize(self):
439
+ """Test tokenization method."""
440
+ sampler = SampleSDE(
441
+ reverse_diffusion=self.reverse_sde,
442
+ noise_predictor=self.noise_predictor,
443
+ image_shape=(32, 32),
444
+ conditional_model=self.conditional_model
445
+ )
446
+
447
+ # Test with single prompt
448
+ input_ids, attention_mask = sampler.tokenize("a test prompt")
449
+ assert input_ids.shape[0] == 1
450
+ assert attention_mask.shape[0] == 1
451
+
452
+ # Test with multiple prompts
453
+ input_ids, attention_mask = sampler.tokenize(["prompt1", "prompt2"])
454
+ assert input_ids.shape[0] == 2
455
+ assert attention_mask.shape[0] == 2
456
+
457
+ def test_forward_unconditional(self):
458
+ """Test unconditional sampling."""
459
+ sampler = SampleSDE(
460
+ reverse_diffusion=self.reverse_sde,
461
+ noise_predictor=self.noise_predictor,
462
+ image_shape=(32, 32),
463
+ batch_size=2
464
+ )
465
+
466
+ with tempfile.TemporaryDirectory() as temp_dir:
467
+ images = sampler.forward(
468
+ conditions=None,
469
+ save_images=True,
470
+ save_path=temp_dir
471
+ )
472
+
473
+ assert images.shape == (2, 3, 32, 32)
474
+ assert torch.all(images >= 0) and torch.all(images <= 1) # Normalized
475
+
476
+ # Check if images were saved
477
+ files = os.listdir(temp_dir)
478
+ assert len(files) == 2
479
+
480
+ def test_forward_conditional(self):
481
+ """Test conditional sampling."""
482
+ sampler = SampleSDE(
483
+ reverse_diffusion=self.reverse_sde,
484
+ noise_predictor=self.noise_predictor,
485
+ image_shape=(32, 32),
486
+ conditional_model=self.conditional_model,
487
+ batch_size=2
488
+ )
489
+
490
+ images = sampler.forward(
491
+ conditions=["a cat", "a dog"],
492
+ save_images=False
493
+ )
494
+
495
+ assert images.shape == (2, 3, 32, 32)
496
+
497
+ def test_to_device(self):
498
+ """Test device movement."""
499
+ sampler = SampleSDE(
500
+ reverse_diffusion=self.reverse_sde,
501
+ noise_predictor=self.noise_predictor,
502
+ image_shape=(32, 32)
503
+ )
504
+
505
+ # Move to CPU if CUDA is available, otherwise test stays on CPU
506
+ target_device = torch.device("cpu")
507
+ sampler = sampler.to(target_device)
508
+
509
+ assert sampler.device == target_device
510
+ assert next(sampler.noise_predictor.parameters()).device == target_device
511
+ assert next(sampler.reverse.parameters()).device == target_device
512
+
513
+
514
+ def test_integration():
515
+ """Integration test with the provided usage code."""
516
+ # Set device
517
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
518
+
519
+ # Create simple test data
520
+ x_data = torch.randn(20, 3, 32, 32)
521
+ y_data = torch.randint(0, 10, (20,))
522
+ dataset = torch.utils.data.TensorDataset(x_data, y_data)
523
+ train_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)
524
+ val_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
525
+
526
+ # Initialize models with smaller parameters for testing
527
+ noise_predictor = NoisePredictor(
528
+ in_channels=3,
529
+ down_channels=[8, 16], # Reduced channels for testing
530
+ mid_channels=[16, 16],
531
+ up_channels=[16, 8],
532
+ down_sampling=[True, False], # Only one downsampling for small images
533
+ time_embed_dim=32,
534
+ y_embed_dim=32,
535
+ num_down_blocks=1,
536
+ num_mid_blocks=1,
537
+ num_up_blocks=1,
538
+ down_sampling_factor=2
539
+ ).to(device)
540
+
541
+ text_encoder = TextEncoder(
542
+ use_pretrained_model=False, # Don't use pretrained for faster testing
543
+ model_name="bert-base-uncased",
544
+ vocabulary_size=100, # Smaller vocabulary
545
+ num_layers=1, # Fewer layers
546
+ input_dimension=32,
547
+ output_dimension=32,
548
+ num_heads=2,
549
+ context_length=10 # Shorter context
550
+ ).to(device)
551
+
552
+ # Optimizer and loss
553
+ optimizer = torch.optim.Adam(
554
+ [p for p in noise_predictor.parameters() if p.requires_grad] +
555
+ [p for p in text_encoder.parameters() if p.requires_grad],
556
+ lr=1e-4
557
+ )
558
+ loss = nn.MSELoss()
559
+
560
+ # SDE hyperparameters with fewer steps
561
+ hyperparams_sde = VarianceSchedulerSDE(
562
+ num_steps=10, # Fewer steps for testing
563
+ beta_start=1e-4,
564
+ beta_end=0.02,
565
+ trainable_beta=False,
566
+ sigma_start=1e-3,
567
+ sigma_end=10.0,
568
+ start=0.0,
569
+ end=1.0,
570
+ beta_method="linear"
571
+ )
572
+
573
+ # Forward and reverse SDE
574
+ forward_sde = ForwardSDE(variance_scheduler=hyperparams_sde, sde_method="vp")
575
+ reverse_sde = ReverseSDE(variance_scheduler=hyperparams_sde, sde_method="vp")
576
+
577
+ # TrainSDE with minimal settings
578
+ with tempfile.TemporaryDirectory() as temp_dir:
579
+ trainer = TrainSDE(
580
+ noise_predictor=noise_predictor,
581
+ forward_diffusion=forward_sde,
582
+ reverse_diffusion=reverse_sde,
583
+ data_loader=train_loader,
584
+ optimizer=optimizer,
585
+ objective=loss,
586
+ val_loader=val_loader,
587
+ max_epochs=2, # Just 2 epochs for testing
588
+ device=device,
589
+ conditional_model=text_encoder,
590
+ metrics_=None, # No metrics for faster testing
591
+ store_path=temp_dir,
592
+ val_frequency=1,
593
+ use_ddp=False,
594
+ grad_accumulation_steps=1,
595
+ log_frequency=1,
596
+ use_compilation=False
597
+ )
598
+
599
+ # Test training
600
+ train_losses, best_val_loss = trainer()
601
+ assert len(train_losses) >= 0 # Could be empty if early stopping
602
+ assert isinstance(best_val_loss, float)
603
+
604
+ # Test sampling
605
+ sampler = SampleSDE(
606
+ reverse_diffusion=reverse_sde,
607
+ noise_predictor=noise_predictor,
608
+ image_shape=(32, 32),
609
+ conditional_model=text_encoder,
610
+ tokenizer="bert-base-uncased",
611
+ max_token_length=10, # Shorter for testing
612
+ batch_size=2,
613
+ in_channels=3,
614
+ device=device,
615
+ image_output_range=(-1.0, 1.0)
616
+ )
617
+
618
+ # Test with class names
619
+ class_names = ['airplane', 'automobile']
620
+ images = sampler(class_names, save_images=False)
621
+ assert images.shape == (2, 3, 32, 32)
622
+
623
+
624
+ if __name__ == "__main__":
625
+ # Run tests
626
+ pytest.main([__file__, "-v"])