TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
sde/noise_predictor.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class NoisePredictor(nn.Module):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
in_channels,
|
|
10
|
+
down_channels,
|
|
11
|
+
mid_channels,
|
|
12
|
+
up_channels,
|
|
13
|
+
down_sampling,
|
|
14
|
+
time_embed_dim,
|
|
15
|
+
y_embed_dim, # output embedding dimension in text conditional net
|
|
16
|
+
num_down_blocks,
|
|
17
|
+
num_mid_blocks,
|
|
18
|
+
num_up_blocks,
|
|
19
|
+
dropout_rate=0.1,
|
|
20
|
+
down_sampling_factor=2,
|
|
21
|
+
where_y=True,
|
|
22
|
+
y_to_all=False
|
|
23
|
+
):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.in_channels = in_channels
|
|
26
|
+
self.down_channels = down_channels
|
|
27
|
+
self.mid_channels = mid_channels
|
|
28
|
+
self.up_channels = up_channels
|
|
29
|
+
self.down_sampling = down_sampling
|
|
30
|
+
self.time_embed_dim = time_embed_dim
|
|
31
|
+
self.y_embed_dim = y_embed_dim
|
|
32
|
+
self.num_down_blocks = num_down_blocks
|
|
33
|
+
self.num_mid_blocks = num_mid_blocks
|
|
34
|
+
self.num_up_blocks = num_up_blocks
|
|
35
|
+
self.dropout_rate = dropout_rate
|
|
36
|
+
self.where_y = where_y
|
|
37
|
+
self.up_sampling = list(reversed(self.down_sampling))
|
|
38
|
+
self.conv1 = nn.Conv2d(
|
|
39
|
+
in_channels=self.in_channels,
|
|
40
|
+
out_channels=self.down_channels[0],
|
|
41
|
+
kernel_size=3,
|
|
42
|
+
padding=1
|
|
43
|
+
)
|
|
44
|
+
# initial time embedding projection
|
|
45
|
+
self.time_projection = nn.Sequential(
|
|
46
|
+
nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim),
|
|
47
|
+
nn.SiLU(),
|
|
48
|
+
nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim)
|
|
49
|
+
)
|
|
50
|
+
# down blocks
|
|
51
|
+
self.down_blocks = nn.ModuleList([
|
|
52
|
+
DownBlock(
|
|
53
|
+
in_channels=self.down_channels[i],
|
|
54
|
+
out_channels=self.down_channels[i+1],
|
|
55
|
+
time_embed_dim=self.time_embed_dim,
|
|
56
|
+
y_embed_dim=y_embed_dim,
|
|
57
|
+
num_layers=self.num_down_blocks,
|
|
58
|
+
down_sampling_factor=down_sampling_factor,
|
|
59
|
+
down_sample=self.down_sampling[i],
|
|
60
|
+
dropout_rate=self.dropout_rate,
|
|
61
|
+
y_to_all=y_to_all
|
|
62
|
+
) for i in range(len(self.down_channels)-1)
|
|
63
|
+
])
|
|
64
|
+
# middle blocks
|
|
65
|
+
self.mid_blocks = nn.ModuleList([
|
|
66
|
+
MiddleBlock(
|
|
67
|
+
in_channels=self.mid_channels[i],
|
|
68
|
+
out_channels=self.mid_channels[i + 1],
|
|
69
|
+
time_embed_dim=self.time_embed_dim,
|
|
70
|
+
y_embed_dim=y_embed_dim,
|
|
71
|
+
num_layers=self.num_mid_blocks,
|
|
72
|
+
dropout_rate=self.dropout_rate,
|
|
73
|
+
y_to_all=y_to_all
|
|
74
|
+
) for i in range(len(self.mid_channels) - 1)
|
|
75
|
+
])
|
|
76
|
+
# up blocks
|
|
77
|
+
skip_channels = list(reversed(self.down_channels))
|
|
78
|
+
self.up_blocks = nn.ModuleList([
|
|
79
|
+
UpBlock(
|
|
80
|
+
in_channels=self.up_channels[i],
|
|
81
|
+
out_channels=self.up_channels[i+1],
|
|
82
|
+
skip_channels=skip_channels[i],
|
|
83
|
+
time_embed_dim=self.time_embed_dim,
|
|
84
|
+
y_embed_dim=y_embed_dim,
|
|
85
|
+
num_layers=self.num_up_blocks,
|
|
86
|
+
up_sampling_factor=down_sampling_factor,
|
|
87
|
+
up_sampling=self.up_sampling[i],
|
|
88
|
+
dropout_rate=self.dropout_rate,
|
|
89
|
+
y_to_all=y_to_all
|
|
90
|
+
) for i in range(len(self.up_channels)-1)
|
|
91
|
+
])
|
|
92
|
+
# final convolution layer
|
|
93
|
+
self.conv2 = nn.Sequential(
|
|
94
|
+
nn.GroupNorm(num_groups=8, num_channels=self.up_channels[-1]),
|
|
95
|
+
nn.Dropout(p=self.dropout_rate),
|
|
96
|
+
nn.Conv2d(in_channels=self.up_channels[-1], out_channels=self.in_channels, kernel_size=3, padding=1)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def initialize_weights(self):
|
|
100
|
+
"""Initialize model weights for better training stability"""
|
|
101
|
+
for module in self.modules():
|
|
102
|
+
if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
|
|
103
|
+
nn.init.kaiming_normal_(module.weight, a=0.2, nonlinearity='leaky_relu')
|
|
104
|
+
if module.bias is not None:
|
|
105
|
+
nn.init.zeros_(module.bias)
|
|
106
|
+
|
|
107
|
+
def forward(self, x, t, y=None):
|
|
108
|
+
|
|
109
|
+
if not self.where_y and y is not None:
|
|
110
|
+
x = torch.cat(tensors=[x, y], dim=1)
|
|
111
|
+
output = self.conv1(x)
|
|
112
|
+
time_embed = GetEmbeddedTime(embed_dim=self.time_embed_dim)(time_steps=t)
|
|
113
|
+
time_embed = self.time_projection(time_embed)
|
|
114
|
+
skip_connections = []
|
|
115
|
+
|
|
116
|
+
for i, down in enumerate(self.down_blocks):
|
|
117
|
+
skip_connections.append(output)
|
|
118
|
+
output = down(x=output, embed_time=time_embed, y=y)
|
|
119
|
+
for i, mid in enumerate(self.mid_blocks):
|
|
120
|
+
output = mid(x=output, embed_time=time_embed, y=y)
|
|
121
|
+
for i, up in enumerate(self.up_blocks):
|
|
122
|
+
skip_connection = skip_connections.pop()
|
|
123
|
+
output = up(x=output, skip_connection=skip_connection, embed_time=time_embed, y=y)
|
|
124
|
+
|
|
125
|
+
output = self.conv2(output)
|
|
126
|
+
return output
|
|
127
|
+
#-----------------------------------------------------------------------------
|
|
128
|
+
class DownBlock(nn.Module):
|
|
129
|
+
def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim,num_layers, down_sampling_factor, down_sample, dropout_rate, y_to_all):
|
|
130
|
+
super().__init__()
|
|
131
|
+
self.num_layers = num_layers
|
|
132
|
+
self.y_to_all = y_to_all
|
|
133
|
+
self.conv1 = nn.ModuleList([
|
|
134
|
+
Conv3(
|
|
135
|
+
in_channels=in_channels if i==0 else out_channels,
|
|
136
|
+
out_channels=out_channels,
|
|
137
|
+
num_groups=8,
|
|
138
|
+
kernel_size=3,
|
|
139
|
+
norm=True,
|
|
140
|
+
activation=True,
|
|
141
|
+
dropout_rate=dropout_rate
|
|
142
|
+
) for i in range(self.num_layers)
|
|
143
|
+
])
|
|
144
|
+
self.conv2 = nn.ModuleList([
|
|
145
|
+
Conv3(
|
|
146
|
+
in_channels=out_channels,
|
|
147
|
+
out_channels=out_channels,
|
|
148
|
+
num_groups=8,
|
|
149
|
+
kernel_size=3,
|
|
150
|
+
norm=True,
|
|
151
|
+
activation=True,
|
|
152
|
+
dropout_rate=dropout_rate
|
|
153
|
+
) for _ in range(self.num_layers)
|
|
154
|
+
])
|
|
155
|
+
self.time_embedding = nn.ModuleList([
|
|
156
|
+
TimeEmbedding(
|
|
157
|
+
output_dim=out_channels,
|
|
158
|
+
embed_dim=time_embed_dim
|
|
159
|
+
) for _ in range(self.num_layers)
|
|
160
|
+
])
|
|
161
|
+
self.attention = nn.ModuleList([
|
|
162
|
+
Attention(
|
|
163
|
+
in_channels=out_channels,
|
|
164
|
+
y_embed_dim= y_embed_dim,
|
|
165
|
+
num_groups=8,
|
|
166
|
+
num_heads=4,
|
|
167
|
+
dropout_rate=dropout_rate
|
|
168
|
+
) for _ in range(self.num_layers)
|
|
169
|
+
])
|
|
170
|
+
self.down_sampling = DownSampling(
|
|
171
|
+
in_channels=out_channels,
|
|
172
|
+
out_channels=out_channels,
|
|
173
|
+
down_sampling_factor=down_sampling_factor,
|
|
174
|
+
conv_block=True,
|
|
175
|
+
max_pool=True
|
|
176
|
+
) if down_sample else nn.Identity()
|
|
177
|
+
self.resnet = nn.ModuleList([
|
|
178
|
+
nn.Conv2d(
|
|
179
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
180
|
+
out_channels=out_channels,
|
|
181
|
+
kernel_size=1
|
|
182
|
+
) for i in range(num_layers)
|
|
183
|
+
|
|
184
|
+
])
|
|
185
|
+
|
|
186
|
+
def forward(self, x, embed_time, y):
|
|
187
|
+
#print("down-block input shape:", x.size())
|
|
188
|
+
output = x
|
|
189
|
+
for i in range(self.num_layers):
|
|
190
|
+
resnet_input = output
|
|
191
|
+
output = self.conv1[i](output)
|
|
192
|
+
output = output + self.time_embedding[i](embed_time)[:, :, None, None]
|
|
193
|
+
output = self.conv2[i](output)
|
|
194
|
+
output = output + self.resnet[i](resnet_input)
|
|
195
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
196
|
+
out_attn = self.attention[i](output, y)
|
|
197
|
+
output = output + out_attn
|
|
198
|
+
elif y is not None and self.y_to_all:
|
|
199
|
+
out_attn = self.attention[i](output, y)
|
|
200
|
+
output = output + out_attn
|
|
201
|
+
elif y is None and self.y_to_all:
|
|
202
|
+
out_attn = self.attention[i](output)
|
|
203
|
+
output = output + out_attn
|
|
204
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
205
|
+
out_attn = self.attention[i](output)
|
|
206
|
+
output = output + out_attn
|
|
207
|
+
|
|
208
|
+
output = self.down_sampling(output)
|
|
209
|
+
#print("down-block output shape:", output.size())
|
|
210
|
+
return output
|
|
211
|
+
#------------------------------------------------------------------------------
|
|
212
|
+
class MiddleBlock(nn.Module):
|
|
213
|
+
def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim, num_layers, dropout_rate, y_to_all=False):
|
|
214
|
+
super().__init__()
|
|
215
|
+
self.num_layers = num_layers
|
|
216
|
+
self.y_to_all = y_to_all
|
|
217
|
+
self.conv1 = nn.ModuleList([
|
|
218
|
+
Conv3(
|
|
219
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
220
|
+
out_channels=out_channels,
|
|
221
|
+
num_groups=8,
|
|
222
|
+
kernel_size=3,
|
|
223
|
+
norm=True,
|
|
224
|
+
activation=True,
|
|
225
|
+
dropout_rate=dropout_rate
|
|
226
|
+
) for i in range(self.num_layers+1)
|
|
227
|
+
])
|
|
228
|
+
self.conv2 = nn.ModuleList([
|
|
229
|
+
Conv3(
|
|
230
|
+
in_channels=out_channels,
|
|
231
|
+
out_channels=out_channels,
|
|
232
|
+
num_groups=8,
|
|
233
|
+
kernel_size=3,
|
|
234
|
+
norm=True,
|
|
235
|
+
activation=True,
|
|
236
|
+
dropout_rate=dropout_rate
|
|
237
|
+
) for _ in range(self.num_layers+1)
|
|
238
|
+
])
|
|
239
|
+
self.time_embedding = nn.ModuleList([
|
|
240
|
+
TimeEmbedding(
|
|
241
|
+
output_dim=out_channels,
|
|
242
|
+
embed_dim=time_embed_dim
|
|
243
|
+
) for _ in range(self.num_layers+1)
|
|
244
|
+
])
|
|
245
|
+
self.attention = nn.ModuleList([
|
|
246
|
+
Attention(
|
|
247
|
+
in_channels=out_channels,
|
|
248
|
+
y_embed_dim=y_embed_dim,
|
|
249
|
+
num_groups=8,
|
|
250
|
+
num_heads=4,
|
|
251
|
+
dropout_rate=dropout_rate
|
|
252
|
+
) for _ in range(self.num_layers + 1)
|
|
253
|
+
])
|
|
254
|
+
self.resnet = nn.ModuleList([
|
|
255
|
+
nn.Conv2d(
|
|
256
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
257
|
+
out_channels=out_channels,
|
|
258
|
+
kernel_size=1
|
|
259
|
+
) for i in range(num_layers+1)
|
|
260
|
+
])
|
|
261
|
+
|
|
262
|
+
def forward(self, x, embed_time, y=None):
|
|
263
|
+
#print("mid-input shape:", x.size())
|
|
264
|
+
output = x
|
|
265
|
+
resnet_input = output
|
|
266
|
+
output = self.conv1[0](output)
|
|
267
|
+
output = output + self.time_embedding[0](embed_time)[:, :, None, None]
|
|
268
|
+
output = self.conv2[0](output)
|
|
269
|
+
output = output + self.resnet[0](resnet_input)
|
|
270
|
+
for i in range(self.num_layers):
|
|
271
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
272
|
+
out_attn = self.attention[i](output, y)
|
|
273
|
+
output = output + out_attn
|
|
274
|
+
elif y is not None and self.y_to_all:
|
|
275
|
+
out_attn = self.attention[i](output, y)
|
|
276
|
+
output = output + out_attn
|
|
277
|
+
elif y is None and self.y_to_all:
|
|
278
|
+
out_attn = self.attention[i](output)
|
|
279
|
+
output = output + out_attn
|
|
280
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
281
|
+
out_attn = self.attention[i](output)
|
|
282
|
+
output = output + out_attn
|
|
283
|
+
resnet_input = output
|
|
284
|
+
output = self.conv1[i + 1](output)
|
|
285
|
+
output = output + self.time_embedding[i + 1](embed_time)[:, :, None, None]
|
|
286
|
+
output = self.conv2[i + 1](output)
|
|
287
|
+
output = output + self.resnet[i+1](resnet_input)
|
|
288
|
+
#print("mid-block output shape:", output.size())
|
|
289
|
+
|
|
290
|
+
return output
|
|
291
|
+
#------------------------------------------------------------------------------
|
|
292
|
+
class UpBlock(nn.Module):
|
|
293
|
+
def __init__(self, in_channels, out_channels, skip_channels, time_embed_dim, y_embed_dim, num_layers, up_sampling_factor, up_sampling=True, dropout_rate=0.2, y_to_all=False):
|
|
294
|
+
super().__init__()
|
|
295
|
+
self.num_layers = num_layers
|
|
296
|
+
self.y_to_all = y_to_all
|
|
297
|
+
effective_in_channels = in_channels//2 + skip_channels
|
|
298
|
+
self.conv1 = nn.ModuleList([
|
|
299
|
+
Conv3(
|
|
300
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
301
|
+
out_channels=out_channels,
|
|
302
|
+
num_groups=8,
|
|
303
|
+
kernel_size=3,
|
|
304
|
+
norm=True,
|
|
305
|
+
activation=True,
|
|
306
|
+
dropout_rate=dropout_rate
|
|
307
|
+
) for i in range(self.num_layers)
|
|
308
|
+
])
|
|
309
|
+
self.conv2 = nn.ModuleList([
|
|
310
|
+
Conv3(
|
|
311
|
+
in_channels=out_channels,
|
|
312
|
+
out_channels=out_channels,
|
|
313
|
+
num_groups=8,
|
|
314
|
+
kernel_size=3,
|
|
315
|
+
norm=True,
|
|
316
|
+
activation=True,
|
|
317
|
+
dropout_rate=dropout_rate
|
|
318
|
+
) for _ in range(self.num_layers)
|
|
319
|
+
])
|
|
320
|
+
self.time_embedding = nn.ModuleList([
|
|
321
|
+
TimeEmbedding(
|
|
322
|
+
output_dim=out_channels,
|
|
323
|
+
embed_dim=time_embed_dim
|
|
324
|
+
) for _ in range(self.num_layers)
|
|
325
|
+
])
|
|
326
|
+
self.attention = nn.ModuleList([
|
|
327
|
+
Attention(
|
|
328
|
+
in_channels=out_channels,
|
|
329
|
+
y_embed_dim=y_embed_dim,
|
|
330
|
+
num_groups=8,
|
|
331
|
+
num_heads=4,
|
|
332
|
+
dropout_rate=dropout_rate
|
|
333
|
+
) for _ in range(self.num_layers)
|
|
334
|
+
])
|
|
335
|
+
self.up_sampling = UpSampling(
|
|
336
|
+
in_channels=in_channels,
|
|
337
|
+
out_channels=in_channels,
|
|
338
|
+
up_sampling_factor=up_sampling_factor,
|
|
339
|
+
conv_block=True,
|
|
340
|
+
up_sampling=True
|
|
341
|
+
) if up_sampling else nn.Identity()
|
|
342
|
+
self.resnet = nn.ModuleList([
|
|
343
|
+
nn.Conv2d(
|
|
344
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
345
|
+
out_channels=out_channels,
|
|
346
|
+
kernel_size=1
|
|
347
|
+
) for i in range(num_layers)
|
|
348
|
+
|
|
349
|
+
])
|
|
350
|
+
|
|
351
|
+
def forward(self, x, skip_connection, embed_time, y=None):
|
|
352
|
+
#print("up-block input shape:", x.size())
|
|
353
|
+
x = self.up_sampling(x)
|
|
354
|
+
x = torch.cat(tensors=[x, skip_connection], dim=1)
|
|
355
|
+
output = x
|
|
356
|
+
for i in range(self.num_layers):
|
|
357
|
+
resnet_input = output
|
|
358
|
+
output = self.conv1[i](output)
|
|
359
|
+
output = output + self.time_embedding[i](embed_time)[:, :, None, None]
|
|
360
|
+
output = self.conv2[i](output)
|
|
361
|
+
output = output + self.resnet[i](resnet_input)
|
|
362
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
363
|
+
out_attn = self.attention[i](output, y)
|
|
364
|
+
output = output + out_attn
|
|
365
|
+
elif y is not None and self.y_to_all:
|
|
366
|
+
out_attn = self.attention[i](output, y)
|
|
367
|
+
output = output + out_attn
|
|
368
|
+
elif y is None and self.y_to_all:
|
|
369
|
+
out_attn = self.attention[i](output)
|
|
370
|
+
output = output + out_attn
|
|
371
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
372
|
+
out_attn = self.attention[i](output)
|
|
373
|
+
output = output + out_attn
|
|
374
|
+
#print("up-block output shape:", output.size())
|
|
375
|
+
return output
|
|
376
|
+
#------------------------------------------------------------------------
|
|
377
|
+
class Conv3(nn.Module):
|
|
378
|
+
def __init__(self, in_channels, out_channels, num_groups=8, kernel_size=3, norm=True, activation=True, dropout_rate=0.2):
|
|
379
|
+
super().__init__()
|
|
380
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
|
|
381
|
+
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels) if norm else nn.Identity()
|
|
382
|
+
self.activation = nn.SiLU() if activation else nn.Identity()
|
|
383
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
384
|
+
|
|
385
|
+
def forward(self, batch):
|
|
386
|
+
batch = self.conv(batch)
|
|
387
|
+
batch = self.group_norm(batch)
|
|
388
|
+
batch = self.activation(batch)
|
|
389
|
+
batch = self.dropout(batch)
|
|
390
|
+
return batch
|
|
391
|
+
#----------------------------------------------------------------
|
|
392
|
+
class TimeEmbedding(nn.Module):
|
|
393
|
+
def __init__(self, output_dim, embed_dim):
|
|
394
|
+
super().__init__()
|
|
395
|
+
self.embedding = nn.Sequential(
|
|
396
|
+
nn.SiLU(),
|
|
397
|
+
nn.Linear(in_features=embed_dim, out_features=output_dim)
|
|
398
|
+
)
|
|
399
|
+
def forward(self, batch):
|
|
400
|
+
return self.embedding(batch)
|
|
401
|
+
#----------------------------------------------------------------
|
|
402
|
+
class GetEmbeddedTime(nn.Module):
|
|
403
|
+
def __init__(self, embed_dim):
|
|
404
|
+
super().__init__()
|
|
405
|
+
assert embed_dim % 2 == 0, "The embedding dimension must be divisible by two"
|
|
406
|
+
self.embed_dim = embed_dim
|
|
407
|
+
|
|
408
|
+
def forward(self, time_steps):
|
|
409
|
+
i = torch.arange(start=0, end=self.embed_dim // 2, dtype=torch.float32, device=time_steps.device)
|
|
410
|
+
factor = 10000 ** (2 * i / self.embed_dim)
|
|
411
|
+
embed_time = time_steps[:, None] / factor
|
|
412
|
+
embed_time = torch.cat(tensors=[torch.sin(embed_time), torch.cos(embed_time)], dim=-1)
|
|
413
|
+
return embed_time
|
|
414
|
+
#----------------------------------------------------------------
|
|
415
|
+
class Attention(nn.Module):
|
|
416
|
+
def __init__(self, in_channels, y_embed_dim=768, num_heads=4, num_groups=8, dropout_rate=0.1):
|
|
417
|
+
super().__init__()
|
|
418
|
+
self.in_channels = in_channels
|
|
419
|
+
self.y_embed_dim = y_embed_dim
|
|
420
|
+
self.num_heads = num_heads
|
|
421
|
+
self.dropout_rate = dropout_rate
|
|
422
|
+
self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True)
|
|
423
|
+
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
|
|
424
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
425
|
+
self.y_projection = nn.Linear(y_embed_dim, in_channels)
|
|
426
|
+
|
|
427
|
+
def forward(self, x, y=None):
|
|
428
|
+
batch_size, channels, h, w = x.shape
|
|
429
|
+
assert channels == self.in_channels, f"Expected {self.in_channels} channels, got {channels}"
|
|
430
|
+
x_reshaped = x.view(batch_size, channels, h * w).permute(0, 2, 1)
|
|
431
|
+
if y is not None:
|
|
432
|
+
y = self.y_projection(y)
|
|
433
|
+
if y.dim() != 3:
|
|
434
|
+
if y.dim() == 2:
|
|
435
|
+
y = y.unsqueeze(1)
|
|
436
|
+
else:
|
|
437
|
+
raise ValueError(
|
|
438
|
+
f"Expected y to be 2D or 3D after projection, got {y.dim()}D with shape {y.shape}"
|
|
439
|
+
)
|
|
440
|
+
if y.shape[-1] != self.in_channels:
|
|
441
|
+
raise ValueError(
|
|
442
|
+
f"Expected y's embedding dim to match in_channels ({self.in_channels}), got {y.shape[-1]}"
|
|
443
|
+
)
|
|
444
|
+
out, _ = self.attention(x_reshaped, y, y)
|
|
445
|
+
else:
|
|
446
|
+
out, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
|
|
447
|
+
out = out.permute(0, 2, 1).view(batch_size, channels, h, w)
|
|
448
|
+
out = self.norm(out)
|
|
449
|
+
out = self.dropout(out)
|
|
450
|
+
return out
|
|
451
|
+
#-----------------------------------------------------------------
|
|
452
|
+
class DownSampling(nn.Module):
|
|
453
|
+
def __init__(self, in_channels, out_channels, down_sampling_factor, conv_block=True, max_pool=True):
|
|
454
|
+
super().__init__()
|
|
455
|
+
self.conv_block = conv_block
|
|
456
|
+
self.max_pool = max_pool
|
|
457
|
+
self.down_sampling_factor = down_sampling_factor
|
|
458
|
+
self.conv = nn.Sequential(
|
|
459
|
+
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
|
|
460
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2 if max_pool else out_channels,
|
|
461
|
+
kernel_size=3, stride=down_sampling_factor, padding=1)
|
|
462
|
+
) if conv_block else nn.Identity()
|
|
463
|
+
self.pool = nn.Sequential(
|
|
464
|
+
nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
|
|
465
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels//2 if conv_block else out_channels,
|
|
466
|
+
kernel_size=1, stride=1, padding=0)
|
|
467
|
+
) if max_pool else nn.Identity()
|
|
468
|
+
|
|
469
|
+
def forward(self, batch):
|
|
470
|
+
if not self.conv_block:
|
|
471
|
+
return self.pool(batch)
|
|
472
|
+
if not self.max_pool:
|
|
473
|
+
return self.conv(batch)
|
|
474
|
+
return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
|
|
475
|
+
#--------------------------------------------------------------------------
|
|
476
|
+
class UpSampling(nn.Module):
|
|
477
|
+
def __init__(self, in_channels, out_channels, up_sampling_factor, conv_block=True, up_sampling=True):
|
|
478
|
+
super().__init__()
|
|
479
|
+
self.conv_block = conv_block
|
|
480
|
+
self.up_sampling = up_sampling
|
|
481
|
+
self.up_sampling_factor = up_sampling_factor
|
|
482
|
+
half_out_channels = out_channels // 2
|
|
483
|
+
self.conv = nn.Sequential(
|
|
484
|
+
nn.ConvTranspose2d(
|
|
485
|
+
in_channels=in_channels,
|
|
486
|
+
out_channels=half_out_channels if up_sampling else out_channels,
|
|
487
|
+
kernel_size=3,
|
|
488
|
+
stride=up_sampling_factor,
|
|
489
|
+
padding=1,
|
|
490
|
+
output_padding=up_sampling_factor - 1
|
|
491
|
+
),
|
|
492
|
+
nn.Conv2d(
|
|
493
|
+
in_channels=half_out_channels if up_sampling else out_channels,
|
|
494
|
+
out_channels=half_out_channels if up_sampling else out_channels,
|
|
495
|
+
kernel_size=1,
|
|
496
|
+
stride=1,
|
|
497
|
+
padding=0
|
|
498
|
+
)
|
|
499
|
+
) if conv_block else nn.Identity()
|
|
500
|
+
|
|
501
|
+
self.up_sample = nn.Sequential(
|
|
502
|
+
nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
|
|
503
|
+
nn.Conv2d(in_channels=in_channels, out_channels=half_out_channels if conv_block else out_channels,
|
|
504
|
+
kernel_size=1, stride=1, padding=0)
|
|
505
|
+
) if up_sampling else nn.Identity()
|
|
506
|
+
|
|
507
|
+
def forward(self, batch):
|
|
508
|
+
if not self.conv_block:
|
|
509
|
+
return self.up_sample(batch)
|
|
510
|
+
if not self.up_sampling:
|
|
511
|
+
return self.conv(batch)
|
|
512
|
+
conv_output = self.conv(batch)
|
|
513
|
+
up_sample_output = self.up_sample(batch)
|
|
514
|
+
if conv_output.shape[2:] != up_sample_output.shape[2:]:
|
|
515
|
+
_, _, h, w = conv_output.shape
|
|
516
|
+
up_sample_output = torch.nn.functional.interpolate(
|
|
517
|
+
up_sample_output,
|
|
518
|
+
size=(h, w),
|
|
519
|
+
mode='nearest'
|
|
520
|
+
)
|
|
521
|
+
return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
|
sde/reverse_sde.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ReverseSDE(nn.Module):
|
|
8
|
+
"""Reverse diffusion process for SDE-based generative models.
|
|
9
|
+
|
|
10
|
+
Implements the reverse diffusion process for score-based generative models using
|
|
11
|
+
Stochastic Differential Equations (SDEs), supporting Variance Exploding (VE),
|
|
12
|
+
Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE methods, as
|
|
13
|
+
described in Song et al. (2021). The reverse process denoises a noisy input using
|
|
14
|
+
predicted noise estimates.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
hyper_params : object
|
|
19
|
+
Hyperparameter object containing SDE-specific parameters. Expected to have
|
|
20
|
+
attributes:
|
|
21
|
+
- `dt`: Time step size for SDE integration (float).
|
|
22
|
+
- `sigmas`: Sigma values for VE method (torch.Tensor, optional).
|
|
23
|
+
- `betas`: Beta values for VP, sub-VP, or ODE methods (torch.Tensor).
|
|
24
|
+
- `cum_betas`: Cumulative beta values for sub-VP method (torch.Tensor, optional).
|
|
25
|
+
method : str
|
|
26
|
+
SDE method to use. Supported methods: "ve", "vp", "sub-vp", "ode".
|
|
27
|
+
|
|
28
|
+
Attributes
|
|
29
|
+
----------
|
|
30
|
+
hyper_params : object
|
|
31
|
+
Stores the provided hyperparameter object.
|
|
32
|
+
method : str
|
|
33
|
+
Selected SDE method.
|
|
34
|
+
|
|
35
|
+
Raises
|
|
36
|
+
------
|
|
37
|
+
ValueError
|
|
38
|
+
If `method` is not one of the supported methods ("ve", "vp", "sub-vp", "ode").
|
|
39
|
+
"""
|
|
40
|
+
def __init__(self, hyper_params, method):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.hyper_params = hyper_params
|
|
43
|
+
self.method = method
|
|
44
|
+
|
|
45
|
+
def forward(self, xt, noise, predicted_noise, time_steps):
|
|
46
|
+
"""Applies the reverse SDE diffusion process to the noisy input.
|
|
47
|
+
|
|
48
|
+
Denoises the input `xt` by applying the reverse SDE process, using predicted
|
|
49
|
+
noise estimates and optional stochastic noise, according to the specified SDE
|
|
50
|
+
method at given time steps. Incorporates drift and diffusion terms as applicable.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
xt : torch.Tensor
|
|
55
|
+
Noisy input tensor at time step `t`, shape (batch_size, channels, height, width).
|
|
56
|
+
noise : torch.Tensor or None
|
|
57
|
+
Gaussian noise tensor, same shape as `xt`, used for stochasticity. If None,
|
|
58
|
+
no stochastic noise is added (e.g., for deterministic ODE).
|
|
59
|
+
predicted_noise : torch.Tensor
|
|
60
|
+
Predicted noise tensor, same shape as `xt`, typically output by a neural network.
|
|
61
|
+
time_steps : torch.Tensor
|
|
62
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
63
|
+
is in the range [0, hyper_params.num_steps - 1].
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
torch.Tensor
|
|
68
|
+
Denoised tensor at the previous time step, same shape as `xt`.
|
|
69
|
+
|
|
70
|
+
Raises
|
|
71
|
+
------
|
|
72
|
+
ValueError
|
|
73
|
+
If `method` is not one of the supported methods ("ve", "vp", "sub-vp", "ode").
|
|
74
|
+
|
|
75
|
+
Notes
|
|
76
|
+
-----
|
|
77
|
+
- For the "ve" and "ode" methods, the output is clamped to [-1e5, 1e5] to prevent
|
|
78
|
+
numerical instability.
|
|
79
|
+
- Stochastic noise (`noise`) is only added if provided and the method supports it
|
|
80
|
+
(not applicable for "ode" in non-VE cases).
|
|
81
|
+
"""
|
|
82
|
+
dt = self.hyper_params.dt
|
|
83
|
+
betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
|
|
84
|
+
cum_betas = self.hyper_params.cum_betas[time_steps].view(-1, 1, 1, 1)
|
|
85
|
+
if self.method == "ve":
|
|
86
|
+
sigma_t = self.hyper_params.sigmas[time_steps]
|
|
87
|
+
sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
|
|
88
|
+
sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
|
|
89
|
+
drift = -(sigma_t ** 2 - sigma_t_prev ** 2).view(-1, 1, 1, 1) * predicted_noise * dt
|
|
90
|
+
diffusion = sigma_diff.view(-1, 1, 1, 1) * noise if noise is not None else 0
|
|
91
|
+
xt = xt + drift + diffusion
|
|
92
|
+
xt = torch.clamp(xt, -1e5, 1e5)
|
|
93
|
+
|
|
94
|
+
elif self.method == "vp":
|
|
95
|
+
drift = -0.5 * betas * xt * dt - betas * predicted_noise * dt
|
|
96
|
+
diffusion = torch.sqrt(betas * dt) * noise if noise is not None else 0
|
|
97
|
+
xt = xt + drift + diffusion
|
|
98
|
+
|
|
99
|
+
elif self.method == "sub-vp":
|
|
100
|
+
drift = -0.5 * betas * xt * dt - betas * (1 - torch.exp(-2 * cum_betas)) * predicted_noise * dt
|
|
101
|
+
diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise if noise is not None else 0
|
|
102
|
+
xt = xt + drift + diffusion
|
|
103
|
+
|
|
104
|
+
elif self.method == "ode":
|
|
105
|
+
if self.method == "ve":
|
|
106
|
+
sigma_t = self.hyper_params.sigmas[time_steps]
|
|
107
|
+
sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
|
|
108
|
+
drift = -0.5 * (sigma_t ** 2 - sigma_t_prev ** 2).view(-1, 1, 1, 1) * predicted_noise * dt
|
|
109
|
+
else:
|
|
110
|
+
drift = -0.5 * betas * xt * dt - 0.5 * betas * predicted_noise * dt
|
|
111
|
+
xt = xt + drift
|
|
112
|
+
xt = torch.clamp(xt, -1e5, 1e5)
|
|
113
|
+
else:
|
|
114
|
+
raise ValueError(f"Unknown method: {self.method}")
|
|
115
|
+
return xt
|