TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
@@ -0,0 +1,521 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ class NoisePredictor(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_channels,
10
+ down_channels,
11
+ mid_channels,
12
+ up_channels,
13
+ down_sampling,
14
+ time_embed_dim,
15
+ y_embed_dim, # output embedding dimension in text conditional net
16
+ num_down_blocks,
17
+ num_mid_blocks,
18
+ num_up_blocks,
19
+ dropout_rate=0.1,
20
+ down_sampling_factor=2,
21
+ where_y=True,
22
+ y_to_all=False
23
+ ):
24
+ super().__init__()
25
+ self.in_channels = in_channels
26
+ self.down_channels = down_channels
27
+ self.mid_channels = mid_channels
28
+ self.up_channels = up_channels
29
+ self.down_sampling = down_sampling
30
+ self.time_embed_dim = time_embed_dim
31
+ self.y_embed_dim = y_embed_dim
32
+ self.num_down_blocks = num_down_blocks
33
+ self.num_mid_blocks = num_mid_blocks
34
+ self.num_up_blocks = num_up_blocks
35
+ self.dropout_rate = dropout_rate
36
+ self.where_y = where_y
37
+ self.up_sampling = list(reversed(self.down_sampling))
38
+ self.conv1 = nn.Conv2d(
39
+ in_channels=self.in_channels,
40
+ out_channels=self.down_channels[0],
41
+ kernel_size=3,
42
+ padding=1
43
+ )
44
+ # initial time embedding projection
45
+ self.time_projection = nn.Sequential(
46
+ nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim),
47
+ nn.SiLU(),
48
+ nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim)
49
+ )
50
+ # down blocks
51
+ self.down_blocks = nn.ModuleList([
52
+ DownBlock(
53
+ in_channels=self.down_channels[i],
54
+ out_channels=self.down_channels[i+1],
55
+ time_embed_dim=self.time_embed_dim,
56
+ y_embed_dim=y_embed_dim,
57
+ num_layers=self.num_down_blocks,
58
+ down_sampling_factor=down_sampling_factor,
59
+ down_sample=self.down_sampling[i],
60
+ dropout_rate=self.dropout_rate,
61
+ y_to_all=y_to_all
62
+ ) for i in range(len(self.down_channels)-1)
63
+ ])
64
+ # middle blocks
65
+ self.mid_blocks = nn.ModuleList([
66
+ MiddleBlock(
67
+ in_channels=self.mid_channels[i],
68
+ out_channels=self.mid_channels[i + 1],
69
+ time_embed_dim=self.time_embed_dim,
70
+ y_embed_dim=y_embed_dim,
71
+ num_layers=self.num_mid_blocks,
72
+ dropout_rate=self.dropout_rate,
73
+ y_to_all=y_to_all
74
+ ) for i in range(len(self.mid_channels) - 1)
75
+ ])
76
+ # up blocks
77
+ skip_channels = list(reversed(self.down_channels))
78
+ self.up_blocks = nn.ModuleList([
79
+ UpBlock(
80
+ in_channels=self.up_channels[i],
81
+ out_channels=self.up_channels[i+1],
82
+ skip_channels=skip_channels[i],
83
+ time_embed_dim=self.time_embed_dim,
84
+ y_embed_dim=y_embed_dim,
85
+ num_layers=self.num_up_blocks,
86
+ up_sampling_factor=down_sampling_factor,
87
+ up_sampling=self.up_sampling[i],
88
+ dropout_rate=self.dropout_rate,
89
+ y_to_all=y_to_all
90
+ ) for i in range(len(self.up_channels)-1)
91
+ ])
92
+ # final convolution layer
93
+ self.conv2 = nn.Sequential(
94
+ nn.GroupNorm(num_groups=8, num_channels=self.up_channels[-1]),
95
+ nn.Dropout(p=self.dropout_rate),
96
+ nn.Conv2d(in_channels=self.up_channels[-1], out_channels=self.in_channels, kernel_size=3, padding=1)
97
+ )
98
+
99
+ def initialize_weights(self):
100
+ """Initialize model weights for better training stability"""
101
+ for module in self.modules():
102
+ if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
103
+ nn.init.kaiming_normal_(module.weight, a=0.2, nonlinearity='leaky_relu')
104
+ if module.bias is not None:
105
+ nn.init.zeros_(module.bias)
106
+
107
+ def forward(self, x, t, y=None):
108
+
109
+ if not self.where_y and y is not None:
110
+ x = torch.cat(tensors=[x, y], dim=1)
111
+ output = self.conv1(x)
112
+ time_embed = GetEmbeddedTime(embed_dim=self.time_embed_dim)(time_steps=t)
113
+ time_embed = self.time_projection(time_embed)
114
+ skip_connections = []
115
+
116
+ for i, down in enumerate(self.down_blocks):
117
+ skip_connections.append(output)
118
+ output = down(x=output, embed_time=time_embed, y=y)
119
+ for i, mid in enumerate(self.mid_blocks):
120
+ output = mid(x=output, embed_time=time_embed, y=y)
121
+ for i, up in enumerate(self.up_blocks):
122
+ skip_connection = skip_connections.pop()
123
+ output = up(x=output, skip_connection=skip_connection, embed_time=time_embed, y=y)
124
+
125
+ output = self.conv2(output)
126
+ return output
127
+ #-----------------------------------------------------------------------------
128
+ class DownBlock(nn.Module):
129
+ def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim,num_layers, down_sampling_factor, down_sample, dropout_rate, y_to_all):
130
+ super().__init__()
131
+ self.num_layers = num_layers
132
+ self.y_to_all = y_to_all
133
+ self.conv1 = nn.ModuleList([
134
+ Conv3(
135
+ in_channels=in_channels if i==0 else out_channels,
136
+ out_channels=out_channels,
137
+ num_groups=8,
138
+ kernel_size=3,
139
+ norm=True,
140
+ activation=True,
141
+ dropout_rate=dropout_rate
142
+ ) for i in range(self.num_layers)
143
+ ])
144
+ self.conv2 = nn.ModuleList([
145
+ Conv3(
146
+ in_channels=out_channels,
147
+ out_channels=out_channels,
148
+ num_groups=8,
149
+ kernel_size=3,
150
+ norm=True,
151
+ activation=True,
152
+ dropout_rate=dropout_rate
153
+ ) for _ in range(self.num_layers)
154
+ ])
155
+ self.time_embedding = nn.ModuleList([
156
+ TimeEmbedding(
157
+ output_dim=out_channels,
158
+ embed_dim=time_embed_dim
159
+ ) for _ in range(self.num_layers)
160
+ ])
161
+ self.attention = nn.ModuleList([
162
+ Attention(
163
+ in_channels=out_channels,
164
+ y_embed_dim= y_embed_dim,
165
+ num_groups=8,
166
+ num_heads=4,
167
+ dropout_rate=dropout_rate
168
+ ) for _ in range(self.num_layers)
169
+ ])
170
+ self.down_sampling = DownSampling(
171
+ in_channels=out_channels,
172
+ out_channels=out_channels,
173
+ down_sampling_factor=down_sampling_factor,
174
+ conv_block=True,
175
+ max_pool=True
176
+ ) if down_sample else nn.Identity()
177
+ self.resnet = nn.ModuleList([
178
+ nn.Conv2d(
179
+ in_channels=in_channels if i == 0 else out_channels,
180
+ out_channels=out_channels,
181
+ kernel_size=1
182
+ ) for i in range(num_layers)
183
+
184
+ ])
185
+
186
+ def forward(self, x, embed_time, y):
187
+ #print("down-block input shape:", x.size())
188
+ output = x
189
+ for i in range(self.num_layers):
190
+ resnet_input = output
191
+ output = self.conv1[i](output)
192
+ output = output + self.time_embedding[i](embed_time)[:, :, None, None]
193
+ output = self.conv2[i](output)
194
+ output = output + self.resnet[i](resnet_input)
195
+ if y is not None and not self.y_to_all and i == 0:
196
+ out_attn = self.attention[i](output, y)
197
+ output = output + out_attn
198
+ elif y is not None and self.y_to_all:
199
+ out_attn = self.attention[i](output, y)
200
+ output = output + out_attn
201
+ elif y is None and self.y_to_all:
202
+ out_attn = self.attention[i](output)
203
+ output = output + out_attn
204
+ elif y is None and not self.y_to_all and i == 0:
205
+ out_attn = self.attention[i](output)
206
+ output = output + out_attn
207
+
208
+ output = self.down_sampling(output)
209
+ #print("down-block output shape:", output.size())
210
+ return output
211
+ #------------------------------------------------------------------------------
212
+ class MiddleBlock(nn.Module):
213
+ def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim, num_layers, dropout_rate, y_to_all=False):
214
+ super().__init__()
215
+ self.num_layers = num_layers
216
+ self.y_to_all = y_to_all
217
+ self.conv1 = nn.ModuleList([
218
+ Conv3(
219
+ in_channels=in_channels if i == 0 else out_channels,
220
+ out_channels=out_channels,
221
+ num_groups=8,
222
+ kernel_size=3,
223
+ norm=True,
224
+ activation=True,
225
+ dropout_rate=dropout_rate
226
+ ) for i in range(self.num_layers+1)
227
+ ])
228
+ self.conv2 = nn.ModuleList([
229
+ Conv3(
230
+ in_channels=out_channels,
231
+ out_channels=out_channels,
232
+ num_groups=8,
233
+ kernel_size=3,
234
+ norm=True,
235
+ activation=True,
236
+ dropout_rate=dropout_rate
237
+ ) for _ in range(self.num_layers+1)
238
+ ])
239
+ self.time_embedding = nn.ModuleList([
240
+ TimeEmbedding(
241
+ output_dim=out_channels,
242
+ embed_dim=time_embed_dim
243
+ ) for _ in range(self.num_layers+1)
244
+ ])
245
+ self.attention = nn.ModuleList([
246
+ Attention(
247
+ in_channels=out_channels,
248
+ y_embed_dim=y_embed_dim,
249
+ num_groups=8,
250
+ num_heads=4,
251
+ dropout_rate=dropout_rate
252
+ ) for _ in range(self.num_layers + 1)
253
+ ])
254
+ self.resnet = nn.ModuleList([
255
+ nn.Conv2d(
256
+ in_channels=in_channels if i == 0 else out_channels,
257
+ out_channels=out_channels,
258
+ kernel_size=1
259
+ ) for i in range(num_layers+1)
260
+ ])
261
+
262
+ def forward(self, x, embed_time, y=None):
263
+ #print("mid-input shape:", x.size())
264
+ output = x
265
+ resnet_input = output
266
+ output = self.conv1[0](output)
267
+ output = output + self.time_embedding[0](embed_time)[:, :, None, None]
268
+ output = self.conv2[0](output)
269
+ output = output + self.resnet[0](resnet_input)
270
+ for i in range(self.num_layers):
271
+ if y is not None and not self.y_to_all and i == 0:
272
+ out_attn = self.attention[i](output, y)
273
+ output = output + out_attn
274
+ elif y is not None and self.y_to_all:
275
+ out_attn = self.attention[i](output, y)
276
+ output = output + out_attn
277
+ elif y is None and self.y_to_all:
278
+ out_attn = self.attention[i](output)
279
+ output = output + out_attn
280
+ elif y is None and not self.y_to_all and i == 0:
281
+ out_attn = self.attention[i](output)
282
+ output = output + out_attn
283
+ resnet_input = output
284
+ output = self.conv1[i + 1](output)
285
+ output = output + self.time_embedding[i + 1](embed_time)[:, :, None, None]
286
+ output = self.conv2[i + 1](output)
287
+ output = output + self.resnet[i+1](resnet_input)
288
+ #print("mid-block output shape:", output.size())
289
+
290
+ return output
291
+ #------------------------------------------------------------------------------
292
+ class UpBlock(nn.Module):
293
+ def __init__(self, in_channels, out_channels, skip_channels, time_embed_dim, y_embed_dim, num_layers, up_sampling_factor, up_sampling=True, dropout_rate=0.2, y_to_all=False):
294
+ super().__init__()
295
+ self.num_layers = num_layers
296
+ self.y_to_all = y_to_all
297
+ effective_in_channels = in_channels//2 + skip_channels
298
+ self.conv1 = nn.ModuleList([
299
+ Conv3(
300
+ in_channels=effective_in_channels if i == 0 else out_channels,
301
+ out_channels=out_channels,
302
+ num_groups=8,
303
+ kernel_size=3,
304
+ norm=True,
305
+ activation=True,
306
+ dropout_rate=dropout_rate
307
+ ) for i in range(self.num_layers)
308
+ ])
309
+ self.conv2 = nn.ModuleList([
310
+ Conv3(
311
+ in_channels=out_channels,
312
+ out_channels=out_channels,
313
+ num_groups=8,
314
+ kernel_size=3,
315
+ norm=True,
316
+ activation=True,
317
+ dropout_rate=dropout_rate
318
+ ) for _ in range(self.num_layers)
319
+ ])
320
+ self.time_embedding = nn.ModuleList([
321
+ TimeEmbedding(
322
+ output_dim=out_channels,
323
+ embed_dim=time_embed_dim
324
+ ) for _ in range(self.num_layers)
325
+ ])
326
+ self.attention = nn.ModuleList([
327
+ Attention(
328
+ in_channels=out_channels,
329
+ y_embed_dim=y_embed_dim,
330
+ num_groups=8,
331
+ num_heads=4,
332
+ dropout_rate=dropout_rate
333
+ ) for _ in range(self.num_layers)
334
+ ])
335
+ self.up_sampling = UpSampling(
336
+ in_channels=in_channels,
337
+ out_channels=in_channels,
338
+ up_sampling_factor=up_sampling_factor,
339
+ conv_block=True,
340
+ up_sampling=True
341
+ ) if up_sampling else nn.Identity()
342
+ self.resnet = nn.ModuleList([
343
+ nn.Conv2d(
344
+ in_channels=effective_in_channels if i == 0 else out_channels,
345
+ out_channels=out_channels,
346
+ kernel_size=1
347
+ ) for i in range(num_layers)
348
+
349
+ ])
350
+
351
+ def forward(self, x, skip_connection, embed_time, y=None):
352
+ #print("up-block input shape:", x.size())
353
+ x = self.up_sampling(x)
354
+ x = torch.cat(tensors=[x, skip_connection], dim=1)
355
+ output = x
356
+ for i in range(self.num_layers):
357
+ resnet_input = output
358
+ output = self.conv1[i](output)
359
+ output = output + self.time_embedding[i](embed_time)[:, :, None, None]
360
+ output = self.conv2[i](output)
361
+ output = output + self.resnet[i](resnet_input)
362
+ if y is not None and not self.y_to_all and i == 0:
363
+ out_attn = self.attention[i](output, y)
364
+ output = output + out_attn
365
+ elif y is not None and self.y_to_all:
366
+ out_attn = self.attention[i](output, y)
367
+ output = output + out_attn
368
+ elif y is None and self.y_to_all:
369
+ out_attn = self.attention[i](output)
370
+ output = output + out_attn
371
+ elif y is None and not self.y_to_all and i == 0:
372
+ out_attn = self.attention[i](output)
373
+ output = output + out_attn
374
+ #print("up-block output shape:", output.size())
375
+ return output
376
+ #------------------------------------------------------------------------
377
+ class Conv3(nn.Module):
378
+ def __init__(self, in_channels, out_channels, num_groups=8, kernel_size=3, norm=True, activation=True, dropout_rate=0.2):
379
+ super().__init__()
380
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
381
+ self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels) if norm else nn.Identity()
382
+ self.activation = nn.SiLU() if activation else nn.Identity()
383
+ self.dropout = nn.Dropout(p=dropout_rate)
384
+
385
+ def forward(self, batch):
386
+ batch = self.conv(batch)
387
+ batch = self.group_norm(batch)
388
+ batch = self.activation(batch)
389
+ batch = self.dropout(batch)
390
+ return batch
391
+ #----------------------------------------------------------------
392
+ class TimeEmbedding(nn.Module):
393
+ def __init__(self, output_dim, embed_dim):
394
+ super().__init__()
395
+ self.embedding = nn.Sequential(
396
+ nn.SiLU(),
397
+ nn.Linear(in_features=embed_dim, out_features=output_dim)
398
+ )
399
+ def forward(self, batch):
400
+ return self.embedding(batch)
401
+ #----------------------------------------------------------------
402
+ class GetEmbeddedTime(nn.Module):
403
+ def __init__(self, embed_dim):
404
+ super().__init__()
405
+ assert embed_dim % 2 == 0, "The embedding dimension must be divisible by two"
406
+ self.embed_dim = embed_dim
407
+
408
+ def forward(self, time_steps):
409
+ i = torch.arange(start=0, end=self.embed_dim // 2, dtype=torch.float32, device=time_steps.device)
410
+ factor = 10000 ** (2 * i / self.embed_dim)
411
+ embed_time = time_steps[:, None] / factor
412
+ embed_time = torch.cat(tensors=[torch.sin(embed_time), torch.cos(embed_time)], dim=-1)
413
+ return embed_time
414
+ #----------------------------------------------------------------
415
+ class Attention(nn.Module):
416
+ def __init__(self, in_channels, y_embed_dim=768, num_heads=4, num_groups=8, dropout_rate=0.1):
417
+ super().__init__()
418
+ self.in_channels = in_channels
419
+ self.y_embed_dim = y_embed_dim
420
+ self.num_heads = num_heads
421
+ self.dropout_rate = dropout_rate
422
+ self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True)
423
+ self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
424
+ self.dropout = nn.Dropout(dropout_rate)
425
+ self.y_projection = nn.Linear(y_embed_dim, in_channels)
426
+
427
+ def forward(self, x, y=None):
428
+ batch_size, channels, h, w = x.shape
429
+ assert channels == self.in_channels, f"Expected {self.in_channels} channels, got {channels}"
430
+ x_reshaped = x.view(batch_size, channels, h * w).permute(0, 2, 1)
431
+ if y is not None:
432
+ y = self.y_projection(y)
433
+ if y.dim() != 3:
434
+ if y.dim() == 2:
435
+ y = y.unsqueeze(1)
436
+ else:
437
+ raise ValueError(
438
+ f"Expected y to be 2D or 3D after projection, got {y.dim()}D with shape {y.shape}"
439
+ )
440
+ if y.shape[-1] != self.in_channels:
441
+ raise ValueError(
442
+ f"Expected y's embedding dim to match in_channels ({self.in_channels}), got {y.shape[-1]}"
443
+ )
444
+ out, _ = self.attention(x_reshaped, y, y)
445
+ else:
446
+ out, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
447
+ out = out.permute(0, 2, 1).view(batch_size, channels, h, w)
448
+ out = self.norm(out)
449
+ out = self.dropout(out)
450
+ return out
451
+ #-----------------------------------------------------------------
452
+ class DownSampling(nn.Module):
453
+ def __init__(self, in_channels, out_channels, down_sampling_factor, conv_block=True, max_pool=True):
454
+ super().__init__()
455
+ self.conv_block = conv_block
456
+ self.max_pool = max_pool
457
+ self.down_sampling_factor = down_sampling_factor
458
+ self.conv = nn.Sequential(
459
+ nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
460
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2 if max_pool else out_channels,
461
+ kernel_size=3, stride=down_sampling_factor, padding=1)
462
+ ) if conv_block else nn.Identity()
463
+ self.pool = nn.Sequential(
464
+ nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
465
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels//2 if conv_block else out_channels,
466
+ kernel_size=1, stride=1, padding=0)
467
+ ) if max_pool else nn.Identity()
468
+
469
+ def forward(self, batch):
470
+ if not self.conv_block:
471
+ return self.pool(batch)
472
+ if not self.max_pool:
473
+ return self.conv(batch)
474
+ return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
475
+ #--------------------------------------------------------------------------
476
+ class UpSampling(nn.Module):
477
+ def __init__(self, in_channels, out_channels, up_sampling_factor, conv_block=True, up_sampling=True):
478
+ super().__init__()
479
+ self.conv_block = conv_block
480
+ self.up_sampling = up_sampling
481
+ self.up_sampling_factor = up_sampling_factor
482
+ half_out_channels = out_channels // 2
483
+ self.conv = nn.Sequential(
484
+ nn.ConvTranspose2d(
485
+ in_channels=in_channels,
486
+ out_channels=half_out_channels if up_sampling else out_channels,
487
+ kernel_size=3,
488
+ stride=up_sampling_factor,
489
+ padding=1,
490
+ output_padding=up_sampling_factor - 1
491
+ ),
492
+ nn.Conv2d(
493
+ in_channels=half_out_channels if up_sampling else out_channels,
494
+ out_channels=half_out_channels if up_sampling else out_channels,
495
+ kernel_size=1,
496
+ stride=1,
497
+ padding=0
498
+ )
499
+ ) if conv_block else nn.Identity()
500
+
501
+ self.up_sample = nn.Sequential(
502
+ nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
503
+ nn.Conv2d(in_channels=in_channels, out_channels=half_out_channels if conv_block else out_channels,
504
+ kernel_size=1, stride=1, padding=0)
505
+ ) if up_sampling else nn.Identity()
506
+
507
+ def forward(self, batch):
508
+ if not self.conv_block:
509
+ return self.up_sample(batch)
510
+ if not self.up_sampling:
511
+ return self.conv(batch)
512
+ conv_output = self.conv(batch)
513
+ up_sample_output = self.up_sample(batch)
514
+ if conv_output.shape[2:] != up_sample_output.shape[2:]:
515
+ _, _, h, w = conv_output.shape
516
+ up_sample_output = torch.nn.functional.interpolate(
517
+ up_sample_output,
518
+ size=(h, w),
519
+ mode='nearest'
520
+ )
521
+ return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
ddim/reverse_ddim.py ADDED
@@ -0,0 +1,91 @@
1
+ """Reverse diffusion process for Denoising Diffusion Implicit Models (DDIM).
2
+
3
+ This module implements the reverse diffusion process for DDIM, as described in Song et al.
4
+ (2021, "Denoising Diffusion Implicit Models"). The reverse process iteratively denoises a
5
+ noisy input to reconstruct the original data distribution using a subset of time steps.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+
13
+ class ReverseDDIM(nn.Module):
14
+ """Reverse diffusion process of DDIM.
15
+
16
+ Implements the reverse diffusion process for Denoising Diffusion Implicit Models
17
+ (DDIM), which denoises a noisy input `xt` using a predicted noise component and a
18
+ subsampled time step schedule, as defined in Song et al. (2021).
19
+
20
+ Parameters
21
+ ----------
22
+ hyper_params : object
23
+ Hyperparameter object containing the noise schedule parameters. Expected to have
24
+ attributes:
25
+ - `tau_num_steps`: Number of subsampled time steps (int).
26
+ - `eta`: Noise scaling factor for the reverse process (float).
27
+ - `get_tau_schedule`: Method to compute the subsampled noise schedule (callable),
28
+ returning a tuple of (betas, alphas, alpha_bars, sqrt_alpha_cumprod,
29
+ sqrt_one_minus_alpha_cumprod).
30
+
31
+ Attributes
32
+ ----------
33
+ hyper_params : object
34
+ Stores the provided hyperparameter object.
35
+ """
36
+ def __init__(self, hyper_params):
37
+ super().__init__()
38
+ self.hyper_params = hyper_params
39
+
40
+ def forward(self, xt, predicted_noise, time_steps, prev_time_steps):
41
+ """Applies the reverse diffusion process to the noisy input.
42
+
43
+ Denoises the input `xt` at time step `t` to produce the previous step `xt_prev`
44
+ at `prev_time_steps` using the predicted noise and the DDIM reverse process.
45
+ Optionally includes stochastic noise scaled by `eta`.
46
+
47
+ Parameters
48
+ ----------
49
+ xt : torch.Tensor
50
+ Noisy input tensor at time step `t`, shape (batch_size, channels, height, width).
51
+ predicted_noise : torch.Tensor
52
+ Predicted noise tensor, same shape as `xt`, typically output by a neural network.
53
+ time_steps : torch.Tensor
54
+ Tensor of time step indices (long), shape (batch_size,), where each value
55
+ is in the range [0, hyper_params.tau_num_steps - 1].
56
+ prev_time_steps : torch.Tensor
57
+ Tensor of previous time step indices (long), shape (batch_size,), where each
58
+ value is in the range [0, hyper_params.tau_num_steps - 1].
59
+
60
+ Returns
61
+ -------
62
+ tuple
63
+ A tuple containing:
64
+ - xt_prev: Denoised tensor at `prev_time_steps`, same shape as `xt`.
65
+ - x0: Estimated original data (t=0), same shape as `xt`.
66
+
67
+ Raises
68
+ ------
69
+ ValueError
70
+ If any value in `time_steps` or `prev_time_steps` is outside the valid range
71
+ [0, hyper_params.tau_num_steps - 1].
72
+ """
73
+ if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.tau_num_steps)):
74
+ raise ValueError(f"time_steps must be between 0 and {self.hyper_params.tau_num_steps - 1}")
75
+ if not torch.all((prev_time_steps >= 0) & (prev_time_steps < self.hyper_params.tau_num_steps)):
76
+ raise ValueError(f"prev_time_steps must be between 0 and {self.hyper_params.tau_num_steps - 1}")
77
+
78
+ _, _, _, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = self.hyper_params.get_tau_schedule()
79
+ tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
80
+ tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
81
+ prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
82
+ prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
83
+
84
+ eta = self.hyper_params.eta
85
+ x0 = (xt - tau_sqrt_one_minus_alpha_cumprod_t * predicted_noise) / tau_sqrt_alpha_cumprod_t
86
+ noise_coeff = eta * ((tau_sqrt_one_minus_alpha_cumprod_t / prev_tau_sqrt_alpha_cumprod_t) *
87
+ prev_tau_sqrt_one_minus_alpha_cumprod_t / torch.clamp(tau_sqrt_one_minus_alpha_cumprod_t, min=1e-8))
88
+ direction_coeff = torch.clamp(prev_tau_sqrt_one_minus_alpha_cumprod_t ** 2 - noise_coeff ** 2, min=1e-8).sqrt()
89
+ xt_prev = prev_tau_sqrt_alpha_cumprod_t * x0 + noise_coeff * torch.randn_like(xt) + direction_coeff * predicted_noise
90
+
91
+ return xt_prev, x0