TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
|
@@ -0,0 +1,742 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import tempfile
|
|
5
|
+
import shutil
|
|
6
|
+
import os
|
|
7
|
+
from unittest.mock import Mock, patch
|
|
8
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
9
|
+
from torchvision import transforms
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
# Import the LDM components (assuming they're in ldm.py)
|
|
13
|
+
from torchdiff.ldm import (
|
|
14
|
+
AutoencoderLDM, TrainAE, TrainLDM, SampleLDM,
|
|
15
|
+
VectorQuantizer, DownBlock, UpBlock, Conv3,
|
|
16
|
+
DownSampling, UpSampling, Attention
|
|
17
|
+
)
|
|
18
|
+
from torchdiff.sde import ForwardSDE, ReverseSDE, VarianceSchedulerSDE
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# mock utility classes that would normally come from torchdiff.utils.py
|
|
23
|
+
class MockTextEncoder(nn.Module):
|
|
24
|
+
def __init__(self, output_dim=32):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.output_dim = output_dim
|
|
27
|
+
|
|
28
|
+
def forward(self, input_ids, attention_mask):
|
|
29
|
+
batch_size = input_ids.shape[0]
|
|
30
|
+
return torch.randn(batch_size, self.output_dim)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MockNoisePredictor(nn.Module):
|
|
34
|
+
def __init__(self, in_channels=2):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.in_channels = in_channels
|
|
37
|
+
self.conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
|
|
38
|
+
|
|
39
|
+
def forward(self, x, t, y=None, context=None):
|
|
40
|
+
return self.conv(x)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MockMetrics:
|
|
44
|
+
def __init__(self, device="cpu", fid=True, metrics=True, lpips_=True):
|
|
45
|
+
self.device = device
|
|
46
|
+
self.fid = fid
|
|
47
|
+
self.metrics = metrics
|
|
48
|
+
self.lpips = lpips_
|
|
49
|
+
|
|
50
|
+
def forward(self, x_real, x_fake):
|
|
51
|
+
return (
|
|
52
|
+
torch.tensor(10.0) if self.fid else float('inf'), # FID
|
|
53
|
+
torch.tensor(0.1) if self.metrics else None, # MSE
|
|
54
|
+
torch.tensor(25.0) if self.metrics else None, # PSNR
|
|
55
|
+
torch.tensor(0.8) if self.metrics else None, # SSIM
|
|
56
|
+
torch.tensor(0.2) if self.lpips else None # LPIPS
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class TestAutoencoderLDM:
|
|
61
|
+
"""Test suite for AutoencoderLDM component."""
|
|
62
|
+
|
|
63
|
+
@pytest.fixture
|
|
64
|
+
def autoencoder_config(self):
|
|
65
|
+
"""Standard configuration for autoencoder tests."""
|
|
66
|
+
return {
|
|
67
|
+
'in_channels': 3,
|
|
68
|
+
'down_channels': [8, 16],
|
|
69
|
+
'up_channels': [16, 8],
|
|
70
|
+
'out_channels': 3,
|
|
71
|
+
'dropout_rate': 0.1,
|
|
72
|
+
'num_heads': 1,
|
|
73
|
+
'num_groups': 8,
|
|
74
|
+
'num_layers_per_block': 2,
|
|
75
|
+
'total_down_sampling_factor': 2,
|
|
76
|
+
'latent_channels': 4,
|
|
77
|
+
'num_embeddings': 32,
|
|
78
|
+
'use_vq': False,
|
|
79
|
+
'beta': 1.0
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
@pytest.fixture
|
|
83
|
+
def sample_data(self):
|
|
84
|
+
"""Sample input data for testing."""
|
|
85
|
+
return torch.randn(2, 3, 32, 32)
|
|
86
|
+
|
|
87
|
+
def test_autoencoder_initialization(self, autoencoder_config):
|
|
88
|
+
"""Test AutoencoderLDM initialization with different configurations."""
|
|
89
|
+
# Test KL-divergence mode
|
|
90
|
+
model_kl = AutoencoderLDM(**autoencoder_config)
|
|
91
|
+
assert not model_kl.use_vq
|
|
92
|
+
assert model_kl.beta == 1.0
|
|
93
|
+
|
|
94
|
+
# Test VQ mode
|
|
95
|
+
config_vq = autoencoder_config.copy()
|
|
96
|
+
config_vq['use_vq'] = True
|
|
97
|
+
model_vq = AutoencoderLDM(**config_vq)
|
|
98
|
+
assert model_vq.use_vq
|
|
99
|
+
assert hasattr(model_vq, 'vq_layer')
|
|
100
|
+
|
|
101
|
+
def test_autoencoder_forward_pass(self, autoencoder_config, sample_data):
|
|
102
|
+
"""Test forward pass of AutoencoderLDM."""
|
|
103
|
+
model = AutoencoderLDM(**autoencoder_config)
|
|
104
|
+
model.eval()
|
|
105
|
+
|
|
106
|
+
with torch.no_grad():
|
|
107
|
+
x_hat, total_loss, reg_loss, z = model(sample_data)
|
|
108
|
+
|
|
109
|
+
# Check output shapes
|
|
110
|
+
assert x_hat.shape == sample_data.shape
|
|
111
|
+
assert isinstance(total_loss, float)
|
|
112
|
+
assert isinstance(reg_loss, (float, torch.Tensor))
|
|
113
|
+
|
|
114
|
+
# Check latent shape
|
|
115
|
+
expected_latent_shape = (2, 4, 16, 16) # Based on config
|
|
116
|
+
assert z.shape == expected_latent_shape
|
|
117
|
+
|
|
118
|
+
def test_encode_decode_consistency(self, autoencoder_config, sample_data):
|
|
119
|
+
"""Test encode-decode cycle."""
|
|
120
|
+
model = AutoencoderLDM(**autoencoder_config)
|
|
121
|
+
model.eval()
|
|
122
|
+
|
|
123
|
+
with torch.no_grad():
|
|
124
|
+
z, reg_loss = model.encode(sample_data)
|
|
125
|
+
x_reconstructed = model.decode(z)
|
|
126
|
+
|
|
127
|
+
assert x_reconstructed.shape == sample_data.shape
|
|
128
|
+
assert z.shape[0] == sample_data.shape[0] # Batch size preserved
|
|
129
|
+
|
|
130
|
+
def test_vq_functionality(self, autoencoder_config, sample_data):
|
|
131
|
+
"""Test Vector Quantization functionality."""
|
|
132
|
+
config_vq = autoencoder_config.copy()
|
|
133
|
+
config_vq['use_vq'] = True
|
|
134
|
+
model = AutoencoderLDM(**config_vq)
|
|
135
|
+
model.eval()
|
|
136
|
+
|
|
137
|
+
with torch.no_grad():
|
|
138
|
+
z, vq_loss = model.encode(sample_data)
|
|
139
|
+
|
|
140
|
+
assert isinstance(vq_loss, torch.Tensor)
|
|
141
|
+
assert vq_loss.requires_grad == False
|
|
142
|
+
|
|
143
|
+
def test_reparameterization_trick(self, autoencoder_config):
|
|
144
|
+
"""Test reparameterization trick for VAE."""
|
|
145
|
+
model = AutoencoderLDM(**autoencoder_config)
|
|
146
|
+
|
|
147
|
+
mu = torch.randn(2, 8, 16, 16)
|
|
148
|
+
logvar = torch.randn(2, 8, 16, 16)
|
|
149
|
+
|
|
150
|
+
z = model.reparameterize(mu, logvar)
|
|
151
|
+
assert z.shape == mu.shape
|
|
152
|
+
|
|
153
|
+
# Test deterministic behavior when logvar is very negative
|
|
154
|
+
logvar_zero = torch.full_like(logvar, -20)
|
|
155
|
+
z_det = model.reparameterize(mu, logvar_zero)
|
|
156
|
+
torch.testing.assert_close(z_det, mu, atol=1e-3, rtol=1e-3)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class TestVectorQuantizer:
|
|
160
|
+
"""Test suite for VectorQuantizer component."""
|
|
161
|
+
|
|
162
|
+
def test_vq_initialization(self):
|
|
163
|
+
"""Test VectorQuantizer initialization."""
|
|
164
|
+
vq = VectorQuantizer(num_embeddings=64, embedding_dim=32)
|
|
165
|
+
|
|
166
|
+
assert vq.num_embeddings == 64
|
|
167
|
+
assert vq.embedding_dim == 32
|
|
168
|
+
assert vq.embedding.weight.shape == (64, 32)
|
|
169
|
+
|
|
170
|
+
def test_vq_forward_pass(self):
|
|
171
|
+
"""Test VectorQuantizer forward pass."""
|
|
172
|
+
vq = VectorQuantizer(num_embeddings=64, embedding_dim=16)
|
|
173
|
+
x = torch.randn(2, 16, 8, 8)
|
|
174
|
+
|
|
175
|
+
quantized, vq_loss = vq(x)
|
|
176
|
+
|
|
177
|
+
assert quantized.shape == x.shape
|
|
178
|
+
assert isinstance(vq_loss, torch.Tensor)
|
|
179
|
+
assert vq_loss.requires_grad
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class TestConvolutionalBlocks:
|
|
183
|
+
"""Test suite for convolutional building blocks."""
|
|
184
|
+
|
|
185
|
+
def test_conv3_block(self):
|
|
186
|
+
"""Test Conv3 block."""
|
|
187
|
+
conv3 = Conv3(in_channels=16, out_channels=32, dropout_rate=0.1)
|
|
188
|
+
x = torch.randn(2, 16, 32, 32)
|
|
189
|
+
|
|
190
|
+
output = conv3(x)
|
|
191
|
+
assert output.shape == (2, 32, 32, 32)
|
|
192
|
+
|
|
193
|
+
def test_down_block(self):
|
|
194
|
+
"""Test DownBlock."""
|
|
195
|
+
down_block = DownBlock(
|
|
196
|
+
in_channels=16, out_channels=32,
|
|
197
|
+
num_layers=2, down_sampling_factor=2,
|
|
198
|
+
dropout_rate=0.1
|
|
199
|
+
)
|
|
200
|
+
x = torch.randn(2, 16, 32, 32)
|
|
201
|
+
|
|
202
|
+
output = down_block(x)
|
|
203
|
+
assert output.shape == (2, 32, 16, 16) # Downsampled by factor of 2
|
|
204
|
+
|
|
205
|
+
def test_up_block(self):
|
|
206
|
+
"""Test UpBlock."""
|
|
207
|
+
up_block = UpBlock(
|
|
208
|
+
in_channels=32, out_channels=16,
|
|
209
|
+
num_layers=2, up_sampling_factor=2,
|
|
210
|
+
dropout_rate=0.1
|
|
211
|
+
)
|
|
212
|
+
x = torch.randn(2, 32, 16, 16)
|
|
213
|
+
|
|
214
|
+
output = up_block(x)
|
|
215
|
+
assert output.shape == (2, 16, 32, 32) # Upsampled by factor of 2
|
|
216
|
+
|
|
217
|
+
def test_attention_block(self):
|
|
218
|
+
"""Test Attention block."""
|
|
219
|
+
attention = Attention(
|
|
220
|
+
num_channels=32, num_heads=4,
|
|
221
|
+
num_groups=8, dropout_rate=0.1
|
|
222
|
+
)
|
|
223
|
+
x = torch.randn(2, 32, 16, 16)
|
|
224
|
+
|
|
225
|
+
output = attention(x)
|
|
226
|
+
assert output.shape == x.shape
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class TestSamplingLayers:
|
|
230
|
+
"""Test suite for sampling layers."""
|
|
231
|
+
|
|
232
|
+
def test_down_sampling(self):
|
|
233
|
+
"""Test DownSampling layer."""
|
|
234
|
+
down_sample = DownSampling(
|
|
235
|
+
in_channels=16, out_channels=32,
|
|
236
|
+
down_sampling_factor=2
|
|
237
|
+
)
|
|
238
|
+
x = torch.randn(2, 16, 32, 32)
|
|
239
|
+
|
|
240
|
+
output = down_sample(x)
|
|
241
|
+
assert output.shape == (2, 32, 16, 16)
|
|
242
|
+
|
|
243
|
+
def test_up_sampling(self):
|
|
244
|
+
"""Test UpSampling layer."""
|
|
245
|
+
up_sample = UpSampling(
|
|
246
|
+
in_channels=32, out_channels=16,
|
|
247
|
+
up_sampling_factor=2
|
|
248
|
+
)
|
|
249
|
+
x = torch.randn(2, 32, 16, 16)
|
|
250
|
+
|
|
251
|
+
output = up_sample(x)
|
|
252
|
+
assert output.shape == (2, 16, 32, 32)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class TestVarianceSchedulerSDE:
|
|
256
|
+
"""Test suite for SDE variance scheduler."""
|
|
257
|
+
|
|
258
|
+
def test_scheduler_initialization(self):
|
|
259
|
+
"""Test VarianceSchedulerSDE initialization."""
|
|
260
|
+
scheduler = VarianceSchedulerSDE(
|
|
261
|
+
num_steps=100, beta_start=1e-4, beta_end=0.02,
|
|
262
|
+
trainable_beta=False
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
assert scheduler.num_steps == 100
|
|
266
|
+
assert scheduler.beta_start == 1e-4
|
|
267
|
+
assert scheduler.beta_end == 0.02
|
|
268
|
+
assert not scheduler.trainable_beta
|
|
269
|
+
|
|
270
|
+
def test_beta_schedules(self):
|
|
271
|
+
"""Test different beta scheduling methods."""
|
|
272
|
+
methods = ["linear", "sigmoid", "quadratic", "constant"]
|
|
273
|
+
|
|
274
|
+
for method in methods:
|
|
275
|
+
scheduler = VarianceSchedulerSDE(
|
|
276
|
+
num_steps=100, beta_start=1e-4, beta_end=0.02,
|
|
277
|
+
beta_method=method, trainable_beta=False
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
betas = scheduler.betas
|
|
281
|
+
assert betas.shape == (100,)
|
|
282
|
+
assert torch.all(betas >= 1e-4)
|
|
283
|
+
assert torch.all(betas <= 0.02)
|
|
284
|
+
|
|
285
|
+
def test_trainable_beta(self):
|
|
286
|
+
"""Test trainable beta functionality."""
|
|
287
|
+
scheduler = VarianceSchedulerSDE(
|
|
288
|
+
num_steps=50, trainable_beta=True
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Check that beta_raw is a parameter
|
|
292
|
+
assert hasattr(scheduler, 'beta_raw')
|
|
293
|
+
assert isinstance(scheduler.beta_raw, nn.Parameter)
|
|
294
|
+
|
|
295
|
+
# Check that betas are in valid range
|
|
296
|
+
betas = scheduler.betas
|
|
297
|
+
assert torch.all(betas >= scheduler.beta_start)
|
|
298
|
+
assert torch.all(betas <= scheduler.beta_end)
|
|
299
|
+
|
|
300
|
+
def test_variance_computation(self):
|
|
301
|
+
"""Test variance computation for different SDE methods."""
|
|
302
|
+
scheduler = VarianceSchedulerSDE(num_steps=100, trainable_beta=False)
|
|
303
|
+
time_steps = torch.tensor([10, 20, 30])
|
|
304
|
+
|
|
305
|
+
methods = ["ve", "vp", "sub-vp"]
|
|
306
|
+
for method in methods:
|
|
307
|
+
variance = scheduler.get_variance(time_steps, method)
|
|
308
|
+
assert variance.shape == time_steps.shape
|
|
309
|
+
assert torch.all(variance >= 0)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class TestSDEProcesses:
|
|
313
|
+
"""Test suite for SDE forward and reverse processes."""
|
|
314
|
+
|
|
315
|
+
@pytest.fixture
|
|
316
|
+
def sde_setup(self):
|
|
317
|
+
"""Setup SDE components for testing."""
|
|
318
|
+
scheduler = VarianceSchedulerSDE(
|
|
319
|
+
num_steps=100, beta_start=1e-4, beta_end=0.02
|
|
320
|
+
)
|
|
321
|
+
return scheduler
|
|
322
|
+
|
|
323
|
+
def test_forward_sde_methods(self, sde_setup):
|
|
324
|
+
"""Test ForwardSDE with different methods."""
|
|
325
|
+
methods = ["ve", "vp", "sub-vp", "ode"]
|
|
326
|
+
x0 = torch.randn(2, 3, 32, 32)
|
|
327
|
+
noise = torch.randn_like(x0)
|
|
328
|
+
time_steps = torch.tensor([10, 20])
|
|
329
|
+
|
|
330
|
+
for method in methods:
|
|
331
|
+
forward_sde = ForwardSDE(sde_setup, method)
|
|
332
|
+
xt = forward_sde(x0, noise, time_steps)
|
|
333
|
+
|
|
334
|
+
assert xt.shape == x0.shape
|
|
335
|
+
assert not torch.isnan(xt).any()
|
|
336
|
+
|
|
337
|
+
def test_reverse_sde_methods(self, sde_setup):
|
|
338
|
+
"""Test ReverseSDE with different methods."""
|
|
339
|
+
methods = ["ve", "vp", "sub-vp", "ode"]
|
|
340
|
+
xt = torch.randn(2, 3, 32, 32)
|
|
341
|
+
noise = torch.randn_like(xt)
|
|
342
|
+
predicted_noise = torch.randn_like(xt)
|
|
343
|
+
time_steps = torch.tensor([10, 20])
|
|
344
|
+
|
|
345
|
+
for method in methods:
|
|
346
|
+
reverse_sde = ReverseSDE(sde_setup, method)
|
|
347
|
+
x_prev = reverse_sde(xt, noise, predicted_noise, time_steps)
|
|
348
|
+
|
|
349
|
+
assert x_prev.shape == xt.shape
|
|
350
|
+
assert not torch.isnan(x_prev).any()
|
|
351
|
+
|
|
352
|
+
def test_ode_method_without_noise(self, sde_setup):
|
|
353
|
+
"""Test ODE method works without noise."""
|
|
354
|
+
forward_sde = ForwardSDE(sde_setup, "ode")
|
|
355
|
+
reverse_sde = ReverseSDE(sde_setup, "ode")
|
|
356
|
+
|
|
357
|
+
x0 = torch.randn(2, 3, 32, 32)
|
|
358
|
+
xt = torch.randn_like(x0)
|
|
359
|
+
predicted_noise = torch.randn_like(x0)
|
|
360
|
+
time_steps = torch.tensor([10, 20])
|
|
361
|
+
|
|
362
|
+
# Forward SDE with ODE method
|
|
363
|
+
xt_forward = forward_sde(x0, torch.randn_like(x0), time_steps)
|
|
364
|
+
|
|
365
|
+
# Reverse SDE with ODE method (noise can be None)
|
|
366
|
+
x_prev = reverse_sde(xt, None, predicted_noise, time_steps)
|
|
367
|
+
|
|
368
|
+
assert not torch.isnan(xt_forward).any()
|
|
369
|
+
assert not torch.isnan(x_prev).any()
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class TestTrainAE:
|
|
373
|
+
"""Test suite for AutoEncoder trainer."""
|
|
374
|
+
|
|
375
|
+
@pytest.fixture
|
|
376
|
+
def training_setup(self):
|
|
377
|
+
"""Setup training components."""
|
|
378
|
+
# Create simple dataset
|
|
379
|
+
data = torch.randn(20, 3, 32, 32)
|
|
380
|
+
labels = torch.randint(0, 10, (20,))
|
|
381
|
+
dataset = TensorDataset(data, labels)
|
|
382
|
+
train_loader = DataLoader(dataset, batch_size=4)
|
|
383
|
+
val_loader = DataLoader(dataset, batch_size=4)
|
|
384
|
+
|
|
385
|
+
# Create model and optimizer
|
|
386
|
+
model = AutoencoderLDM(
|
|
387
|
+
in_channels=3, down_channels=[8, 16], up_channels=[16, 8],
|
|
388
|
+
out_channels=3, dropout_rate=0.1, latent_channels=4,
|
|
389
|
+
num_heads=1, num_groups=8, num_layers_per_block=2,
|
|
390
|
+
total_down_sampling_factor=2, num_embeddings=32
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
394
|
+
metrics = MockMetrics()
|
|
395
|
+
|
|
396
|
+
return {
|
|
397
|
+
'model': model,
|
|
398
|
+
'optimizer': optimizer,
|
|
399
|
+
'train_loader': train_loader,
|
|
400
|
+
'val_loader': val_loader,
|
|
401
|
+
'metrics': metrics
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
def test_train_ae_initialization(self, training_setup):
|
|
405
|
+
"""Test TrainAE initialization."""
|
|
406
|
+
trainer = TrainAE(
|
|
407
|
+
model=training_setup['model'],
|
|
408
|
+
optimizer=training_setup['optimizer'],
|
|
409
|
+
data_loader=training_setup['train_loader'],
|
|
410
|
+
val_loader=training_setup['val_loader'],
|
|
411
|
+
max_epochs=2,
|
|
412
|
+
metrics_=training_setup['metrics'],
|
|
413
|
+
device='cpu'
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
assert trainer.max_epochs == 2
|
|
417
|
+
assert trainer.device.type == 'cpu'
|
|
418
|
+
|
|
419
|
+
def test_train_ae_forward(self, training_setup):
|
|
420
|
+
"""Test TrainAE training loop."""
|
|
421
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
422
|
+
trainer = TrainAE(
|
|
423
|
+
model=training_setup['model'],
|
|
424
|
+
optimizer=training_setup['optimizer'],
|
|
425
|
+
data_loader=training_setup['train_loader'],
|
|
426
|
+
val_loader=training_setup['val_loader'],
|
|
427
|
+
max_epochs=2,
|
|
428
|
+
metrics_=training_setup['metrics'],
|
|
429
|
+
device='cpu',
|
|
430
|
+
store_path=temp_dir,
|
|
431
|
+
val_frequency=1,
|
|
432
|
+
log_frequency=1
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
train_losses, best_val_loss = trainer()
|
|
436
|
+
|
|
437
|
+
assert len(train_losses) <= 2
|
|
438
|
+
assert isinstance(best_val_loss, float)
|
|
439
|
+
assert best_val_loss < float('inf')
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
class TestTrainLDM:
|
|
443
|
+
"""Test suite for LDM trainer."""
|
|
444
|
+
|
|
445
|
+
@pytest.fixture
|
|
446
|
+
def ldm_training_setup(self):
|
|
447
|
+
"""Setup LDM training components."""
|
|
448
|
+
# Create dataset with text labels
|
|
449
|
+
data = torch.randn(16, 3, 32, 32)
|
|
450
|
+
labels = ['class_' + str(i % 4) for i in range(16)]
|
|
451
|
+
dataset = TensorDataset(data, labels)
|
|
452
|
+
train_loader = DataLoader(dataset, batch_size=4)
|
|
453
|
+
val_loader = DataLoader(dataset, batch_size=4)
|
|
454
|
+
|
|
455
|
+
# Create components
|
|
456
|
+
compressor = AutoencoderLDM(
|
|
457
|
+
in_channels=3, down_channels=[8, 16], up_channels=[16, 8],
|
|
458
|
+
out_channels=3, dropout_rate=0.1, latent_channels=4,
|
|
459
|
+
num_heads=1, num_groups=8, num_layers_per_block=2,
|
|
460
|
+
total_down_sampling_factor=2, num_embeddings=32
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
noise_predictor = MockNoisePredictor(in_channels=4)
|
|
464
|
+
text_encoder = MockTextEncoder(output_dim=32)
|
|
465
|
+
|
|
466
|
+
scheduler = VarianceSchedulerSDE(num_steps=50, trainable_beta=False)
|
|
467
|
+
forward_sde = ForwardSDE(scheduler, "ode")
|
|
468
|
+
reverse_sde = ReverseSDE(scheduler, "ode")
|
|
469
|
+
|
|
470
|
+
optimizer = torch.optim.Adam([
|
|
471
|
+
*noise_predictor.parameters(),
|
|
472
|
+
*text_encoder.parameters()
|
|
473
|
+
], lr=1e-3)
|
|
474
|
+
|
|
475
|
+
return {
|
|
476
|
+
'compressor': compressor,
|
|
477
|
+
'noise_predictor': noise_predictor,
|
|
478
|
+
'text_encoder': text_encoder,
|
|
479
|
+
'forward_sde': forward_sde,
|
|
480
|
+
'reverse_sde': reverse_sde,
|
|
481
|
+
'optimizer': optimizer,
|
|
482
|
+
'train_loader': train_loader,
|
|
483
|
+
'val_loader': val_loader
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
def test_train_ldm_initialization(self, ldm_training_setup):
|
|
487
|
+
"""Test TrainLDM initialization."""
|
|
488
|
+
setup = ldm_training_setup
|
|
489
|
+
|
|
490
|
+
trainer = TrainLDM(
|
|
491
|
+
diffusion_model="sde",
|
|
492
|
+
forward_diffusion=setup['forward_sde'],
|
|
493
|
+
reverse_diffusion=setup['reverse_sde'],
|
|
494
|
+
noise_predictor=setup['noise_predictor'],
|
|
495
|
+
compressor_model=setup['compressor'],
|
|
496
|
+
optimizer=setup['optimizer'],
|
|
497
|
+
objective=nn.MSELoss(),
|
|
498
|
+
data_loader=setup['train_loader'],
|
|
499
|
+
val_loader=setup['val_loader'],
|
|
500
|
+
conditional_model=setup['text_encoder'],
|
|
501
|
+
max_epochs=2,
|
|
502
|
+
device='cpu'
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
assert trainer.diffusion_model == "sde"
|
|
506
|
+
assert trainer.max_epochs == 2
|
|
507
|
+
|
|
508
|
+
def test_train_ldm_forward(self, ldm_training_setup):
|
|
509
|
+
"""Test TrainLDM training loop."""
|
|
510
|
+
setup = ldm_training_setup
|
|
511
|
+
|
|
512
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
513
|
+
trainer = TrainLDM(
|
|
514
|
+
diffusion_model="sde",
|
|
515
|
+
forward_diffusion=setup['forward_sde'],
|
|
516
|
+
reverse_diffusion=setup['reverse_sde'],
|
|
517
|
+
noise_predictor=setup['noise_predictor'],
|
|
518
|
+
compressor_model=setup['compressor'],
|
|
519
|
+
optimizer=setup['optimizer'],
|
|
520
|
+
objective=nn.MSELoss(),
|
|
521
|
+
data_loader=setup['train_loader'],
|
|
522
|
+
val_loader=setup['val_loader'],
|
|
523
|
+
conditional_model=setup['text_encoder'],
|
|
524
|
+
max_epochs=2,
|
|
525
|
+
device='cpu',
|
|
526
|
+
store_path=temp_dir,
|
|
527
|
+
val_frequency=1,
|
|
528
|
+
log_frequency=1,
|
|
529
|
+
metrics_=MockMetrics()
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
train_losses, best_val_loss = trainer()
|
|
533
|
+
|
|
534
|
+
assert len(train_losses) <= 2
|
|
535
|
+
assert isinstance(best_val_loss, float)
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class TestSampleLDM:
|
|
539
|
+
"""Test suite for LDM sampler."""
|
|
540
|
+
|
|
541
|
+
@pytest.fixture
|
|
542
|
+
def sampling_setup(self):
|
|
543
|
+
"""Setup sampling components."""
|
|
544
|
+
compressor = AutoencoderLDM(
|
|
545
|
+
in_channels=3, down_channels=[8, 16], up_channels=[16, 8],
|
|
546
|
+
out_channels=3, dropout_rate=0.1, latent_channels=4,
|
|
547
|
+
num_heads=1, num_groups=8, num_layers_per_block=2,
|
|
548
|
+
total_down_sampling_factor=2, num_embeddings=32
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
noise_predictor = MockNoisePredictor(in_channels=4)
|
|
552
|
+
text_encoder = MockTextEncoder(output_dim=32)
|
|
553
|
+
|
|
554
|
+
scheduler = VarianceSchedulerSDE(num_steps=50, trainable_beta=False)
|
|
555
|
+
reverse_sde = ReverseSDE(scheduler, "ode")
|
|
556
|
+
|
|
557
|
+
return {
|
|
558
|
+
'compressor': compressor,
|
|
559
|
+
'noise_predictor': noise_predictor,
|
|
560
|
+
'text_encoder': text_encoder,
|
|
561
|
+
'reverse_sde': reverse_sde
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
def test_sample_ldm_initialization(self, sampling_setup):
|
|
565
|
+
"""Test SampleLDM initialization."""
|
|
566
|
+
setup = sampling_setup
|
|
567
|
+
|
|
568
|
+
sampler = SampleLDM(
|
|
569
|
+
diffusion_model="sde",
|
|
570
|
+
reverse_diffusion=setup['reverse_sde'],
|
|
571
|
+
noise_predictor=setup['noise_predictor'],
|
|
572
|
+
compressor_model=setup['compressor'],
|
|
573
|
+
image_shape=(32, 32),
|
|
574
|
+
conditional_model=setup['text_encoder'],
|
|
575
|
+
batch_size=2,
|
|
576
|
+
device='cpu'
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
assert sampler.diffusion_model == "sde"
|
|
580
|
+
assert sampler.batch_size == 2
|
|
581
|
+
assert sampler.image_shape == (32, 32)
|
|
582
|
+
|
|
583
|
+
def test_sample_ldm_tokenization(self, sampling_setup):
|
|
584
|
+
"""Test text tokenization in SampleLDM."""
|
|
585
|
+
setup = sampling_setup
|
|
586
|
+
|
|
587
|
+
sampler = SampleLDM(
|
|
588
|
+
diffusion_model="sde",
|
|
589
|
+
reverse_diffusion=setup['reverse_sde'],
|
|
590
|
+
noise_predictor=setup['noise_predictor'],
|
|
591
|
+
compressor_model=setup['compressor'],
|
|
592
|
+
image_shape=(32, 32),
|
|
593
|
+
conditional_model=setup['text_encoder'],
|
|
594
|
+
batch_size=2,
|
|
595
|
+
device='cpu'
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
# Test single prompt
|
|
599
|
+
input_ids, attention_mask = sampler.tokenize("test prompt")
|
|
600
|
+
assert input_ids.shape[0] == 1
|
|
601
|
+
|
|
602
|
+
# Test multiple prompts
|
|
603
|
+
input_ids, attention_mask = sampler.tokenize(["prompt1", "prompt2"])
|
|
604
|
+
assert input_ids.shape[0] == 2
|
|
605
|
+
|
|
606
|
+
def test_sample_ldm_generation(self, sampling_setup):
|
|
607
|
+
"""Test image generation with SampleLDM."""
|
|
608
|
+
setup = sampling_setup
|
|
609
|
+
|
|
610
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
611
|
+
sampler = SampleLDM(
|
|
612
|
+
diffusion_model="sde",
|
|
613
|
+
reverse_diffusion=setup['reverse_sde'],
|
|
614
|
+
noise_predictor=setup['noise_predictor'],
|
|
615
|
+
compressor_model=setup['compressor'],
|
|
616
|
+
image_shape=(32, 32),
|
|
617
|
+
conditional_model=setup['text_encoder'],
|
|
618
|
+
batch_size=2,
|
|
619
|
+
device='cpu'
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
# Test unconditional generation
|
|
623
|
+
images = sampler(
|
|
624
|
+
conditions=None,
|
|
625
|
+
normalize_output=True,
|
|
626
|
+
save_images=False
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
assert images.shape == (2, 3, 32, 32)
|
|
630
|
+
assert torch.all(images >= 0) and torch.all(images <= 1)
|
|
631
|
+
|
|
632
|
+
# Test conditional generation
|
|
633
|
+
images_cond = sampler(
|
|
634
|
+
conditions=["test1", "test2"],
|
|
635
|
+
normalize_output=True,
|
|
636
|
+
save_images=True,
|
|
637
|
+
save_path=temp_dir
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
assert images_cond.shape == (2, 3, 32, 32)
|
|
641
|
+
|
|
642
|
+
# Check that images were saved
|
|
643
|
+
saved_files = os.listdir(temp_dir)
|
|
644
|
+
assert len(saved_files) == 2
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
class TestIntegration:
|
|
648
|
+
"""Integration tests for full LDM pipeline."""
|
|
649
|
+
|
|
650
|
+
def test_end_to_end_pipeline(self):
|
|
651
|
+
"""Test complete LDM pipeline from training to sampling."""
|
|
652
|
+
# Create minimal dataset
|
|
653
|
+
data = torch.randn(8, 3, 32, 32)
|
|
654
|
+
labels = ['class_' + str(i % 2) for i in range(8)]
|
|
655
|
+
dataset = TensorDataset(data, labels)
|
|
656
|
+
train_loader = DataLoader(dataset, batch_size=4)
|
|
657
|
+
|
|
658
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
659
|
+
# 1. Train AutoEncoder
|
|
660
|
+
compressor = AutoencoderLDM(
|
|
661
|
+
in_channels=3, down_channels=[8, 16], up_channels=[16, 8],
|
|
662
|
+
out_channels=3, dropout_rate=0.1, latent_channels=4,
|
|
663
|
+
num_heads=1, num_groups=8, num_layers_per_block=2,
|
|
664
|
+
total_down_sampling_factor=2, num_embeddings=32
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
ae_optimizer = torch.optim.Adam(compressor.parameters(), lr=1e-3)
|
|
668
|
+
ae_trainer = TrainAE(
|
|
669
|
+
model=compressor,
|
|
670
|
+
optimizer=ae_optimizer,
|
|
671
|
+
data_loader=train_loader,
|
|
672
|
+
max_epochs=1,
|
|
673
|
+
device='cpu',
|
|
674
|
+
store_path=temp_dir
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
ae_losses, ae_best_loss = ae_trainer()
|
|
678
|
+
assert len(ae_losses) == 1
|
|
679
|
+
|
|
680
|
+
# 2. Train LDM
|
|
681
|
+
noise_predictor = MockNoisePredictor(in_channels=4)
|
|
682
|
+
text_encoder = MockTextEncoder(output_dim=32)
|
|
683
|
+
|
|
684
|
+
scheduler = VarianceSchedulerSDE(num_steps=20, trainable_beta=False)
|
|
685
|
+
forward_sde = ForwardSDE(scheduler, "ode")
|
|
686
|
+
reverse_sde = ReverseSDE(scheduler, "ode")
|
|
687
|
+
|
|
688
|
+
ldm_optimizer = torch.optim.Adam([
|
|
689
|
+
*noise_predictor.parameters(),
|
|
690
|
+
*text_encoder.parameters()
|
|
691
|
+
], lr=1e-3)
|
|
692
|
+
|
|
693
|
+
ldm_trainer = TrainLDM(
|
|
694
|
+
diffusion_model="sde",
|
|
695
|
+
forward_diffusion=forward_sde,
|
|
696
|
+
reverse_diffusion=reverse_sde,
|
|
697
|
+
noise_predictor=noise_predictor,
|
|
698
|
+
compressor_model=compressor,
|
|
699
|
+
optimizer=ldm_optimizer,
|
|
700
|
+
objective=nn.MSELoss(),
|
|
701
|
+
data_loader=train_loader,
|
|
702
|
+
conditional_model=text_encoder,
|
|
703
|
+
max_epochs=1,
|
|
704
|
+
device='cpu',
|
|
705
|
+
store_path=temp_dir
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
ldm_losses, ldm_best_loss = ldm_trainer()
|
|
709
|
+
assert len(ldm_losses) == 1
|
|
710
|
+
|
|
711
|
+
# 3. Sample from trained model
|
|
712
|
+
sampler = SampleLDM(
|
|
713
|
+
diffusion_model="sde",
|
|
714
|
+
reverse_diffusion=reverse_sde,
|
|
715
|
+
noise_predictor=noise_predictor,
|
|
716
|
+
compressor_model=compressor,
|
|
717
|
+
image_shape=(32, 32),
|
|
718
|
+
conditional_model=text_encoder,
|
|
719
|
+
batch_size=2,
|
|
720
|
+
device='cpu'
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
generated_images = sampler(
|
|
724
|
+
conditions=["class_0", "class_1"],
|
|
725
|
+
save_images=False
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
assert generated_images.shape == (2, 3, 32, 32)
|
|
729
|
+
assert not torch.isnan(generated_images).any()
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
if __name__ == "__main__":
|
|
733
|
+
import subprocess
|
|
734
|
+
import sys
|
|
735
|
+
|
|
736
|
+
try:
|
|
737
|
+
import pytest
|
|
738
|
+
except ImportError:
|
|
739
|
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "pytest"])
|
|
740
|
+
import pytest
|
|
741
|
+
|
|
742
|
+
pytest.main([__file__, "-v"])
|