TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
|
@@ -0,0 +1,626 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import numpy as np
|
|
5
|
+
import os
|
|
6
|
+
import tempfile
|
|
7
|
+
from unittest.mock import Mock, patch
|
|
8
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
9
|
+
from torchdiff.utils import NoisePredictor, TextEncoder
|
|
10
|
+
from torchdiff.sde import (
|
|
11
|
+
VarianceSchedulerSDE,
|
|
12
|
+
ForwardSDE,
|
|
13
|
+
ReverseSDE,
|
|
14
|
+
TrainSDE,
|
|
15
|
+
SampleSDE
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TestVarianceSchedulerSDE:
|
|
20
|
+
"""Test cases for VarianceSchedulerSDE class."""
|
|
21
|
+
|
|
22
|
+
def test_init(self):
|
|
23
|
+
"""Test initialization with different parameters."""
|
|
24
|
+
# Test default initialization
|
|
25
|
+
scheduler = VarianceSchedulerSDE()
|
|
26
|
+
assert scheduler.num_steps == 1000
|
|
27
|
+
assert scheduler.beta_start == 1e-4
|
|
28
|
+
assert scheduler.beta_end == 0.02
|
|
29
|
+
|
|
30
|
+
# Test custom initialization
|
|
31
|
+
scheduler = VarianceSchedulerSDE(
|
|
32
|
+
num_steps=500,
|
|
33
|
+
beta_start=1e-5,
|
|
34
|
+
beta_end=0.01,
|
|
35
|
+
beta_method="quadratic"
|
|
36
|
+
)
|
|
37
|
+
assert scheduler.num_steps == 500
|
|
38
|
+
assert scheduler.beta_start == 1e-5
|
|
39
|
+
assert scheduler.beta_end == 0.01
|
|
40
|
+
|
|
41
|
+
def test_invalid_parameters(self):
|
|
42
|
+
"""Test initialization with invalid parameters."""
|
|
43
|
+
with pytest.raises(ValueError):
|
|
44
|
+
VarianceSchedulerSDE(num_steps=-10)
|
|
45
|
+
|
|
46
|
+
with pytest.raises(ValueError):
|
|
47
|
+
VarianceSchedulerSDE(beta_start=0.1, beta_end=0.01) # start > end
|
|
48
|
+
|
|
49
|
+
with pytest.raises(ValueError):
|
|
50
|
+
VarianceSchedulerSDE(sigma_start=1.0, sigma_end=0.5) # start > end
|
|
51
|
+
|
|
52
|
+
def test_beta_schedule_methods(self):
|
|
53
|
+
"""Test all beta schedule computation methods."""
|
|
54
|
+
methods = ["linear", "sigmoid", "quadratic", "constant", "inverse_time"]
|
|
55
|
+
|
|
56
|
+
for method in methods:
|
|
57
|
+
scheduler = VarianceSchedulerSDE(num_steps=100, beta_method=method)
|
|
58
|
+
betas = scheduler.betas
|
|
59
|
+
|
|
60
|
+
assert betas.shape[0] == 100
|
|
61
|
+
assert torch.all(betas >= scheduler.beta_start)
|
|
62
|
+
assert torch.all(betas <= scheduler.beta_end)
|
|
63
|
+
|
|
64
|
+
def test_cumulative_betas(self):
|
|
65
|
+
"""Test cumulative beta computation."""
|
|
66
|
+
scheduler = VarianceSchedulerSDE(num_steps=100)
|
|
67
|
+
cum_betas = scheduler._cum_betas
|
|
68
|
+
|
|
69
|
+
assert cum_betas.shape[0] == 100
|
|
70
|
+
assert cum_betas[0] > 0 # Should be positive
|
|
71
|
+
assert cum_betas[-1] > cum_betas[0] # Should be increasing
|
|
72
|
+
|
|
73
|
+
def test_sigmas(self):
|
|
74
|
+
"""Test sigma computation."""
|
|
75
|
+
scheduler = VarianceSchedulerSDE(num_steps=100)
|
|
76
|
+
sigmas = scheduler.sigmas
|
|
77
|
+
|
|
78
|
+
assert sigmas.shape[0] == 100
|
|
79
|
+
assert sigmas[0] == scheduler.sigma_start
|
|
80
|
+
assert sigmas[-1] == scheduler.sigma_end
|
|
81
|
+
|
|
82
|
+
def test_get_variance(self):
|
|
83
|
+
"""Test variance computation for different methods."""
|
|
84
|
+
scheduler = VarianceSchedulerSDE(num_steps=100)
|
|
85
|
+
time_steps = torch.tensor([0, 50, 99])
|
|
86
|
+
|
|
87
|
+
for method in ["ve", "vp", "sub-vp"]:
|
|
88
|
+
variance = scheduler.get_variance(time_steps, method)
|
|
89
|
+
assert variance.shape[0] == 3
|
|
90
|
+
assert torch.all(variance >= 0) # Variance should be non-negative
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class TestForwardSDE:
|
|
94
|
+
"""Test cases for ForwardSDE class."""
|
|
95
|
+
|
|
96
|
+
def setup_method(self):
|
|
97
|
+
"""Setup test fixtures."""
|
|
98
|
+
self.scheduler = VarianceSchedulerSDE(num_steps=100)
|
|
99
|
+
self.batch_size = 4
|
|
100
|
+
self.channels = 3
|
|
101
|
+
self.height = 32
|
|
102
|
+
self.width = 32
|
|
103
|
+
|
|
104
|
+
def test_init(self):
|
|
105
|
+
"""Test initialization."""
|
|
106
|
+
for method in ["ve", "vp", "sub-vp", "ode"]:
|
|
107
|
+
forward_sde = ForwardSDE(self.scheduler, method)
|
|
108
|
+
assert forward_sde.sde_method == method
|
|
109
|
+
|
|
110
|
+
with pytest.raises(ValueError):
|
|
111
|
+
ForwardSDE(self.scheduler, "invalid_method")
|
|
112
|
+
|
|
113
|
+
def test_forward_ve(self):
|
|
114
|
+
"""Test VE forward process."""
|
|
115
|
+
forward_sde = ForwardSDE(self.scheduler, "ve")
|
|
116
|
+
x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
117
|
+
noise = torch.randn_like(x0)
|
|
118
|
+
time_steps = torch.randint(0, 100, (self.batch_size,))
|
|
119
|
+
|
|
120
|
+
xt = forward_sde(x0, noise, time_steps)
|
|
121
|
+
assert xt.shape == x0.shape
|
|
122
|
+
|
|
123
|
+
def test_forward_vp(self):
|
|
124
|
+
"""Test VP forward process."""
|
|
125
|
+
forward_sde = ForwardSDE(self.scheduler, "vp")
|
|
126
|
+
x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
127
|
+
noise = torch.randn_like(x0)
|
|
128
|
+
time_steps = torch.randint(0, 100, (self.batch_size,))
|
|
129
|
+
|
|
130
|
+
xt = forward_sde(x0, noise, time_steps)
|
|
131
|
+
assert xt.shape == x0.shape
|
|
132
|
+
|
|
133
|
+
def test_forward_sub_vp(self):
|
|
134
|
+
"""Test sub-VP forward process."""
|
|
135
|
+
forward_sde = ForwardSDE(self.scheduler, "sub-vp")
|
|
136
|
+
x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
137
|
+
noise = torch.randn_like(x0)
|
|
138
|
+
time_steps = torch.randint(0, 100, (self.batch_size,))
|
|
139
|
+
|
|
140
|
+
xt = forward_sde(x0, noise, time_steps)
|
|
141
|
+
assert xt.shape == x0.shape
|
|
142
|
+
|
|
143
|
+
def test_forward_ode(self):
|
|
144
|
+
"""Test ODE forward process."""
|
|
145
|
+
forward_sde = ForwardSDE(self.scheduler, "ode")
|
|
146
|
+
x0 = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
147
|
+
noise = torch.randn_like(x0)
|
|
148
|
+
time_steps = torch.randint(0, 100, (self.batch_size,))
|
|
149
|
+
|
|
150
|
+
xt = forward_sde(x0, noise, time_steps)
|
|
151
|
+
assert xt.shape == x0.shape
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class TestReverseSDE:
|
|
155
|
+
"""Test cases for ReverseSDE class."""
|
|
156
|
+
|
|
157
|
+
def setup_method(self):
|
|
158
|
+
"""Setup test fixtures."""
|
|
159
|
+
self.scheduler = VarianceSchedulerSDE(num_steps=100)
|
|
160
|
+
self.batch_size = 4
|
|
161
|
+
self.channels = 3
|
|
162
|
+
self.height = 32
|
|
163
|
+
self.width = 32
|
|
164
|
+
|
|
165
|
+
def test_init(self):
|
|
166
|
+
"""Test initialization."""
|
|
167
|
+
for method in ["ve", "vp", "sub-vp", "ode"]:
|
|
168
|
+
reverse_sde = ReverseSDE(self.scheduler, method)
|
|
169
|
+
assert reverse_sde.sde_method == method
|
|
170
|
+
|
|
171
|
+
with pytest.raises(ValueError):
|
|
172
|
+
ReverseSDE(self.scheduler, "invalid_method")
|
|
173
|
+
|
|
174
|
+
def test_reverse_ve(self):
|
|
175
|
+
"""Test VE reverse process."""
|
|
176
|
+
reverse_sde = ReverseSDE(self.scheduler, "ve")
|
|
177
|
+
xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
178
|
+
noise = torch.randn_like(xt)
|
|
179
|
+
predicted_noise = torch.randn_like(xt)
|
|
180
|
+
time_steps = torch.randint(1, 100, (self.batch_size,))
|
|
181
|
+
|
|
182
|
+
xt_prev = reverse_sde(xt, noise, predicted_noise, time_steps)
|
|
183
|
+
assert xt_prev.shape == xt.shape
|
|
184
|
+
|
|
185
|
+
def test_reverse_vp(self):
|
|
186
|
+
"""Test VP reverse process."""
|
|
187
|
+
reverse_sde = ReverseSDE(self.scheduler, "vp")
|
|
188
|
+
xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
189
|
+
noise = torch.randn_like(xt)
|
|
190
|
+
predicted_noise = torch.randn_like(xt)
|
|
191
|
+
time_steps = torch.randint(0, 100, (self.batch_size,))
|
|
192
|
+
|
|
193
|
+
xt_prev = reverse_sde(xt, noise, predicted_noise, time_steps)
|
|
194
|
+
assert xt_prev.shape == xt.shape
|
|
195
|
+
|
|
196
|
+
def test_reverse_sub_vp(self):
|
|
197
|
+
"""Test sub-VP reverse process."""
|
|
198
|
+
reverse_sde = ReverseSDE(self.scheduler, "sub-vp")
|
|
199
|
+
xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
200
|
+
noise = torch.randn_like(xt)
|
|
201
|
+
predicted_noise = torch.randn_like(xt)
|
|
202
|
+
time_steps = torch.randint(0, 100, (self.batch_size,))
|
|
203
|
+
|
|
204
|
+
xt_prev = reverse_sde(xt, noise, predicted_noise, time_steps)
|
|
205
|
+
assert xt_prev.shape == xt.shape
|
|
206
|
+
|
|
207
|
+
def test_reverse_ode(self):
|
|
208
|
+
"""Test ODE reverse process."""
|
|
209
|
+
reverse_sde = ReverseSDE(self.scheduler, "ode")
|
|
210
|
+
xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
|
|
211
|
+
noise = None # ODE doesn't use noise
|
|
212
|
+
predicted_noise = torch.randn_like(xt)
|
|
213
|
+
time_steps = torch.randint(0, 100, (self.batch_size,))
|
|
214
|
+
|
|
215
|
+
xt_prev = reverse_sde(xt, noise, predicted_noise, time_steps)
|
|
216
|
+
assert xt_prev.shape == xt.shape
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class TestTrainSDE:
|
|
220
|
+
"""Test cases for TrainSDE class."""
|
|
221
|
+
|
|
222
|
+
def setup_method(self):
|
|
223
|
+
"""Setup test fixtures."""
|
|
224
|
+
|
|
225
|
+
# Create simple models for testing
|
|
226
|
+
class SimpleNoisePredictor(nn.Module):
|
|
227
|
+
def __init__(self):
|
|
228
|
+
super().__init__()
|
|
229
|
+
self.conv = nn.Conv2d(3, 3, 3, padding=1)
|
|
230
|
+
|
|
231
|
+
def forward(self, x, t, y=None, mask=None):
|
|
232
|
+
return self.conv(x)
|
|
233
|
+
|
|
234
|
+
class SimpleConditionalModel(nn.Module):
|
|
235
|
+
def __init__(self):
|
|
236
|
+
super().__init__()
|
|
237
|
+
self.embed = nn.Linear(77, 64)
|
|
238
|
+
|
|
239
|
+
def forward(self, input_ids, attention_mask=None):
|
|
240
|
+
return self.embed(input_ids.float())
|
|
241
|
+
|
|
242
|
+
# Create test data
|
|
243
|
+
self.batch_size = 4
|
|
244
|
+
self.channels = 3
|
|
245
|
+
self.height = 32
|
|
246
|
+
self.width = 32
|
|
247
|
+
|
|
248
|
+
x_data = torch.randn(20, self.channels, self.height, self.width)
|
|
249
|
+
y_data = torch.randint(0, 10, (20,))
|
|
250
|
+
dataset = TensorDataset(x_data, y_data)
|
|
251
|
+
self.data_loader = DataLoader(dataset, batch_size=self.batch_size)
|
|
252
|
+
|
|
253
|
+
# Create components
|
|
254
|
+
self.scheduler = VarianceSchedulerSDE(num_steps=10)
|
|
255
|
+
self.forward_sde = ForwardSDE(self.scheduler, "vp")
|
|
256
|
+
self.reverse_sde = ReverseSDE(self.scheduler, "vp")
|
|
257
|
+
self.noise_predictor = SimpleNoisePredictor()
|
|
258
|
+
self.conditional_model = SimpleConditionalModel()
|
|
259
|
+
self.optimizer = torch.optim.Adam(
|
|
260
|
+
list(self.noise_predictor.parameters()) +
|
|
261
|
+
list(self.conditional_model.parameters()),
|
|
262
|
+
lr=1e-4
|
|
263
|
+
)
|
|
264
|
+
self.objective = nn.MSELoss()
|
|
265
|
+
|
|
266
|
+
def test_init(self):
|
|
267
|
+
"""Test initialization."""
|
|
268
|
+
trainer = TrainSDE(
|
|
269
|
+
noise_predictor=self.noise_predictor,
|
|
270
|
+
forward_diffusion=self.forward_sde,
|
|
271
|
+
reverse_diffusion=self.reverse_sde,
|
|
272
|
+
data_loader=self.data_loader,
|
|
273
|
+
optimizer=self.optimizer,
|
|
274
|
+
objective=self.objective,
|
|
275
|
+
conditional_model=self.conditional_model,
|
|
276
|
+
max_epochs=2
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
assert trainer is not None
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@patch('sde.TrainSDE._setup_ddp')
|
|
283
|
+
def test_ddp_setup(self, mock_setup_ddp):
|
|
284
|
+
trainer = TrainSDE(
|
|
285
|
+
noise_predictor=self.noise_predictor,
|
|
286
|
+
forward_diffusion=self.forward_sde,
|
|
287
|
+
reverse_diffusion=self.reverse_sde,
|
|
288
|
+
data_loader=self.data_loader,
|
|
289
|
+
optimizer=self.optimizer,
|
|
290
|
+
objective=self.objective,
|
|
291
|
+
use_ddp=True
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
mock_setup_ddp.assert_called_once()
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def test_single_gpu_setup(self):
|
|
298
|
+
"""Test single GPU setup."""
|
|
299
|
+
trainer = TrainSDE(
|
|
300
|
+
noise_predictor=self.noise_predictor,
|
|
301
|
+
forward_diffusion=self.forward_sde,
|
|
302
|
+
reverse_diffusion=self.reverse_sde,
|
|
303
|
+
data_loader=self.data_loader,
|
|
304
|
+
optimizer=self.optimizer,
|
|
305
|
+
objective=self.objective,
|
|
306
|
+
use_ddp=False
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
assert trainer.ddp_rank == 0
|
|
310
|
+
assert trainer.ddp_local_rank == 0
|
|
311
|
+
assert trainer.ddp_world_size == 1
|
|
312
|
+
assert trainer.master_process
|
|
313
|
+
|
|
314
|
+
def test_warmup_scheduler(self):
|
|
315
|
+
"""Test warmup scheduler creation."""
|
|
316
|
+
trainer = TrainSDE(
|
|
317
|
+
noise_predictor=self.noise_predictor,
|
|
318
|
+
forward_diffusion=self.forward_sde,
|
|
319
|
+
reverse_diffusion=self.reverse_sde,
|
|
320
|
+
data_loader=self.data_loader,
|
|
321
|
+
optimizer=self.optimizer,
|
|
322
|
+
objective=self.objective
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
scheduler = trainer.warmup_scheduler(self.optimizer, 10)
|
|
326
|
+
assert scheduler is not None
|
|
327
|
+
|
|
328
|
+
def test_process_conditional_input(self):
|
|
329
|
+
"""Test conditional input processing."""
|
|
330
|
+
trainer = TrainSDE(
|
|
331
|
+
noise_predictor=self.noise_predictor,
|
|
332
|
+
forward_diffusion=self.forward_sde,
|
|
333
|
+
reverse_diffusion=self.reverse_sde,
|
|
334
|
+
data_loader=self.data_loader,
|
|
335
|
+
optimizer=self.optimizer,
|
|
336
|
+
objective=self.objective,
|
|
337
|
+
conditional_model=self.conditional_model
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Test with tensor input
|
|
341
|
+
y_tensor = torch.tensor([1, 2, 3, 4])
|
|
342
|
+
y_encoded = trainer._process_conditional_input(y_tensor)
|
|
343
|
+
assert y_encoded is not None
|
|
344
|
+
|
|
345
|
+
# Test with list input
|
|
346
|
+
y_list = ["test1", "test2", "test3", "test4"]
|
|
347
|
+
y_encoded = trainer._process_conditional_input(y_list)
|
|
348
|
+
assert y_encoded is not None
|
|
349
|
+
|
|
350
|
+
def test_save_checkpoint(self):
|
|
351
|
+
"""Test checkpoint saving."""
|
|
352
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
353
|
+
trainer = TrainSDE(
|
|
354
|
+
noise_predictor=self.noise_predictor,
|
|
355
|
+
forward_diffusion=self.forward_sde,
|
|
356
|
+
reverse_diffusion=self.reverse_sde,
|
|
357
|
+
data_loader=self.data_loader,
|
|
358
|
+
optimizer=self.optimizer,
|
|
359
|
+
objective=self.objective,
|
|
360
|
+
store_path=temp_dir
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
trainer._save_checkpoint(1, 0.5)
|
|
364
|
+
|
|
365
|
+
# Check if file was created
|
|
366
|
+
files = os.listdir(temp_dir)
|
|
367
|
+
assert any(f.startswith("sde_epoch_1") for f in files)
|
|
368
|
+
|
|
369
|
+
def test_validate(self):
|
|
370
|
+
"""Test validation method."""
|
|
371
|
+
# Mock metrics
|
|
372
|
+
mock_metrics = Mock()
|
|
373
|
+
mock_metrics.forward.return_value = (1.0, 0.1, 25.0, 0.8, 0.2)
|
|
374
|
+
mock_metrics.fid = True
|
|
375
|
+
mock_metrics.metrics = True
|
|
376
|
+
mock_metrics.lpips = True
|
|
377
|
+
|
|
378
|
+
trainer = TrainSDE(
|
|
379
|
+
noise_predictor=self.noise_predictor,
|
|
380
|
+
forward_diffusion=self.forward_sde,
|
|
381
|
+
reverse_diffusion=self.reverse_sde,
|
|
382
|
+
data_loader=self.data_loader,
|
|
383
|
+
optimizer=self.optimizer,
|
|
384
|
+
objective=self.objective,
|
|
385
|
+
val_loader=self.data_loader,
|
|
386
|
+
metrics_=mock_metrics
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
val_loss, fid, mse, psnr, ssim, lpips = trainer.validate()
|
|
390
|
+
|
|
391
|
+
assert isinstance(val_loss, float)
|
|
392
|
+
assert isinstance(fid, float)
|
|
393
|
+
assert isinstance(mse, float)
|
|
394
|
+
assert isinstance(psnr, float)
|
|
395
|
+
assert isinstance(ssim, float)
|
|
396
|
+
assert isinstance(lpips, float)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class TestSampleSDE:
|
|
400
|
+
"""Test cases for SampleSDE class."""
|
|
401
|
+
|
|
402
|
+
def setup_method(self):
|
|
403
|
+
"""Setup test fixtures."""
|
|
404
|
+
|
|
405
|
+
# Create simple models for testing
|
|
406
|
+
class SimpleNoisePredictor(nn.Module):
|
|
407
|
+
def __init__(self):
|
|
408
|
+
super().__init__()
|
|
409
|
+
self.conv = nn.Conv2d(3, 3, 3, padding=1)
|
|
410
|
+
|
|
411
|
+
def forward(self, x, t, y=None):
|
|
412
|
+
return self.conv(x)
|
|
413
|
+
|
|
414
|
+
class SimpleConditionalModel(nn.Module):
|
|
415
|
+
def __init__(self):
|
|
416
|
+
super().__init__()
|
|
417
|
+
self.embed = nn.Linear(77, 64)
|
|
418
|
+
|
|
419
|
+
def forward(self, input_ids, attention_mask=None):
|
|
420
|
+
return self.embed(input_ids.float())
|
|
421
|
+
|
|
422
|
+
# Create components
|
|
423
|
+
self.scheduler = VarianceSchedulerSDE(num_steps=10)
|
|
424
|
+
self.reverse_sde = ReverseSDE(self.scheduler, "vp")
|
|
425
|
+
self.noise_predictor = SimpleNoisePredictor()
|
|
426
|
+
self.conditional_model = SimpleConditionalModel()
|
|
427
|
+
|
|
428
|
+
def test_init(self):
|
|
429
|
+
"""Test initialization."""
|
|
430
|
+
sampler = SampleSDE(
|
|
431
|
+
reverse_diffusion=self.reverse_sde,
|
|
432
|
+
noise_predictor=self.noise_predictor,
|
|
433
|
+
image_shape=(32, 32)
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
assert sampler is not None
|
|
437
|
+
|
|
438
|
+
def test_tokenize(self):
|
|
439
|
+
"""Test tokenization method."""
|
|
440
|
+
sampler = SampleSDE(
|
|
441
|
+
reverse_diffusion=self.reverse_sde,
|
|
442
|
+
noise_predictor=self.noise_predictor,
|
|
443
|
+
image_shape=(32, 32),
|
|
444
|
+
conditional_model=self.conditional_model
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Test with single prompt
|
|
448
|
+
input_ids, attention_mask = sampler.tokenize("a test prompt")
|
|
449
|
+
assert input_ids.shape[0] == 1
|
|
450
|
+
assert attention_mask.shape[0] == 1
|
|
451
|
+
|
|
452
|
+
# Test with multiple prompts
|
|
453
|
+
input_ids, attention_mask = sampler.tokenize(["prompt1", "prompt2"])
|
|
454
|
+
assert input_ids.shape[0] == 2
|
|
455
|
+
assert attention_mask.shape[0] == 2
|
|
456
|
+
|
|
457
|
+
def test_forward_unconditional(self):
|
|
458
|
+
"""Test unconditional sampling."""
|
|
459
|
+
sampler = SampleSDE(
|
|
460
|
+
reverse_diffusion=self.reverse_sde,
|
|
461
|
+
noise_predictor=self.noise_predictor,
|
|
462
|
+
image_shape=(32, 32),
|
|
463
|
+
batch_size=2
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
467
|
+
images = sampler.forward(
|
|
468
|
+
conditions=None,
|
|
469
|
+
save_images=True,
|
|
470
|
+
save_path=temp_dir
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
assert images.shape == (2, 3, 32, 32)
|
|
474
|
+
assert torch.all(images >= 0) and torch.all(images <= 1) # Normalized
|
|
475
|
+
|
|
476
|
+
# Check if images were saved
|
|
477
|
+
files = os.listdir(temp_dir)
|
|
478
|
+
assert len(files) == 2
|
|
479
|
+
|
|
480
|
+
def test_forward_conditional(self):
|
|
481
|
+
"""Test conditional sampling."""
|
|
482
|
+
sampler = SampleSDE(
|
|
483
|
+
reverse_diffusion=self.reverse_sde,
|
|
484
|
+
noise_predictor=self.noise_predictor,
|
|
485
|
+
image_shape=(32, 32),
|
|
486
|
+
conditional_model=self.conditional_model,
|
|
487
|
+
batch_size=2
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
images = sampler.forward(
|
|
491
|
+
conditions=["a cat", "a dog"],
|
|
492
|
+
save_images=False
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
assert images.shape == (2, 3, 32, 32)
|
|
496
|
+
|
|
497
|
+
def test_to_device(self):
|
|
498
|
+
"""Test device movement."""
|
|
499
|
+
sampler = SampleSDE(
|
|
500
|
+
reverse_diffusion=self.reverse_sde,
|
|
501
|
+
noise_predictor=self.noise_predictor,
|
|
502
|
+
image_shape=(32, 32)
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
# Move to CPU if CUDA is available, otherwise test stays on CPU
|
|
506
|
+
target_device = torch.device("cpu")
|
|
507
|
+
sampler = sampler.to(target_device)
|
|
508
|
+
|
|
509
|
+
assert sampler.device == target_device
|
|
510
|
+
assert next(sampler.noise_predictor.parameters()).device == target_device
|
|
511
|
+
assert next(sampler.reverse.parameters()).device == target_device
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def test_integration():
|
|
515
|
+
"""Integration test with the provided usage code."""
|
|
516
|
+
# Set device
|
|
517
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
518
|
+
|
|
519
|
+
# Create simple test data
|
|
520
|
+
x_data = torch.randn(20, 3, 32, 32)
|
|
521
|
+
y_data = torch.randint(0, 10, (20,))
|
|
522
|
+
dataset = torch.utils.data.TensorDataset(x_data, y_data)
|
|
523
|
+
train_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)
|
|
524
|
+
val_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)
|
|
525
|
+
|
|
526
|
+
# Initialize models with smaller parameters for testing
|
|
527
|
+
noise_predictor = NoisePredictor(
|
|
528
|
+
in_channels=3,
|
|
529
|
+
down_channels=[8, 16], # Reduced channels for testing
|
|
530
|
+
mid_channels=[16, 16],
|
|
531
|
+
up_channels=[16, 8],
|
|
532
|
+
down_sampling=[True, False], # Only one downsampling for small images
|
|
533
|
+
time_embed_dim=32,
|
|
534
|
+
y_embed_dim=32,
|
|
535
|
+
num_down_blocks=1,
|
|
536
|
+
num_mid_blocks=1,
|
|
537
|
+
num_up_blocks=1,
|
|
538
|
+
down_sampling_factor=2
|
|
539
|
+
).to(device)
|
|
540
|
+
|
|
541
|
+
text_encoder = TextEncoder(
|
|
542
|
+
use_pretrained_model=False, # Don't use pretrained for faster testing
|
|
543
|
+
model_name="bert-base-uncased",
|
|
544
|
+
vocabulary_size=100, # Smaller vocabulary
|
|
545
|
+
num_layers=1, # Fewer layers
|
|
546
|
+
input_dimension=32,
|
|
547
|
+
output_dimension=32,
|
|
548
|
+
num_heads=2,
|
|
549
|
+
context_length=10 # Shorter context
|
|
550
|
+
).to(device)
|
|
551
|
+
|
|
552
|
+
# Optimizer and loss
|
|
553
|
+
optimizer = torch.optim.Adam(
|
|
554
|
+
[p for p in noise_predictor.parameters() if p.requires_grad] +
|
|
555
|
+
[p for p in text_encoder.parameters() if p.requires_grad],
|
|
556
|
+
lr=1e-4
|
|
557
|
+
)
|
|
558
|
+
loss = nn.MSELoss()
|
|
559
|
+
|
|
560
|
+
# SDE hyperparameters with fewer steps
|
|
561
|
+
hyperparams_sde = VarianceSchedulerSDE(
|
|
562
|
+
num_steps=10, # Fewer steps for testing
|
|
563
|
+
beta_start=1e-4,
|
|
564
|
+
beta_end=0.02,
|
|
565
|
+
trainable_beta=False,
|
|
566
|
+
sigma_start=1e-3,
|
|
567
|
+
sigma_end=10.0,
|
|
568
|
+
start=0.0,
|
|
569
|
+
end=1.0,
|
|
570
|
+
beta_method="linear"
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
# Forward and reverse SDE
|
|
574
|
+
forward_sde = ForwardSDE(variance_scheduler=hyperparams_sde, sde_method="vp")
|
|
575
|
+
reverse_sde = ReverseSDE(variance_scheduler=hyperparams_sde, sde_method="vp")
|
|
576
|
+
|
|
577
|
+
# TrainSDE with minimal settings
|
|
578
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
579
|
+
trainer = TrainSDE(
|
|
580
|
+
noise_predictor=noise_predictor,
|
|
581
|
+
forward_diffusion=forward_sde,
|
|
582
|
+
reverse_diffusion=reverse_sde,
|
|
583
|
+
data_loader=train_loader,
|
|
584
|
+
optimizer=optimizer,
|
|
585
|
+
objective=loss,
|
|
586
|
+
val_loader=val_loader,
|
|
587
|
+
max_epochs=2, # Just 2 epochs for testing
|
|
588
|
+
device=device,
|
|
589
|
+
conditional_model=text_encoder,
|
|
590
|
+
metrics_=None, # No metrics for faster testing
|
|
591
|
+
store_path=temp_dir,
|
|
592
|
+
val_frequency=1,
|
|
593
|
+
use_ddp=False,
|
|
594
|
+
grad_accumulation_steps=1,
|
|
595
|
+
log_frequency=1,
|
|
596
|
+
use_compilation=False
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
# Test training
|
|
600
|
+
train_losses, best_val_loss = trainer()
|
|
601
|
+
assert len(train_losses) >= 0 # Could be empty if early stopping
|
|
602
|
+
assert isinstance(best_val_loss, float)
|
|
603
|
+
|
|
604
|
+
# Test sampling
|
|
605
|
+
sampler = SampleSDE(
|
|
606
|
+
reverse_diffusion=reverse_sde,
|
|
607
|
+
noise_predictor=noise_predictor,
|
|
608
|
+
image_shape=(32, 32),
|
|
609
|
+
conditional_model=text_encoder,
|
|
610
|
+
tokenizer="bert-base-uncased",
|
|
611
|
+
max_token_length=10, # Shorter for testing
|
|
612
|
+
batch_size=2,
|
|
613
|
+
in_channels=3,
|
|
614
|
+
device=device,
|
|
615
|
+
image_output_range=(-1.0, 1.0)
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# Test with class names
|
|
619
|
+
class_names = ['airplane', 'automobile']
|
|
620
|
+
images = sampler(class_names, save_images=False)
|
|
621
|
+
assert images.shape == (2, 3, 32, 32)
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
if __name__ == "__main__":
|
|
625
|
+
# Run tests
|
|
626
|
+
pytest.main([__file__, "-v"])
|