TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ddpm/noise_predictor.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class NoisePredictor(nn.Module):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
in_channels,
|
|
10
|
+
down_channels,
|
|
11
|
+
mid_channels,
|
|
12
|
+
up_channels,
|
|
13
|
+
down_sampling,
|
|
14
|
+
time_embed_dim,
|
|
15
|
+
y_embed_dim, # output embedding dimension in text conditional net
|
|
16
|
+
num_down_blocks,
|
|
17
|
+
num_mid_blocks,
|
|
18
|
+
num_up_blocks,
|
|
19
|
+
dropout_rate=0.1,
|
|
20
|
+
down_sampling_factor=2,
|
|
21
|
+
where_y=True,
|
|
22
|
+
y_to_all=False
|
|
23
|
+
):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.in_channels = in_channels
|
|
26
|
+
self.down_channels = down_channels
|
|
27
|
+
self.mid_channels = mid_channels
|
|
28
|
+
self.up_channels = up_channels
|
|
29
|
+
self.down_sampling = down_sampling
|
|
30
|
+
self.time_embed_dim = time_embed_dim
|
|
31
|
+
self.y_embed_dim = y_embed_dim
|
|
32
|
+
self.num_down_blocks = num_down_blocks
|
|
33
|
+
self.num_mid_blocks = num_mid_blocks
|
|
34
|
+
self.num_up_blocks = num_up_blocks
|
|
35
|
+
self.dropout_rate = dropout_rate
|
|
36
|
+
self.where_y = where_y
|
|
37
|
+
self.up_sampling = list(reversed(self.down_sampling))
|
|
38
|
+
self.conv1 = nn.Conv2d(
|
|
39
|
+
in_channels=self.in_channels,
|
|
40
|
+
out_channels=self.down_channels[0],
|
|
41
|
+
kernel_size=3,
|
|
42
|
+
padding=1
|
|
43
|
+
)
|
|
44
|
+
# initial time embedding projection
|
|
45
|
+
self.time_projection = nn.Sequential(
|
|
46
|
+
nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim),
|
|
47
|
+
nn.SiLU(),
|
|
48
|
+
nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim)
|
|
49
|
+
)
|
|
50
|
+
# down blocks
|
|
51
|
+
self.down_blocks = nn.ModuleList([
|
|
52
|
+
DownBlock(
|
|
53
|
+
in_channels=self.down_channels[i],
|
|
54
|
+
out_channels=self.down_channels[i+1],
|
|
55
|
+
time_embed_dim=self.time_embed_dim,
|
|
56
|
+
y_embed_dim=y_embed_dim,
|
|
57
|
+
num_layers=self.num_down_blocks,
|
|
58
|
+
down_sampling_factor=down_sampling_factor,
|
|
59
|
+
down_sample=self.down_sampling[i],
|
|
60
|
+
dropout_rate=self.dropout_rate,
|
|
61
|
+
y_to_all=y_to_all
|
|
62
|
+
) for i in range(len(self.down_channels)-1)
|
|
63
|
+
])
|
|
64
|
+
# middle blocks
|
|
65
|
+
self.mid_blocks = nn.ModuleList([
|
|
66
|
+
MiddleBlock(
|
|
67
|
+
in_channels=self.mid_channels[i],
|
|
68
|
+
out_channels=self.mid_channels[i + 1],
|
|
69
|
+
time_embed_dim=self.time_embed_dim,
|
|
70
|
+
y_embed_dim=y_embed_dim,
|
|
71
|
+
num_layers=self.num_mid_blocks,
|
|
72
|
+
dropout_rate=self.dropout_rate,
|
|
73
|
+
y_to_all=y_to_all
|
|
74
|
+
) for i in range(len(self.mid_channels) - 1)
|
|
75
|
+
])
|
|
76
|
+
# up blocks
|
|
77
|
+
skip_channels = list(reversed(self.down_channels))
|
|
78
|
+
self.up_blocks = nn.ModuleList([
|
|
79
|
+
UpBlock(
|
|
80
|
+
in_channels=self.up_channels[i],
|
|
81
|
+
out_channels=self.up_channels[i+1],
|
|
82
|
+
skip_channels=skip_channels[i],
|
|
83
|
+
time_embed_dim=self.time_embed_dim,
|
|
84
|
+
y_embed_dim=y_embed_dim,
|
|
85
|
+
num_layers=self.num_up_blocks,
|
|
86
|
+
up_sampling_factor=down_sampling_factor,
|
|
87
|
+
up_sampling=self.up_sampling[i],
|
|
88
|
+
dropout_rate=self.dropout_rate,
|
|
89
|
+
y_to_all=y_to_all
|
|
90
|
+
) for i in range(len(self.up_channels)-1)
|
|
91
|
+
])
|
|
92
|
+
# final convolution layer
|
|
93
|
+
self.conv2 = nn.Sequential(
|
|
94
|
+
nn.GroupNorm(num_groups=8, num_channels=self.up_channels[-1]),
|
|
95
|
+
nn.Dropout(p=self.dropout_rate),
|
|
96
|
+
nn.Conv2d(in_channels=self.up_channels[-1], out_channels=self.in_channels, kernel_size=3, padding=1)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def initialize_weights(self):
|
|
100
|
+
"""Initialize model weights for better training stability"""
|
|
101
|
+
for module in self.modules():
|
|
102
|
+
if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
|
|
103
|
+
nn.init.kaiming_normal_(module.weight, a=0.2, nonlinearity='leaky_relu')
|
|
104
|
+
if module.bias is not None:
|
|
105
|
+
nn.init.zeros_(module.bias)
|
|
106
|
+
|
|
107
|
+
def forward(self, x, t, y=None):
|
|
108
|
+
|
|
109
|
+
if not self.where_y and y is not None:
|
|
110
|
+
x = torch.cat(tensors=[x, y], dim=1)
|
|
111
|
+
output = self.conv1(x)
|
|
112
|
+
time_embed = GetEmbeddedTime(embed_dim=self.time_embed_dim)(time_steps=t)
|
|
113
|
+
time_embed = self.time_projection(time_embed)
|
|
114
|
+
skip_connections = []
|
|
115
|
+
|
|
116
|
+
for i, down in enumerate(self.down_blocks):
|
|
117
|
+
skip_connections.append(output)
|
|
118
|
+
output = down(x=output, embed_time=time_embed, y=y)
|
|
119
|
+
for i, mid in enumerate(self.mid_blocks):
|
|
120
|
+
output = mid(x=output, embed_time=time_embed, y=y)
|
|
121
|
+
for i, up in enumerate(self.up_blocks):
|
|
122
|
+
skip_connection = skip_connections.pop()
|
|
123
|
+
output = up(x=output, skip_connection=skip_connection, embed_time=time_embed, y=y)
|
|
124
|
+
|
|
125
|
+
output = self.conv2(output)
|
|
126
|
+
return output
|
|
127
|
+
#-----------------------------------------------------------------------------
|
|
128
|
+
class DownBlock(nn.Module):
|
|
129
|
+
def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim,num_layers, down_sampling_factor, down_sample, dropout_rate, y_to_all):
|
|
130
|
+
super().__init__()
|
|
131
|
+
self.num_layers = num_layers
|
|
132
|
+
self.y_to_all = y_to_all
|
|
133
|
+
self.conv1 = nn.ModuleList([
|
|
134
|
+
Conv3(
|
|
135
|
+
in_channels=in_channels if i==0 else out_channels,
|
|
136
|
+
out_channels=out_channels,
|
|
137
|
+
num_groups=8,
|
|
138
|
+
kernel_size=3,
|
|
139
|
+
norm=True,
|
|
140
|
+
activation=True,
|
|
141
|
+
dropout_rate=dropout_rate
|
|
142
|
+
) for i in range(self.num_layers)
|
|
143
|
+
])
|
|
144
|
+
self.conv2 = nn.ModuleList([
|
|
145
|
+
Conv3(
|
|
146
|
+
in_channels=out_channels,
|
|
147
|
+
out_channels=out_channels,
|
|
148
|
+
num_groups=8,
|
|
149
|
+
kernel_size=3,
|
|
150
|
+
norm=True,
|
|
151
|
+
activation=True,
|
|
152
|
+
dropout_rate=dropout_rate
|
|
153
|
+
) for _ in range(self.num_layers)
|
|
154
|
+
])
|
|
155
|
+
self.time_embedding = nn.ModuleList([
|
|
156
|
+
TimeEmbedding(
|
|
157
|
+
output_dim=out_channels,
|
|
158
|
+
embed_dim=time_embed_dim
|
|
159
|
+
) for _ in range(self.num_layers)
|
|
160
|
+
])
|
|
161
|
+
self.attention = nn.ModuleList([
|
|
162
|
+
Attention(
|
|
163
|
+
in_channels=out_channels,
|
|
164
|
+
y_embed_dim= y_embed_dim,
|
|
165
|
+
num_groups=8,
|
|
166
|
+
num_heads=4,
|
|
167
|
+
dropout_rate=dropout_rate
|
|
168
|
+
) for _ in range(self.num_layers)
|
|
169
|
+
])
|
|
170
|
+
self.down_sampling = DownSampling(
|
|
171
|
+
in_channels=out_channels,
|
|
172
|
+
out_channels=out_channels,
|
|
173
|
+
down_sampling_factor=down_sampling_factor,
|
|
174
|
+
conv_block=True,
|
|
175
|
+
max_pool=True
|
|
176
|
+
) if down_sample else nn.Identity()
|
|
177
|
+
self.resnet = nn.ModuleList([
|
|
178
|
+
nn.Conv2d(
|
|
179
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
180
|
+
out_channels=out_channels,
|
|
181
|
+
kernel_size=1
|
|
182
|
+
) for i in range(num_layers)
|
|
183
|
+
|
|
184
|
+
])
|
|
185
|
+
|
|
186
|
+
def forward(self, x, embed_time, y):
|
|
187
|
+
#print("down-block input shape:", x.size())
|
|
188
|
+
output = x
|
|
189
|
+
for i in range(self.num_layers):
|
|
190
|
+
resnet_input = output
|
|
191
|
+
output = self.conv1[i](output)
|
|
192
|
+
output = output + self.time_embedding[i](embed_time)[:, :, None, None]
|
|
193
|
+
output = self.conv2[i](output)
|
|
194
|
+
output = output + self.resnet[i](resnet_input)
|
|
195
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
196
|
+
out_attn = self.attention[i](output, y)
|
|
197
|
+
output = output + out_attn
|
|
198
|
+
elif y is not None and self.y_to_all:
|
|
199
|
+
out_attn = self.attention[i](output, y)
|
|
200
|
+
output = output + out_attn
|
|
201
|
+
elif y is None and self.y_to_all:
|
|
202
|
+
out_attn = self.attention[i](output)
|
|
203
|
+
output = output + out_attn
|
|
204
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
205
|
+
out_attn = self.attention[i](output)
|
|
206
|
+
output = output + out_attn
|
|
207
|
+
|
|
208
|
+
output = self.down_sampling(output)
|
|
209
|
+
#print("down-block output shape:", output.size())
|
|
210
|
+
return output
|
|
211
|
+
#------------------------------------------------------------------------------
|
|
212
|
+
class MiddleBlock(nn.Module):
|
|
213
|
+
def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim, num_layers, dropout_rate, y_to_all=False):
|
|
214
|
+
super().__init__()
|
|
215
|
+
self.num_layers = num_layers
|
|
216
|
+
self.y_to_all = y_to_all
|
|
217
|
+
self.conv1 = nn.ModuleList([
|
|
218
|
+
Conv3(
|
|
219
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
220
|
+
out_channels=out_channels,
|
|
221
|
+
num_groups=8,
|
|
222
|
+
kernel_size=3,
|
|
223
|
+
norm=True,
|
|
224
|
+
activation=True,
|
|
225
|
+
dropout_rate=dropout_rate
|
|
226
|
+
) for i in range(self.num_layers+1)
|
|
227
|
+
])
|
|
228
|
+
self.conv2 = nn.ModuleList([
|
|
229
|
+
Conv3(
|
|
230
|
+
in_channels=out_channels,
|
|
231
|
+
out_channels=out_channels,
|
|
232
|
+
num_groups=8,
|
|
233
|
+
kernel_size=3,
|
|
234
|
+
norm=True,
|
|
235
|
+
activation=True,
|
|
236
|
+
dropout_rate=dropout_rate
|
|
237
|
+
) for _ in range(self.num_layers+1)
|
|
238
|
+
])
|
|
239
|
+
self.time_embedding = nn.ModuleList([
|
|
240
|
+
TimeEmbedding(
|
|
241
|
+
output_dim=out_channels,
|
|
242
|
+
embed_dim=time_embed_dim
|
|
243
|
+
) for _ in range(self.num_layers+1)
|
|
244
|
+
])
|
|
245
|
+
self.attention = nn.ModuleList([
|
|
246
|
+
Attention(
|
|
247
|
+
in_channels=out_channels,
|
|
248
|
+
y_embed_dim=y_embed_dim,
|
|
249
|
+
num_groups=8,
|
|
250
|
+
num_heads=4,
|
|
251
|
+
dropout_rate=dropout_rate
|
|
252
|
+
) for _ in range(self.num_layers + 1)
|
|
253
|
+
])
|
|
254
|
+
self.resnet = nn.ModuleList([
|
|
255
|
+
nn.Conv2d(
|
|
256
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
257
|
+
out_channels=out_channels,
|
|
258
|
+
kernel_size=1
|
|
259
|
+
) for i in range(num_layers+1)
|
|
260
|
+
])
|
|
261
|
+
|
|
262
|
+
def forward(self, x, embed_time, y=None):
|
|
263
|
+
#print("mid-input shape:", x.size())
|
|
264
|
+
output = x
|
|
265
|
+
resnet_input = output
|
|
266
|
+
output = self.conv1[0](output)
|
|
267
|
+
output = output + self.time_embedding[0](embed_time)[:, :, None, None]
|
|
268
|
+
output = self.conv2[0](output)
|
|
269
|
+
output = output + self.resnet[0](resnet_input)
|
|
270
|
+
for i in range(self.num_layers):
|
|
271
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
272
|
+
out_attn = self.attention[i](output, y)
|
|
273
|
+
output = output + out_attn
|
|
274
|
+
elif y is not None and self.y_to_all:
|
|
275
|
+
out_attn = self.attention[i](output, y)
|
|
276
|
+
output = output + out_attn
|
|
277
|
+
elif y is None and self.y_to_all:
|
|
278
|
+
out_attn = self.attention[i](output)
|
|
279
|
+
output = output + out_attn
|
|
280
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
281
|
+
out_attn = self.attention[i](output)
|
|
282
|
+
output = output + out_attn
|
|
283
|
+
resnet_input = output
|
|
284
|
+
output = self.conv1[i + 1](output)
|
|
285
|
+
output = output + self.time_embedding[i + 1](embed_time)[:, :, None, None]
|
|
286
|
+
output = self.conv2[i + 1](output)
|
|
287
|
+
output = output + self.resnet[i+1](resnet_input)
|
|
288
|
+
#print("mid-block output shape:", output.size())
|
|
289
|
+
|
|
290
|
+
return output
|
|
291
|
+
#------------------------------------------------------------------------------
|
|
292
|
+
class UpBlock(nn.Module):
|
|
293
|
+
def __init__(self, in_channels, out_channels, skip_channels, time_embed_dim, y_embed_dim, num_layers, up_sampling_factor, up_sampling=True, dropout_rate=0.2, y_to_all=False):
|
|
294
|
+
super().__init__()
|
|
295
|
+
self.num_layers = num_layers
|
|
296
|
+
self.y_to_all = y_to_all
|
|
297
|
+
effective_in_channels = in_channels//2 + skip_channels
|
|
298
|
+
self.conv1 = nn.ModuleList([
|
|
299
|
+
Conv3(
|
|
300
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
301
|
+
out_channels=out_channels,
|
|
302
|
+
num_groups=8,
|
|
303
|
+
kernel_size=3,
|
|
304
|
+
norm=True,
|
|
305
|
+
activation=True,
|
|
306
|
+
dropout_rate=dropout_rate
|
|
307
|
+
) for i in range(self.num_layers)
|
|
308
|
+
])
|
|
309
|
+
self.conv2 = nn.ModuleList([
|
|
310
|
+
Conv3(
|
|
311
|
+
in_channels=out_channels,
|
|
312
|
+
out_channels=out_channels,
|
|
313
|
+
num_groups=8,
|
|
314
|
+
kernel_size=3,
|
|
315
|
+
norm=True,
|
|
316
|
+
activation=True,
|
|
317
|
+
dropout_rate=dropout_rate
|
|
318
|
+
) for _ in range(self.num_layers)
|
|
319
|
+
])
|
|
320
|
+
self.time_embedding = nn.ModuleList([
|
|
321
|
+
TimeEmbedding(
|
|
322
|
+
output_dim=out_channels,
|
|
323
|
+
embed_dim=time_embed_dim
|
|
324
|
+
) for _ in range(self.num_layers)
|
|
325
|
+
])
|
|
326
|
+
self.attention = nn.ModuleList([
|
|
327
|
+
Attention(
|
|
328
|
+
in_channels=out_channels,
|
|
329
|
+
y_embed_dim=y_embed_dim,
|
|
330
|
+
num_groups=8,
|
|
331
|
+
num_heads=4,
|
|
332
|
+
dropout_rate=dropout_rate
|
|
333
|
+
) for _ in range(self.num_layers)
|
|
334
|
+
])
|
|
335
|
+
self.up_sampling = UpSampling(
|
|
336
|
+
in_channels=in_channels,
|
|
337
|
+
out_channels=in_channels,
|
|
338
|
+
up_sampling_factor=up_sampling_factor,
|
|
339
|
+
conv_block=True,
|
|
340
|
+
up_sampling=True
|
|
341
|
+
) if up_sampling else nn.Identity()
|
|
342
|
+
self.resnet = nn.ModuleList([
|
|
343
|
+
nn.Conv2d(
|
|
344
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
345
|
+
out_channels=out_channels,
|
|
346
|
+
kernel_size=1
|
|
347
|
+
) for i in range(num_layers)
|
|
348
|
+
|
|
349
|
+
])
|
|
350
|
+
|
|
351
|
+
def forward(self, x, skip_connection, embed_time, y=None):
|
|
352
|
+
#print("up-block input shape:", x.size())
|
|
353
|
+
x = self.up_sampling(x)
|
|
354
|
+
x = torch.cat(tensors=[x, skip_connection], dim=1)
|
|
355
|
+
output = x
|
|
356
|
+
for i in range(self.num_layers):
|
|
357
|
+
resnet_input = output
|
|
358
|
+
output = self.conv1[i](output)
|
|
359
|
+
output = output + self.time_embedding[i](embed_time)[:, :, None, None]
|
|
360
|
+
output = self.conv2[i](output)
|
|
361
|
+
output = output + self.resnet[i](resnet_input)
|
|
362
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
363
|
+
out_attn = self.attention[i](output, y)
|
|
364
|
+
output = output + out_attn
|
|
365
|
+
elif y is not None and self.y_to_all:
|
|
366
|
+
out_attn = self.attention[i](output, y)
|
|
367
|
+
output = output + out_attn
|
|
368
|
+
elif y is None and self.y_to_all:
|
|
369
|
+
out_attn = self.attention[i](output)
|
|
370
|
+
output = output + out_attn
|
|
371
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
372
|
+
out_attn = self.attention[i](output)
|
|
373
|
+
output = output + out_attn
|
|
374
|
+
#print("up-block output shape:", output.size())
|
|
375
|
+
return output
|
|
376
|
+
#------------------------------------------------------------------------
|
|
377
|
+
class Conv3(nn.Module):
|
|
378
|
+
def __init__(self, in_channels, out_channels, num_groups=8, kernel_size=3, norm=True, activation=True, dropout_rate=0.2):
|
|
379
|
+
super().__init__()
|
|
380
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
|
|
381
|
+
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels) if norm else nn.Identity()
|
|
382
|
+
self.activation = nn.SiLU() if activation else nn.Identity()
|
|
383
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
384
|
+
|
|
385
|
+
def forward(self, batch):
|
|
386
|
+
batch = self.conv(batch)
|
|
387
|
+
batch = self.group_norm(batch)
|
|
388
|
+
batch = self.activation(batch)
|
|
389
|
+
batch = self.dropout(batch)
|
|
390
|
+
return batch
|
|
391
|
+
#----------------------------------------------------------------
|
|
392
|
+
class TimeEmbedding(nn.Module):
|
|
393
|
+
def __init__(self, output_dim, embed_dim):
|
|
394
|
+
super().__init__()
|
|
395
|
+
self.embedding = nn.Sequential(
|
|
396
|
+
nn.SiLU(),
|
|
397
|
+
nn.Linear(in_features=embed_dim, out_features=output_dim)
|
|
398
|
+
)
|
|
399
|
+
def forward(self, batch):
|
|
400
|
+
return self.embedding(batch)
|
|
401
|
+
#----------------------------------------------------------------
|
|
402
|
+
class GetEmbeddedTime(nn.Module):
|
|
403
|
+
def __init__(self, embed_dim):
|
|
404
|
+
super().__init__()
|
|
405
|
+
assert embed_dim % 2 == 0, "The embedding dimension must be divisible by two"
|
|
406
|
+
self.embed_dim = embed_dim
|
|
407
|
+
|
|
408
|
+
def forward(self, time_steps):
|
|
409
|
+
i = torch.arange(start=0, end=self.embed_dim // 2, dtype=torch.float32, device=time_steps.device)
|
|
410
|
+
factor = 10000 ** (2 * i / self.embed_dim)
|
|
411
|
+
embed_time = time_steps[:, None] / factor
|
|
412
|
+
embed_time = torch.cat(tensors=[torch.sin(embed_time), torch.cos(embed_time)], dim=-1)
|
|
413
|
+
return embed_time
|
|
414
|
+
#----------------------------------------------------------------
|
|
415
|
+
class Attention(nn.Module):
|
|
416
|
+
def __init__(self, in_channels, y_embed_dim=768, num_heads=4, num_groups=8, dropout_rate=0.1):
|
|
417
|
+
super().__init__()
|
|
418
|
+
self.in_channels = in_channels
|
|
419
|
+
self.y_embed_dim = y_embed_dim
|
|
420
|
+
self.num_heads = num_heads
|
|
421
|
+
self.dropout_rate = dropout_rate
|
|
422
|
+
self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True)
|
|
423
|
+
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
|
|
424
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
425
|
+
self.y_projection = nn.Linear(y_embed_dim, in_channels)
|
|
426
|
+
|
|
427
|
+
def forward(self, x, y=None):
|
|
428
|
+
batch_size, channels, h, w = x.shape
|
|
429
|
+
assert channels == self.in_channels, f"Expected {self.in_channels} channels, got {channels}"
|
|
430
|
+
x_reshaped = x.view(batch_size, channels, h * w).permute(0, 2, 1)
|
|
431
|
+
if y is not None:
|
|
432
|
+
y = self.y_projection(y)
|
|
433
|
+
if y.dim() != 3:
|
|
434
|
+
if y.dim() == 2:
|
|
435
|
+
y = y.unsqueeze(1)
|
|
436
|
+
else:
|
|
437
|
+
raise ValueError(
|
|
438
|
+
f"Expected y to be 2D or 3D after projection, got {y.dim()}D with shape {y.shape}"
|
|
439
|
+
)
|
|
440
|
+
if y.shape[-1] != self.in_channels:
|
|
441
|
+
raise ValueError(
|
|
442
|
+
f"Expected y's embedding dim to match in_channels ({self.in_channels}), got {y.shape[-1]}"
|
|
443
|
+
)
|
|
444
|
+
out, _ = self.attention(x_reshaped, y, y)
|
|
445
|
+
else:
|
|
446
|
+
out, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
|
|
447
|
+
out = out.permute(0, 2, 1).view(batch_size, channels, h, w)
|
|
448
|
+
out = self.norm(out)
|
|
449
|
+
out = self.dropout(out)
|
|
450
|
+
return out
|
|
451
|
+
#-----------------------------------------------------------------
|
|
452
|
+
class DownSampling(nn.Module):
|
|
453
|
+
def __init__(self, in_channels, out_channels, down_sampling_factor, conv_block=True, max_pool=True):
|
|
454
|
+
super().__init__()
|
|
455
|
+
self.conv_block = conv_block
|
|
456
|
+
self.max_pool = max_pool
|
|
457
|
+
self.down_sampling_factor = down_sampling_factor
|
|
458
|
+
self.conv = nn.Sequential(
|
|
459
|
+
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
|
|
460
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2 if max_pool else out_channels,
|
|
461
|
+
kernel_size=3, stride=down_sampling_factor, padding=1)
|
|
462
|
+
) if conv_block else nn.Identity()
|
|
463
|
+
self.pool = nn.Sequential(
|
|
464
|
+
nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
|
|
465
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels//2 if conv_block else out_channels,
|
|
466
|
+
kernel_size=1, stride=1, padding=0)
|
|
467
|
+
) if max_pool else nn.Identity()
|
|
468
|
+
|
|
469
|
+
def forward(self, batch):
|
|
470
|
+
if not self.conv_block:
|
|
471
|
+
return self.pool(batch)
|
|
472
|
+
if not self.max_pool:
|
|
473
|
+
return self.conv(batch)
|
|
474
|
+
return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
|
|
475
|
+
#--------------------------------------------------------------------------
|
|
476
|
+
class UpSampling(nn.Module):
|
|
477
|
+
def __init__(self, in_channels, out_channels, up_sampling_factor, conv_block=True, up_sampling=True):
|
|
478
|
+
super().__init__()
|
|
479
|
+
self.conv_block = conv_block
|
|
480
|
+
self.up_sampling = up_sampling
|
|
481
|
+
self.up_sampling_factor = up_sampling_factor
|
|
482
|
+
half_out_channels = out_channels // 2
|
|
483
|
+
self.conv = nn.Sequential(
|
|
484
|
+
nn.ConvTranspose2d(
|
|
485
|
+
in_channels=in_channels,
|
|
486
|
+
out_channels=half_out_channels if up_sampling else out_channels,
|
|
487
|
+
kernel_size=3,
|
|
488
|
+
stride=up_sampling_factor,
|
|
489
|
+
padding=1,
|
|
490
|
+
output_padding=up_sampling_factor - 1
|
|
491
|
+
),
|
|
492
|
+
nn.Conv2d(
|
|
493
|
+
in_channels=half_out_channels if up_sampling else out_channels,
|
|
494
|
+
out_channels=half_out_channels if up_sampling else out_channels,
|
|
495
|
+
kernel_size=1,
|
|
496
|
+
stride=1,
|
|
497
|
+
padding=0
|
|
498
|
+
)
|
|
499
|
+
) if conv_block else nn.Identity()
|
|
500
|
+
|
|
501
|
+
self.up_sample = nn.Sequential(
|
|
502
|
+
nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
|
|
503
|
+
nn.Conv2d(in_channels=in_channels, out_channels=half_out_channels if conv_block else out_channels,
|
|
504
|
+
kernel_size=1, stride=1, padding=0)
|
|
505
|
+
) if up_sampling else nn.Identity()
|
|
506
|
+
|
|
507
|
+
def forward(self, batch):
|
|
508
|
+
if not self.conv_block:
|
|
509
|
+
return self.up_sample(batch)
|
|
510
|
+
if not self.up_sampling:
|
|
511
|
+
return self.conv(batch)
|
|
512
|
+
conv_output = self.conv(batch)
|
|
513
|
+
up_sample_output = self.up_sample(batch)
|
|
514
|
+
if conv_output.shape[2:] != up_sample_output.shape[2:]:
|
|
515
|
+
_, _, h, w = conv_output.shape
|
|
516
|
+
up_sample_output = torch.nn.functional.interpolate(
|
|
517
|
+
up_sample_output,
|
|
518
|
+
size=(h, w),
|
|
519
|
+
mode='nearest'
|
|
520
|
+
)
|
|
521
|
+
return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
|
ddpm/reverse_ddpm.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Reverse diffusion process for Denoising Diffusion Probabilistic Models (DDPM).
|
|
2
|
+
|
|
3
|
+
This module implements the reverse diffusion process as described in the DDPM paper
|
|
4
|
+
(Ho et al., 2020, "Denoising Diffusion Probabilistic Models"). The reverse process
|
|
5
|
+
gradually denoises a noisy input to reconstruct the original data distribution.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ReverseDDPM(nn.Module):
|
|
15
|
+
"""Reverse diffusion process of DDPM.
|
|
16
|
+
|
|
17
|
+
Implements the reverse diffusion process for DDPM, which iteratively denoises a
|
|
18
|
+
noisy input `xt` using a predicted noise component, as defined in Ho et al. (2020).
|
|
19
|
+
The process relies on a noise schedule that can be either fixed or trainable,
|
|
20
|
+
specified through the provided hyperparameters.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
hyper_params : object
|
|
25
|
+
Hyperparameter object containing the noise schedule parameters. Expected to have
|
|
26
|
+
attributes:
|
|
27
|
+
- `num_steps`: Number of diffusion steps (int).
|
|
28
|
+
- `trainable_beta`: Whether the noise schedule is trainable (bool).
|
|
29
|
+
- `betas`: Noise schedule parameters (torch.Tensor, optional if trainable_beta is True).
|
|
30
|
+
- `alphas`: Precomputed alpha values (torch.Tensor, optional if trainable_beta is False).
|
|
31
|
+
- `alpha_bars`: Precomputed cumulative product of alphas (torch.Tensor, optional if trainable_beta is False).
|
|
32
|
+
- `compute_schedule`: Method to compute the noise schedule (callable, optional if trainable_beta is True).
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
hyper_params : object
|
|
37
|
+
Stores the provided hyperparameter object for use in the reverse process.
|
|
38
|
+
"""
|
|
39
|
+
def __init__(self, hyper_params):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.hyper_params = hyper_params # hyperparameters class
|
|
42
|
+
|
|
43
|
+
def forward(self, xt, predicted_noise, time_steps):
|
|
44
|
+
"""Applies the reverse diffusion process to the noisy input.
|
|
45
|
+
|
|
46
|
+
Denoises the input `xt` by computing the mean of the reverse process
|
|
47
|
+
distribution using the predicted noise and optionally adding stochastic noise
|
|
48
|
+
for time steps greater than 0, as per the DDPM reverse process.
|
|
49
|
+
|
|
50
|
+
Parameters
|
|
51
|
+
----------
|
|
52
|
+
xt : torch.Tensor
|
|
53
|
+
Noisy input tensor at time step `t`, of shape (batch_size, channels, height, width).
|
|
54
|
+
predicted_noise : torch.Tensor
|
|
55
|
+
Predicted noise tensor, of the same shape as `xt`, typically output by a neural network.
|
|
56
|
+
time_steps : torch.Tensor
|
|
57
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
58
|
+
is in the range [0, hyper_params.num_steps - 1].
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
torch.Tensor
|
|
63
|
+
Denoised tensor `xt_minus_1` at time step `t-1`, with the same shape as `xt`.
|
|
64
|
+
For time_steps == 0, returns the mean of the reverse process without added noise.
|
|
65
|
+
|
|
66
|
+
Raises
|
|
67
|
+
------
|
|
68
|
+
ValueError
|
|
69
|
+
If any value in `time_steps` is outside the valid range
|
|
70
|
+
[0, hyper_params.num_steps - 1].
|
|
71
|
+
"""
|
|
72
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.num_steps)):
|
|
73
|
+
raise ValueError(f"time_steps must be between 0 and {self.hyper_params.num_steps - 1}")
|
|
74
|
+
|
|
75
|
+
if self.hyper_params.trainable_beta:
|
|
76
|
+
betas_t, alphas_t, alpha_bars_t, _, _ = self.hyper_params.compute_schedule(self.hyper_params.betas)
|
|
77
|
+
betas_t = betas_t[time_steps].to(xt.device)
|
|
78
|
+
alphas_t = alphas_t[time_steps].to(xt.device)
|
|
79
|
+
alpha_bars_t = alpha_bars_t[time_steps].to(xt.device)
|
|
80
|
+
alpha_bars_t_minus_1 = alpha_bars_t[time_steps - 1].to(xt.device) if time_steps.any() else None
|
|
81
|
+
else:
|
|
82
|
+
betas_t = self.hyper_params.betas[time_steps].to(xt.device)
|
|
83
|
+
alphas_t = self.hyper_params.alphas[time_steps].to(xt.device)
|
|
84
|
+
alpha_bars_t = self.hyper_params.alpha_bars[time_steps].to(xt.device)
|
|
85
|
+
alpha_bars_t_minus_1 = self.hyper_params.alpha_bars[time_steps - 1].to(xt.device) if time_steps.any() else None
|
|
86
|
+
|
|
87
|
+
sqrt_alphas_t = torch.sqrt(alphas_t).view(-1, 1, 1, 1)
|
|
88
|
+
sqrt_one_minus_alpha_bars_t = torch.sqrt(1 - alpha_bars_t).view(-1, 1, 1, 1)
|
|
89
|
+
betas_t = betas_t.view(-1, 1, 1, 1)
|
|
90
|
+
|
|
91
|
+
mu = (xt - (betas_t / sqrt_one_minus_alpha_bars_t) * predicted_noise) / sqrt_alphas_t
|
|
92
|
+
|
|
93
|
+
mask = (time_steps == 0)
|
|
94
|
+
if mask.all():
|
|
95
|
+
return mu
|
|
96
|
+
|
|
97
|
+
variance = (1 - alpha_bars_t_minus_1) / (1 - alpha_bars_t) * betas_t.squeeze()
|
|
98
|
+
std = torch.sqrt(variance).view(-1, 1, 1, 1)
|
|
99
|
+
|
|
100
|
+
z = torch.randn_like(xt).to(xt.device)
|
|
101
|
+
xt_minus_1 = mu + (~mask).float().view(-1, 1, 1, 1) * std * z
|
|
102
|
+
return xt_minus_1
|