TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ddim/noise_predictor.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class NoisePredictor(nn.Module):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
in_channels,
|
|
10
|
+
down_channels,
|
|
11
|
+
mid_channels,
|
|
12
|
+
up_channels,
|
|
13
|
+
down_sampling,
|
|
14
|
+
time_embed_dim,
|
|
15
|
+
y_embed_dim, # output embedding dimension in text conditional net
|
|
16
|
+
num_down_blocks,
|
|
17
|
+
num_mid_blocks,
|
|
18
|
+
num_up_blocks,
|
|
19
|
+
dropout_rate=0.1,
|
|
20
|
+
down_sampling_factor=2,
|
|
21
|
+
where_y=True,
|
|
22
|
+
y_to_all=False
|
|
23
|
+
):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.in_channels = in_channels
|
|
26
|
+
self.down_channels = down_channels
|
|
27
|
+
self.mid_channels = mid_channels
|
|
28
|
+
self.up_channels = up_channels
|
|
29
|
+
self.down_sampling = down_sampling
|
|
30
|
+
self.time_embed_dim = time_embed_dim
|
|
31
|
+
self.y_embed_dim = y_embed_dim
|
|
32
|
+
self.num_down_blocks = num_down_blocks
|
|
33
|
+
self.num_mid_blocks = num_mid_blocks
|
|
34
|
+
self.num_up_blocks = num_up_blocks
|
|
35
|
+
self.dropout_rate = dropout_rate
|
|
36
|
+
self.where_y = where_y
|
|
37
|
+
self.up_sampling = list(reversed(self.down_sampling))
|
|
38
|
+
self.conv1 = nn.Conv2d(
|
|
39
|
+
in_channels=self.in_channels,
|
|
40
|
+
out_channels=self.down_channels[0],
|
|
41
|
+
kernel_size=3,
|
|
42
|
+
padding=1
|
|
43
|
+
)
|
|
44
|
+
# initial time embedding projection
|
|
45
|
+
self.time_projection = nn.Sequential(
|
|
46
|
+
nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim),
|
|
47
|
+
nn.SiLU(),
|
|
48
|
+
nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim)
|
|
49
|
+
)
|
|
50
|
+
# down blocks
|
|
51
|
+
self.down_blocks = nn.ModuleList([
|
|
52
|
+
DownBlock(
|
|
53
|
+
in_channels=self.down_channels[i],
|
|
54
|
+
out_channels=self.down_channels[i+1],
|
|
55
|
+
time_embed_dim=self.time_embed_dim,
|
|
56
|
+
y_embed_dim=y_embed_dim,
|
|
57
|
+
num_layers=self.num_down_blocks,
|
|
58
|
+
down_sampling_factor=down_sampling_factor,
|
|
59
|
+
down_sample=self.down_sampling[i],
|
|
60
|
+
dropout_rate=self.dropout_rate,
|
|
61
|
+
y_to_all=y_to_all
|
|
62
|
+
) for i in range(len(self.down_channels)-1)
|
|
63
|
+
])
|
|
64
|
+
# middle blocks
|
|
65
|
+
self.mid_blocks = nn.ModuleList([
|
|
66
|
+
MiddleBlock(
|
|
67
|
+
in_channels=self.mid_channels[i],
|
|
68
|
+
out_channels=self.mid_channels[i + 1],
|
|
69
|
+
time_embed_dim=self.time_embed_dim,
|
|
70
|
+
y_embed_dim=y_embed_dim,
|
|
71
|
+
num_layers=self.num_mid_blocks,
|
|
72
|
+
dropout_rate=self.dropout_rate,
|
|
73
|
+
y_to_all=y_to_all
|
|
74
|
+
) for i in range(len(self.mid_channels) - 1)
|
|
75
|
+
])
|
|
76
|
+
# up blocks
|
|
77
|
+
skip_channels = list(reversed(self.down_channels))
|
|
78
|
+
self.up_blocks = nn.ModuleList([
|
|
79
|
+
UpBlock(
|
|
80
|
+
in_channels=self.up_channels[i],
|
|
81
|
+
out_channels=self.up_channels[i+1],
|
|
82
|
+
skip_channels=skip_channels[i],
|
|
83
|
+
time_embed_dim=self.time_embed_dim,
|
|
84
|
+
y_embed_dim=y_embed_dim,
|
|
85
|
+
num_layers=self.num_up_blocks,
|
|
86
|
+
up_sampling_factor=down_sampling_factor,
|
|
87
|
+
up_sampling=self.up_sampling[i],
|
|
88
|
+
dropout_rate=self.dropout_rate,
|
|
89
|
+
y_to_all=y_to_all
|
|
90
|
+
) for i in range(len(self.up_channels)-1)
|
|
91
|
+
])
|
|
92
|
+
# final convolution layer
|
|
93
|
+
self.conv2 = nn.Sequential(
|
|
94
|
+
nn.GroupNorm(num_groups=8, num_channels=self.up_channels[-1]),
|
|
95
|
+
nn.Dropout(p=self.dropout_rate),
|
|
96
|
+
nn.Conv2d(in_channels=self.up_channels[-1], out_channels=self.in_channels, kernel_size=3, padding=1)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def initialize_weights(self):
|
|
100
|
+
"""Initialize model weights for better training stability"""
|
|
101
|
+
for module in self.modules():
|
|
102
|
+
if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
|
|
103
|
+
nn.init.kaiming_normal_(module.weight, a=0.2, nonlinearity='leaky_relu')
|
|
104
|
+
if module.bias is not None:
|
|
105
|
+
nn.init.zeros_(module.bias)
|
|
106
|
+
|
|
107
|
+
def forward(self, x, t, y=None):
|
|
108
|
+
|
|
109
|
+
if not self.where_y and y is not None:
|
|
110
|
+
x = torch.cat(tensors=[x, y], dim=1)
|
|
111
|
+
output = self.conv1(x)
|
|
112
|
+
time_embed = GetEmbeddedTime(embed_dim=self.time_embed_dim)(time_steps=t)
|
|
113
|
+
time_embed = self.time_projection(time_embed)
|
|
114
|
+
skip_connections = []
|
|
115
|
+
|
|
116
|
+
for i, down in enumerate(self.down_blocks):
|
|
117
|
+
skip_connections.append(output)
|
|
118
|
+
output = down(x=output, embed_time=time_embed, y=y)
|
|
119
|
+
for i, mid in enumerate(self.mid_blocks):
|
|
120
|
+
output = mid(x=output, embed_time=time_embed, y=y)
|
|
121
|
+
for i, up in enumerate(self.up_blocks):
|
|
122
|
+
skip_connection = skip_connections.pop()
|
|
123
|
+
output = up(x=output, skip_connection=skip_connection, embed_time=time_embed, y=y)
|
|
124
|
+
|
|
125
|
+
output = self.conv2(output)
|
|
126
|
+
return output
|
|
127
|
+
#-----------------------------------------------------------------------------
|
|
128
|
+
class DownBlock(nn.Module):
|
|
129
|
+
def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim,num_layers, down_sampling_factor, down_sample, dropout_rate, y_to_all):
|
|
130
|
+
super().__init__()
|
|
131
|
+
self.num_layers = num_layers
|
|
132
|
+
self.y_to_all = y_to_all
|
|
133
|
+
self.conv1 = nn.ModuleList([
|
|
134
|
+
Conv3(
|
|
135
|
+
in_channels=in_channels if i==0 else out_channels,
|
|
136
|
+
out_channels=out_channels,
|
|
137
|
+
num_groups=8,
|
|
138
|
+
kernel_size=3,
|
|
139
|
+
norm=True,
|
|
140
|
+
activation=True,
|
|
141
|
+
dropout_rate=dropout_rate
|
|
142
|
+
) for i in range(self.num_layers)
|
|
143
|
+
])
|
|
144
|
+
self.conv2 = nn.ModuleList([
|
|
145
|
+
Conv3(
|
|
146
|
+
in_channels=out_channels,
|
|
147
|
+
out_channels=out_channels,
|
|
148
|
+
num_groups=8,
|
|
149
|
+
kernel_size=3,
|
|
150
|
+
norm=True,
|
|
151
|
+
activation=True,
|
|
152
|
+
dropout_rate=dropout_rate
|
|
153
|
+
) for _ in range(self.num_layers)
|
|
154
|
+
])
|
|
155
|
+
self.time_embedding = nn.ModuleList([
|
|
156
|
+
TimeEmbedding(
|
|
157
|
+
output_dim=out_channels,
|
|
158
|
+
embed_dim=time_embed_dim
|
|
159
|
+
) for _ in range(self.num_layers)
|
|
160
|
+
])
|
|
161
|
+
self.attention = nn.ModuleList([
|
|
162
|
+
Attention(
|
|
163
|
+
in_channels=out_channels,
|
|
164
|
+
y_embed_dim= y_embed_dim,
|
|
165
|
+
num_groups=8,
|
|
166
|
+
num_heads=4,
|
|
167
|
+
dropout_rate=dropout_rate
|
|
168
|
+
) for _ in range(self.num_layers)
|
|
169
|
+
])
|
|
170
|
+
self.down_sampling = DownSampling(
|
|
171
|
+
in_channels=out_channels,
|
|
172
|
+
out_channels=out_channels,
|
|
173
|
+
down_sampling_factor=down_sampling_factor,
|
|
174
|
+
conv_block=True,
|
|
175
|
+
max_pool=True
|
|
176
|
+
) if down_sample else nn.Identity()
|
|
177
|
+
self.resnet = nn.ModuleList([
|
|
178
|
+
nn.Conv2d(
|
|
179
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
180
|
+
out_channels=out_channels,
|
|
181
|
+
kernel_size=1
|
|
182
|
+
) for i in range(num_layers)
|
|
183
|
+
|
|
184
|
+
])
|
|
185
|
+
|
|
186
|
+
def forward(self, x, embed_time, y):
|
|
187
|
+
#print("down-block input shape:", x.size())
|
|
188
|
+
output = x
|
|
189
|
+
for i in range(self.num_layers):
|
|
190
|
+
resnet_input = output
|
|
191
|
+
output = self.conv1[i](output)
|
|
192
|
+
output = output + self.time_embedding[i](embed_time)[:, :, None, None]
|
|
193
|
+
output = self.conv2[i](output)
|
|
194
|
+
output = output + self.resnet[i](resnet_input)
|
|
195
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
196
|
+
out_attn = self.attention[i](output, y)
|
|
197
|
+
output = output + out_attn
|
|
198
|
+
elif y is not None and self.y_to_all:
|
|
199
|
+
out_attn = self.attention[i](output, y)
|
|
200
|
+
output = output + out_attn
|
|
201
|
+
elif y is None and self.y_to_all:
|
|
202
|
+
out_attn = self.attention[i](output)
|
|
203
|
+
output = output + out_attn
|
|
204
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
205
|
+
out_attn = self.attention[i](output)
|
|
206
|
+
output = output + out_attn
|
|
207
|
+
|
|
208
|
+
output = self.down_sampling(output)
|
|
209
|
+
#print("down-block output shape:", output.size())
|
|
210
|
+
return output
|
|
211
|
+
#------------------------------------------------------------------------------
|
|
212
|
+
class MiddleBlock(nn.Module):
|
|
213
|
+
def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim, num_layers, dropout_rate, y_to_all=False):
|
|
214
|
+
super().__init__()
|
|
215
|
+
self.num_layers = num_layers
|
|
216
|
+
self.y_to_all = y_to_all
|
|
217
|
+
self.conv1 = nn.ModuleList([
|
|
218
|
+
Conv3(
|
|
219
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
220
|
+
out_channels=out_channels,
|
|
221
|
+
num_groups=8,
|
|
222
|
+
kernel_size=3,
|
|
223
|
+
norm=True,
|
|
224
|
+
activation=True,
|
|
225
|
+
dropout_rate=dropout_rate
|
|
226
|
+
) for i in range(self.num_layers+1)
|
|
227
|
+
])
|
|
228
|
+
self.conv2 = nn.ModuleList([
|
|
229
|
+
Conv3(
|
|
230
|
+
in_channels=out_channels,
|
|
231
|
+
out_channels=out_channels,
|
|
232
|
+
num_groups=8,
|
|
233
|
+
kernel_size=3,
|
|
234
|
+
norm=True,
|
|
235
|
+
activation=True,
|
|
236
|
+
dropout_rate=dropout_rate
|
|
237
|
+
) for _ in range(self.num_layers+1)
|
|
238
|
+
])
|
|
239
|
+
self.time_embedding = nn.ModuleList([
|
|
240
|
+
TimeEmbedding(
|
|
241
|
+
output_dim=out_channels,
|
|
242
|
+
embed_dim=time_embed_dim
|
|
243
|
+
) for _ in range(self.num_layers+1)
|
|
244
|
+
])
|
|
245
|
+
self.attention = nn.ModuleList([
|
|
246
|
+
Attention(
|
|
247
|
+
in_channels=out_channels,
|
|
248
|
+
y_embed_dim=y_embed_dim,
|
|
249
|
+
num_groups=8,
|
|
250
|
+
num_heads=4,
|
|
251
|
+
dropout_rate=dropout_rate
|
|
252
|
+
) for _ in range(self.num_layers + 1)
|
|
253
|
+
])
|
|
254
|
+
self.resnet = nn.ModuleList([
|
|
255
|
+
nn.Conv2d(
|
|
256
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
257
|
+
out_channels=out_channels,
|
|
258
|
+
kernel_size=1
|
|
259
|
+
) for i in range(num_layers+1)
|
|
260
|
+
])
|
|
261
|
+
|
|
262
|
+
def forward(self, x, embed_time, y=None):
|
|
263
|
+
#print("mid-input shape:", x.size())
|
|
264
|
+
output = x
|
|
265
|
+
resnet_input = output
|
|
266
|
+
output = self.conv1[0](output)
|
|
267
|
+
output = output + self.time_embedding[0](embed_time)[:, :, None, None]
|
|
268
|
+
output = self.conv2[0](output)
|
|
269
|
+
output = output + self.resnet[0](resnet_input)
|
|
270
|
+
for i in range(self.num_layers):
|
|
271
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
272
|
+
out_attn = self.attention[i](output, y)
|
|
273
|
+
output = output + out_attn
|
|
274
|
+
elif y is not None and self.y_to_all:
|
|
275
|
+
out_attn = self.attention[i](output, y)
|
|
276
|
+
output = output + out_attn
|
|
277
|
+
elif y is None and self.y_to_all:
|
|
278
|
+
out_attn = self.attention[i](output)
|
|
279
|
+
output = output + out_attn
|
|
280
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
281
|
+
out_attn = self.attention[i](output)
|
|
282
|
+
output = output + out_attn
|
|
283
|
+
resnet_input = output
|
|
284
|
+
output = self.conv1[i + 1](output)
|
|
285
|
+
output = output + self.time_embedding[i + 1](embed_time)[:, :, None, None]
|
|
286
|
+
output = self.conv2[i + 1](output)
|
|
287
|
+
output = output + self.resnet[i+1](resnet_input)
|
|
288
|
+
#print("mid-block output shape:", output.size())
|
|
289
|
+
|
|
290
|
+
return output
|
|
291
|
+
#------------------------------------------------------------------------------
|
|
292
|
+
class UpBlock(nn.Module):
|
|
293
|
+
def __init__(self, in_channels, out_channels, skip_channels, time_embed_dim, y_embed_dim, num_layers, up_sampling_factor, up_sampling=True, dropout_rate=0.2, y_to_all=False):
|
|
294
|
+
super().__init__()
|
|
295
|
+
self.num_layers = num_layers
|
|
296
|
+
self.y_to_all = y_to_all
|
|
297
|
+
effective_in_channels = in_channels//2 + skip_channels
|
|
298
|
+
self.conv1 = nn.ModuleList([
|
|
299
|
+
Conv3(
|
|
300
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
301
|
+
out_channels=out_channels,
|
|
302
|
+
num_groups=8,
|
|
303
|
+
kernel_size=3,
|
|
304
|
+
norm=True,
|
|
305
|
+
activation=True,
|
|
306
|
+
dropout_rate=dropout_rate
|
|
307
|
+
) for i in range(self.num_layers)
|
|
308
|
+
])
|
|
309
|
+
self.conv2 = nn.ModuleList([
|
|
310
|
+
Conv3(
|
|
311
|
+
in_channels=out_channels,
|
|
312
|
+
out_channels=out_channels,
|
|
313
|
+
num_groups=8,
|
|
314
|
+
kernel_size=3,
|
|
315
|
+
norm=True,
|
|
316
|
+
activation=True,
|
|
317
|
+
dropout_rate=dropout_rate
|
|
318
|
+
) for _ in range(self.num_layers)
|
|
319
|
+
])
|
|
320
|
+
self.time_embedding = nn.ModuleList([
|
|
321
|
+
TimeEmbedding(
|
|
322
|
+
output_dim=out_channels,
|
|
323
|
+
embed_dim=time_embed_dim
|
|
324
|
+
) for _ in range(self.num_layers)
|
|
325
|
+
])
|
|
326
|
+
self.attention = nn.ModuleList([
|
|
327
|
+
Attention(
|
|
328
|
+
in_channels=out_channels,
|
|
329
|
+
y_embed_dim=y_embed_dim,
|
|
330
|
+
num_groups=8,
|
|
331
|
+
num_heads=4,
|
|
332
|
+
dropout_rate=dropout_rate
|
|
333
|
+
) for _ in range(self.num_layers)
|
|
334
|
+
])
|
|
335
|
+
self.up_sampling = UpSampling(
|
|
336
|
+
in_channels=in_channels,
|
|
337
|
+
out_channels=in_channels,
|
|
338
|
+
up_sampling_factor=up_sampling_factor,
|
|
339
|
+
conv_block=True,
|
|
340
|
+
up_sampling=True
|
|
341
|
+
) if up_sampling else nn.Identity()
|
|
342
|
+
self.resnet = nn.ModuleList([
|
|
343
|
+
nn.Conv2d(
|
|
344
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
345
|
+
out_channels=out_channels,
|
|
346
|
+
kernel_size=1
|
|
347
|
+
) for i in range(num_layers)
|
|
348
|
+
|
|
349
|
+
])
|
|
350
|
+
|
|
351
|
+
def forward(self, x, skip_connection, embed_time, y=None):
|
|
352
|
+
#print("up-block input shape:", x.size())
|
|
353
|
+
x = self.up_sampling(x)
|
|
354
|
+
x = torch.cat(tensors=[x, skip_connection], dim=1)
|
|
355
|
+
output = x
|
|
356
|
+
for i in range(self.num_layers):
|
|
357
|
+
resnet_input = output
|
|
358
|
+
output = self.conv1[i](output)
|
|
359
|
+
output = output + self.time_embedding[i](embed_time)[:, :, None, None]
|
|
360
|
+
output = self.conv2[i](output)
|
|
361
|
+
output = output + self.resnet[i](resnet_input)
|
|
362
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
363
|
+
out_attn = self.attention[i](output, y)
|
|
364
|
+
output = output + out_attn
|
|
365
|
+
elif y is not None and self.y_to_all:
|
|
366
|
+
out_attn = self.attention[i](output, y)
|
|
367
|
+
output = output + out_attn
|
|
368
|
+
elif y is None and self.y_to_all:
|
|
369
|
+
out_attn = self.attention[i](output)
|
|
370
|
+
output = output + out_attn
|
|
371
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
372
|
+
out_attn = self.attention[i](output)
|
|
373
|
+
output = output + out_attn
|
|
374
|
+
#print("up-block output shape:", output.size())
|
|
375
|
+
return output
|
|
376
|
+
#------------------------------------------------------------------------
|
|
377
|
+
class Conv3(nn.Module):
|
|
378
|
+
def __init__(self, in_channels, out_channels, num_groups=8, kernel_size=3, norm=True, activation=True, dropout_rate=0.2):
|
|
379
|
+
super().__init__()
|
|
380
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
|
|
381
|
+
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels) if norm else nn.Identity()
|
|
382
|
+
self.activation = nn.SiLU() if activation else nn.Identity()
|
|
383
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
384
|
+
|
|
385
|
+
def forward(self, batch):
|
|
386
|
+
batch = self.conv(batch)
|
|
387
|
+
batch = self.group_norm(batch)
|
|
388
|
+
batch = self.activation(batch)
|
|
389
|
+
batch = self.dropout(batch)
|
|
390
|
+
return batch
|
|
391
|
+
#----------------------------------------------------------------
|
|
392
|
+
class TimeEmbedding(nn.Module):
|
|
393
|
+
def __init__(self, output_dim, embed_dim):
|
|
394
|
+
super().__init__()
|
|
395
|
+
self.embedding = nn.Sequential(
|
|
396
|
+
nn.SiLU(),
|
|
397
|
+
nn.Linear(in_features=embed_dim, out_features=output_dim)
|
|
398
|
+
)
|
|
399
|
+
def forward(self, batch):
|
|
400
|
+
return self.embedding(batch)
|
|
401
|
+
#----------------------------------------------------------------
|
|
402
|
+
class GetEmbeddedTime(nn.Module):
|
|
403
|
+
def __init__(self, embed_dim):
|
|
404
|
+
super().__init__()
|
|
405
|
+
assert embed_dim % 2 == 0, "The embedding dimension must be divisible by two"
|
|
406
|
+
self.embed_dim = embed_dim
|
|
407
|
+
|
|
408
|
+
def forward(self, time_steps):
|
|
409
|
+
i = torch.arange(start=0, end=self.embed_dim // 2, dtype=torch.float32, device=time_steps.device)
|
|
410
|
+
factor = 10000 ** (2 * i / self.embed_dim)
|
|
411
|
+
embed_time = time_steps[:, None] / factor
|
|
412
|
+
embed_time = torch.cat(tensors=[torch.sin(embed_time), torch.cos(embed_time)], dim=-1)
|
|
413
|
+
return embed_time
|
|
414
|
+
#----------------------------------------------------------------
|
|
415
|
+
class Attention(nn.Module):
|
|
416
|
+
def __init__(self, in_channels, y_embed_dim=768, num_heads=4, num_groups=8, dropout_rate=0.1):
|
|
417
|
+
super().__init__()
|
|
418
|
+
self.in_channels = in_channels
|
|
419
|
+
self.y_embed_dim = y_embed_dim
|
|
420
|
+
self.num_heads = num_heads
|
|
421
|
+
self.dropout_rate = dropout_rate
|
|
422
|
+
self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True)
|
|
423
|
+
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
|
|
424
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
425
|
+
self.y_projection = nn.Linear(y_embed_dim, in_channels)
|
|
426
|
+
|
|
427
|
+
def forward(self, x, y=None):
|
|
428
|
+
batch_size, channels, h, w = x.shape
|
|
429
|
+
assert channels == self.in_channels, f"Expected {self.in_channels} channels, got {channels}"
|
|
430
|
+
x_reshaped = x.view(batch_size, channels, h * w).permute(0, 2, 1)
|
|
431
|
+
if y is not None:
|
|
432
|
+
y = self.y_projection(y)
|
|
433
|
+
if y.dim() != 3:
|
|
434
|
+
if y.dim() == 2:
|
|
435
|
+
y = y.unsqueeze(1)
|
|
436
|
+
else:
|
|
437
|
+
raise ValueError(
|
|
438
|
+
f"Expected y to be 2D or 3D after projection, got {y.dim()}D with shape {y.shape}"
|
|
439
|
+
)
|
|
440
|
+
if y.shape[-1] != self.in_channels:
|
|
441
|
+
raise ValueError(
|
|
442
|
+
f"Expected y's embedding dim to match in_channels ({self.in_channels}), got {y.shape[-1]}"
|
|
443
|
+
)
|
|
444
|
+
out, _ = self.attention(x_reshaped, y, y)
|
|
445
|
+
else:
|
|
446
|
+
out, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
|
|
447
|
+
out = out.permute(0, 2, 1).view(batch_size, channels, h, w)
|
|
448
|
+
out = self.norm(out)
|
|
449
|
+
out = self.dropout(out)
|
|
450
|
+
return out
|
|
451
|
+
#-----------------------------------------------------------------
|
|
452
|
+
class DownSampling(nn.Module):
|
|
453
|
+
def __init__(self, in_channels, out_channels, down_sampling_factor, conv_block=True, max_pool=True):
|
|
454
|
+
super().__init__()
|
|
455
|
+
self.conv_block = conv_block
|
|
456
|
+
self.max_pool = max_pool
|
|
457
|
+
self.down_sampling_factor = down_sampling_factor
|
|
458
|
+
self.conv = nn.Sequential(
|
|
459
|
+
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
|
|
460
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2 if max_pool else out_channels,
|
|
461
|
+
kernel_size=3, stride=down_sampling_factor, padding=1)
|
|
462
|
+
) if conv_block else nn.Identity()
|
|
463
|
+
self.pool = nn.Sequential(
|
|
464
|
+
nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
|
|
465
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels//2 if conv_block else out_channels,
|
|
466
|
+
kernel_size=1, stride=1, padding=0)
|
|
467
|
+
) if max_pool else nn.Identity()
|
|
468
|
+
|
|
469
|
+
def forward(self, batch):
|
|
470
|
+
if not self.conv_block:
|
|
471
|
+
return self.pool(batch)
|
|
472
|
+
if not self.max_pool:
|
|
473
|
+
return self.conv(batch)
|
|
474
|
+
return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
|
|
475
|
+
#--------------------------------------------------------------------------
|
|
476
|
+
class UpSampling(nn.Module):
|
|
477
|
+
def __init__(self, in_channels, out_channels, up_sampling_factor, conv_block=True, up_sampling=True):
|
|
478
|
+
super().__init__()
|
|
479
|
+
self.conv_block = conv_block
|
|
480
|
+
self.up_sampling = up_sampling
|
|
481
|
+
self.up_sampling_factor = up_sampling_factor
|
|
482
|
+
half_out_channels = out_channels // 2
|
|
483
|
+
self.conv = nn.Sequential(
|
|
484
|
+
nn.ConvTranspose2d(
|
|
485
|
+
in_channels=in_channels,
|
|
486
|
+
out_channels=half_out_channels if up_sampling else out_channels,
|
|
487
|
+
kernel_size=3,
|
|
488
|
+
stride=up_sampling_factor,
|
|
489
|
+
padding=1,
|
|
490
|
+
output_padding=up_sampling_factor - 1
|
|
491
|
+
),
|
|
492
|
+
nn.Conv2d(
|
|
493
|
+
in_channels=half_out_channels if up_sampling else out_channels,
|
|
494
|
+
out_channels=half_out_channels if up_sampling else out_channels,
|
|
495
|
+
kernel_size=1,
|
|
496
|
+
stride=1,
|
|
497
|
+
padding=0
|
|
498
|
+
)
|
|
499
|
+
) if conv_block else nn.Identity()
|
|
500
|
+
|
|
501
|
+
self.up_sample = nn.Sequential(
|
|
502
|
+
nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
|
|
503
|
+
nn.Conv2d(in_channels=in_channels, out_channels=half_out_channels if conv_block else out_channels,
|
|
504
|
+
kernel_size=1, stride=1, padding=0)
|
|
505
|
+
) if up_sampling else nn.Identity()
|
|
506
|
+
|
|
507
|
+
def forward(self, batch):
|
|
508
|
+
if not self.conv_block:
|
|
509
|
+
return self.up_sample(batch)
|
|
510
|
+
if not self.up_sampling:
|
|
511
|
+
return self.conv(batch)
|
|
512
|
+
conv_output = self.conv(batch)
|
|
513
|
+
up_sample_output = self.up_sample(batch)
|
|
514
|
+
if conv_output.shape[2:] != up_sample_output.shape[2:]:
|
|
515
|
+
_, _, h, w = conv_output.shape
|
|
516
|
+
up_sample_output = torch.nn.functional.interpolate(
|
|
517
|
+
up_sample_output,
|
|
518
|
+
size=(h, w),
|
|
519
|
+
mode='nearest'
|
|
520
|
+
)
|
|
521
|
+
return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
|
ddim/reverse_ddim.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Reverse diffusion process for Denoising Diffusion Implicit Models (DDIM).
|
|
2
|
+
|
|
3
|
+
This module implements the reverse diffusion process for DDIM, as described in Song et al.
|
|
4
|
+
(2021, "Denoising Diffusion Implicit Models"). The reverse process iteratively denoises a
|
|
5
|
+
noisy input to reconstruct the original data distribution using a subset of time steps.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ReverseDDIM(nn.Module):
|
|
14
|
+
"""Reverse diffusion process of DDIM.
|
|
15
|
+
|
|
16
|
+
Implements the reverse diffusion process for Denoising Diffusion Implicit Models
|
|
17
|
+
(DDIM), which denoises a noisy input `xt` using a predicted noise component and a
|
|
18
|
+
subsampled time step schedule, as defined in Song et al. (2021).
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
hyper_params : object
|
|
23
|
+
Hyperparameter object containing the noise schedule parameters. Expected to have
|
|
24
|
+
attributes:
|
|
25
|
+
- `tau_num_steps`: Number of subsampled time steps (int).
|
|
26
|
+
- `eta`: Noise scaling factor for the reverse process (float).
|
|
27
|
+
- `get_tau_schedule`: Method to compute the subsampled noise schedule (callable),
|
|
28
|
+
returning a tuple of (betas, alphas, alpha_bars, sqrt_alpha_cumprod,
|
|
29
|
+
sqrt_one_minus_alpha_cumprod).
|
|
30
|
+
|
|
31
|
+
Attributes
|
|
32
|
+
----------
|
|
33
|
+
hyper_params : object
|
|
34
|
+
Stores the provided hyperparameter object.
|
|
35
|
+
"""
|
|
36
|
+
def __init__(self, hyper_params):
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.hyper_params = hyper_params
|
|
39
|
+
|
|
40
|
+
def forward(self, xt, predicted_noise, time_steps, prev_time_steps):
|
|
41
|
+
"""Applies the reverse diffusion process to the noisy input.
|
|
42
|
+
|
|
43
|
+
Denoises the input `xt` at time step `t` to produce the previous step `xt_prev`
|
|
44
|
+
at `prev_time_steps` using the predicted noise and the DDIM reverse process.
|
|
45
|
+
Optionally includes stochastic noise scaled by `eta`.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
xt : torch.Tensor
|
|
50
|
+
Noisy input tensor at time step `t`, shape (batch_size, channels, height, width).
|
|
51
|
+
predicted_noise : torch.Tensor
|
|
52
|
+
Predicted noise tensor, same shape as `xt`, typically output by a neural network.
|
|
53
|
+
time_steps : torch.Tensor
|
|
54
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
55
|
+
is in the range [0, hyper_params.tau_num_steps - 1].
|
|
56
|
+
prev_time_steps : torch.Tensor
|
|
57
|
+
Tensor of previous time step indices (long), shape (batch_size,), where each
|
|
58
|
+
value is in the range [0, hyper_params.tau_num_steps - 1].
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
tuple
|
|
63
|
+
A tuple containing:
|
|
64
|
+
- xt_prev: Denoised tensor at `prev_time_steps`, same shape as `xt`.
|
|
65
|
+
- x0: Estimated original data (t=0), same shape as `xt`.
|
|
66
|
+
|
|
67
|
+
Raises
|
|
68
|
+
------
|
|
69
|
+
ValueError
|
|
70
|
+
If any value in `time_steps` or `prev_time_steps` is outside the valid range
|
|
71
|
+
[0, hyper_params.tau_num_steps - 1].
|
|
72
|
+
"""
|
|
73
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.tau_num_steps)):
|
|
74
|
+
raise ValueError(f"time_steps must be between 0 and {self.hyper_params.tau_num_steps - 1}")
|
|
75
|
+
if not torch.all((prev_time_steps >= 0) & (prev_time_steps < self.hyper_params.tau_num_steps)):
|
|
76
|
+
raise ValueError(f"prev_time_steps must be between 0 and {self.hyper_params.tau_num_steps - 1}")
|
|
77
|
+
|
|
78
|
+
_, _, _, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = self.hyper_params.get_tau_schedule()
|
|
79
|
+
tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
80
|
+
tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
81
|
+
prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
82
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
|
|
83
|
+
|
|
84
|
+
eta = self.hyper_params.eta
|
|
85
|
+
x0 = (xt - tau_sqrt_one_minus_alpha_cumprod_t * predicted_noise) / tau_sqrt_alpha_cumprod_t
|
|
86
|
+
noise_coeff = eta * ((tau_sqrt_one_minus_alpha_cumprod_t / prev_tau_sqrt_alpha_cumprod_t) *
|
|
87
|
+
prev_tau_sqrt_one_minus_alpha_cumprod_t / torch.clamp(tau_sqrt_one_minus_alpha_cumprod_t, min=1e-8))
|
|
88
|
+
direction_coeff = torch.clamp(prev_tau_sqrt_one_minus_alpha_cumprod_t ** 2 - noise_coeff ** 2, min=1e-8).sqrt()
|
|
89
|
+
xt_prev = prev_tau_sqrt_alpha_cumprod_t * x0 + noise_coeff * torch.randn_like(xt) + direction_coeff * predicted_noise
|
|
90
|
+
|
|
91
|
+
return xt_prev, x0
|