TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
sde/noise_predictor.py ADDED
@@ -0,0 +1,521 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ class NoisePredictor(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_channels,
10
+ down_channels,
11
+ mid_channels,
12
+ up_channels,
13
+ down_sampling,
14
+ time_embed_dim,
15
+ y_embed_dim, # output embedding dimension in text conditional net
16
+ num_down_blocks,
17
+ num_mid_blocks,
18
+ num_up_blocks,
19
+ dropout_rate=0.1,
20
+ down_sampling_factor=2,
21
+ where_y=True,
22
+ y_to_all=False
23
+ ):
24
+ super().__init__()
25
+ self.in_channels = in_channels
26
+ self.down_channels = down_channels
27
+ self.mid_channels = mid_channels
28
+ self.up_channels = up_channels
29
+ self.down_sampling = down_sampling
30
+ self.time_embed_dim = time_embed_dim
31
+ self.y_embed_dim = y_embed_dim
32
+ self.num_down_blocks = num_down_blocks
33
+ self.num_mid_blocks = num_mid_blocks
34
+ self.num_up_blocks = num_up_blocks
35
+ self.dropout_rate = dropout_rate
36
+ self.where_y = where_y
37
+ self.up_sampling = list(reversed(self.down_sampling))
38
+ self.conv1 = nn.Conv2d(
39
+ in_channels=self.in_channels,
40
+ out_channels=self.down_channels[0],
41
+ kernel_size=3,
42
+ padding=1
43
+ )
44
+ # initial time embedding projection
45
+ self.time_projection = nn.Sequential(
46
+ nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim),
47
+ nn.SiLU(),
48
+ nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim)
49
+ )
50
+ # down blocks
51
+ self.down_blocks = nn.ModuleList([
52
+ DownBlock(
53
+ in_channels=self.down_channels[i],
54
+ out_channels=self.down_channels[i+1],
55
+ time_embed_dim=self.time_embed_dim,
56
+ y_embed_dim=y_embed_dim,
57
+ num_layers=self.num_down_blocks,
58
+ down_sampling_factor=down_sampling_factor,
59
+ down_sample=self.down_sampling[i],
60
+ dropout_rate=self.dropout_rate,
61
+ y_to_all=y_to_all
62
+ ) for i in range(len(self.down_channels)-1)
63
+ ])
64
+ # middle blocks
65
+ self.mid_blocks = nn.ModuleList([
66
+ MiddleBlock(
67
+ in_channels=self.mid_channels[i],
68
+ out_channels=self.mid_channels[i + 1],
69
+ time_embed_dim=self.time_embed_dim,
70
+ y_embed_dim=y_embed_dim,
71
+ num_layers=self.num_mid_blocks,
72
+ dropout_rate=self.dropout_rate,
73
+ y_to_all=y_to_all
74
+ ) for i in range(len(self.mid_channels) - 1)
75
+ ])
76
+ # up blocks
77
+ skip_channels = list(reversed(self.down_channels))
78
+ self.up_blocks = nn.ModuleList([
79
+ UpBlock(
80
+ in_channels=self.up_channels[i],
81
+ out_channels=self.up_channels[i+1],
82
+ skip_channels=skip_channels[i],
83
+ time_embed_dim=self.time_embed_dim,
84
+ y_embed_dim=y_embed_dim,
85
+ num_layers=self.num_up_blocks,
86
+ up_sampling_factor=down_sampling_factor,
87
+ up_sampling=self.up_sampling[i],
88
+ dropout_rate=self.dropout_rate,
89
+ y_to_all=y_to_all
90
+ ) for i in range(len(self.up_channels)-1)
91
+ ])
92
+ # final convolution layer
93
+ self.conv2 = nn.Sequential(
94
+ nn.GroupNorm(num_groups=8, num_channels=self.up_channels[-1]),
95
+ nn.Dropout(p=self.dropout_rate),
96
+ nn.Conv2d(in_channels=self.up_channels[-1], out_channels=self.in_channels, kernel_size=3, padding=1)
97
+ )
98
+
99
+ def initialize_weights(self):
100
+ """Initialize model weights for better training stability"""
101
+ for module in self.modules():
102
+ if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
103
+ nn.init.kaiming_normal_(module.weight, a=0.2, nonlinearity='leaky_relu')
104
+ if module.bias is not None:
105
+ nn.init.zeros_(module.bias)
106
+
107
+ def forward(self, x, t, y=None):
108
+
109
+ if not self.where_y and y is not None:
110
+ x = torch.cat(tensors=[x, y], dim=1)
111
+ output = self.conv1(x)
112
+ time_embed = GetEmbeddedTime(embed_dim=self.time_embed_dim)(time_steps=t)
113
+ time_embed = self.time_projection(time_embed)
114
+ skip_connections = []
115
+
116
+ for i, down in enumerate(self.down_blocks):
117
+ skip_connections.append(output)
118
+ output = down(x=output, embed_time=time_embed, y=y)
119
+ for i, mid in enumerate(self.mid_blocks):
120
+ output = mid(x=output, embed_time=time_embed, y=y)
121
+ for i, up in enumerate(self.up_blocks):
122
+ skip_connection = skip_connections.pop()
123
+ output = up(x=output, skip_connection=skip_connection, embed_time=time_embed, y=y)
124
+
125
+ output = self.conv2(output)
126
+ return output
127
+ #-----------------------------------------------------------------------------
128
+ class DownBlock(nn.Module):
129
+ def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim,num_layers, down_sampling_factor, down_sample, dropout_rate, y_to_all):
130
+ super().__init__()
131
+ self.num_layers = num_layers
132
+ self.y_to_all = y_to_all
133
+ self.conv1 = nn.ModuleList([
134
+ Conv3(
135
+ in_channels=in_channels if i==0 else out_channels,
136
+ out_channels=out_channels,
137
+ num_groups=8,
138
+ kernel_size=3,
139
+ norm=True,
140
+ activation=True,
141
+ dropout_rate=dropout_rate
142
+ ) for i in range(self.num_layers)
143
+ ])
144
+ self.conv2 = nn.ModuleList([
145
+ Conv3(
146
+ in_channels=out_channels,
147
+ out_channels=out_channels,
148
+ num_groups=8,
149
+ kernel_size=3,
150
+ norm=True,
151
+ activation=True,
152
+ dropout_rate=dropout_rate
153
+ ) for _ in range(self.num_layers)
154
+ ])
155
+ self.time_embedding = nn.ModuleList([
156
+ TimeEmbedding(
157
+ output_dim=out_channels,
158
+ embed_dim=time_embed_dim
159
+ ) for _ in range(self.num_layers)
160
+ ])
161
+ self.attention = nn.ModuleList([
162
+ Attention(
163
+ in_channels=out_channels,
164
+ y_embed_dim= y_embed_dim,
165
+ num_groups=8,
166
+ num_heads=4,
167
+ dropout_rate=dropout_rate
168
+ ) for _ in range(self.num_layers)
169
+ ])
170
+ self.down_sampling = DownSampling(
171
+ in_channels=out_channels,
172
+ out_channels=out_channels,
173
+ down_sampling_factor=down_sampling_factor,
174
+ conv_block=True,
175
+ max_pool=True
176
+ ) if down_sample else nn.Identity()
177
+ self.resnet = nn.ModuleList([
178
+ nn.Conv2d(
179
+ in_channels=in_channels if i == 0 else out_channels,
180
+ out_channels=out_channels,
181
+ kernel_size=1
182
+ ) for i in range(num_layers)
183
+
184
+ ])
185
+
186
+ def forward(self, x, embed_time, y):
187
+ #print("down-block input shape:", x.size())
188
+ output = x
189
+ for i in range(self.num_layers):
190
+ resnet_input = output
191
+ output = self.conv1[i](output)
192
+ output = output + self.time_embedding[i](embed_time)[:, :, None, None]
193
+ output = self.conv2[i](output)
194
+ output = output + self.resnet[i](resnet_input)
195
+ if y is not None and not self.y_to_all and i == 0:
196
+ out_attn = self.attention[i](output, y)
197
+ output = output + out_attn
198
+ elif y is not None and self.y_to_all:
199
+ out_attn = self.attention[i](output, y)
200
+ output = output + out_attn
201
+ elif y is None and self.y_to_all:
202
+ out_attn = self.attention[i](output)
203
+ output = output + out_attn
204
+ elif y is None and not self.y_to_all and i == 0:
205
+ out_attn = self.attention[i](output)
206
+ output = output + out_attn
207
+
208
+ output = self.down_sampling(output)
209
+ #print("down-block output shape:", output.size())
210
+ return output
211
+ #------------------------------------------------------------------------------
212
+ class MiddleBlock(nn.Module):
213
+ def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim, num_layers, dropout_rate, y_to_all=False):
214
+ super().__init__()
215
+ self.num_layers = num_layers
216
+ self.y_to_all = y_to_all
217
+ self.conv1 = nn.ModuleList([
218
+ Conv3(
219
+ in_channels=in_channels if i == 0 else out_channels,
220
+ out_channels=out_channels,
221
+ num_groups=8,
222
+ kernel_size=3,
223
+ norm=True,
224
+ activation=True,
225
+ dropout_rate=dropout_rate
226
+ ) for i in range(self.num_layers+1)
227
+ ])
228
+ self.conv2 = nn.ModuleList([
229
+ Conv3(
230
+ in_channels=out_channels,
231
+ out_channels=out_channels,
232
+ num_groups=8,
233
+ kernel_size=3,
234
+ norm=True,
235
+ activation=True,
236
+ dropout_rate=dropout_rate
237
+ ) for _ in range(self.num_layers+1)
238
+ ])
239
+ self.time_embedding = nn.ModuleList([
240
+ TimeEmbedding(
241
+ output_dim=out_channels,
242
+ embed_dim=time_embed_dim
243
+ ) for _ in range(self.num_layers+1)
244
+ ])
245
+ self.attention = nn.ModuleList([
246
+ Attention(
247
+ in_channels=out_channels,
248
+ y_embed_dim=y_embed_dim,
249
+ num_groups=8,
250
+ num_heads=4,
251
+ dropout_rate=dropout_rate
252
+ ) for _ in range(self.num_layers + 1)
253
+ ])
254
+ self.resnet = nn.ModuleList([
255
+ nn.Conv2d(
256
+ in_channels=in_channels if i == 0 else out_channels,
257
+ out_channels=out_channels,
258
+ kernel_size=1
259
+ ) for i in range(num_layers+1)
260
+ ])
261
+
262
+ def forward(self, x, embed_time, y=None):
263
+ #print("mid-input shape:", x.size())
264
+ output = x
265
+ resnet_input = output
266
+ output = self.conv1[0](output)
267
+ output = output + self.time_embedding[0](embed_time)[:, :, None, None]
268
+ output = self.conv2[0](output)
269
+ output = output + self.resnet[0](resnet_input)
270
+ for i in range(self.num_layers):
271
+ if y is not None and not self.y_to_all and i == 0:
272
+ out_attn = self.attention[i](output, y)
273
+ output = output + out_attn
274
+ elif y is not None and self.y_to_all:
275
+ out_attn = self.attention[i](output, y)
276
+ output = output + out_attn
277
+ elif y is None and self.y_to_all:
278
+ out_attn = self.attention[i](output)
279
+ output = output + out_attn
280
+ elif y is None and not self.y_to_all and i == 0:
281
+ out_attn = self.attention[i](output)
282
+ output = output + out_attn
283
+ resnet_input = output
284
+ output = self.conv1[i + 1](output)
285
+ output = output + self.time_embedding[i + 1](embed_time)[:, :, None, None]
286
+ output = self.conv2[i + 1](output)
287
+ output = output + self.resnet[i+1](resnet_input)
288
+ #print("mid-block output shape:", output.size())
289
+
290
+ return output
291
+ #------------------------------------------------------------------------------
292
+ class UpBlock(nn.Module):
293
+ def __init__(self, in_channels, out_channels, skip_channels, time_embed_dim, y_embed_dim, num_layers, up_sampling_factor, up_sampling=True, dropout_rate=0.2, y_to_all=False):
294
+ super().__init__()
295
+ self.num_layers = num_layers
296
+ self.y_to_all = y_to_all
297
+ effective_in_channels = in_channels//2 + skip_channels
298
+ self.conv1 = nn.ModuleList([
299
+ Conv3(
300
+ in_channels=effective_in_channels if i == 0 else out_channels,
301
+ out_channels=out_channels,
302
+ num_groups=8,
303
+ kernel_size=3,
304
+ norm=True,
305
+ activation=True,
306
+ dropout_rate=dropout_rate
307
+ ) for i in range(self.num_layers)
308
+ ])
309
+ self.conv2 = nn.ModuleList([
310
+ Conv3(
311
+ in_channels=out_channels,
312
+ out_channels=out_channels,
313
+ num_groups=8,
314
+ kernel_size=3,
315
+ norm=True,
316
+ activation=True,
317
+ dropout_rate=dropout_rate
318
+ ) for _ in range(self.num_layers)
319
+ ])
320
+ self.time_embedding = nn.ModuleList([
321
+ TimeEmbedding(
322
+ output_dim=out_channels,
323
+ embed_dim=time_embed_dim
324
+ ) for _ in range(self.num_layers)
325
+ ])
326
+ self.attention = nn.ModuleList([
327
+ Attention(
328
+ in_channels=out_channels,
329
+ y_embed_dim=y_embed_dim,
330
+ num_groups=8,
331
+ num_heads=4,
332
+ dropout_rate=dropout_rate
333
+ ) for _ in range(self.num_layers)
334
+ ])
335
+ self.up_sampling = UpSampling(
336
+ in_channels=in_channels,
337
+ out_channels=in_channels,
338
+ up_sampling_factor=up_sampling_factor,
339
+ conv_block=True,
340
+ up_sampling=True
341
+ ) if up_sampling else nn.Identity()
342
+ self.resnet = nn.ModuleList([
343
+ nn.Conv2d(
344
+ in_channels=effective_in_channels if i == 0 else out_channels,
345
+ out_channels=out_channels,
346
+ kernel_size=1
347
+ ) for i in range(num_layers)
348
+
349
+ ])
350
+
351
+ def forward(self, x, skip_connection, embed_time, y=None):
352
+ #print("up-block input shape:", x.size())
353
+ x = self.up_sampling(x)
354
+ x = torch.cat(tensors=[x, skip_connection], dim=1)
355
+ output = x
356
+ for i in range(self.num_layers):
357
+ resnet_input = output
358
+ output = self.conv1[i](output)
359
+ output = output + self.time_embedding[i](embed_time)[:, :, None, None]
360
+ output = self.conv2[i](output)
361
+ output = output + self.resnet[i](resnet_input)
362
+ if y is not None and not self.y_to_all and i == 0:
363
+ out_attn = self.attention[i](output, y)
364
+ output = output + out_attn
365
+ elif y is not None and self.y_to_all:
366
+ out_attn = self.attention[i](output, y)
367
+ output = output + out_attn
368
+ elif y is None and self.y_to_all:
369
+ out_attn = self.attention[i](output)
370
+ output = output + out_attn
371
+ elif y is None and not self.y_to_all and i == 0:
372
+ out_attn = self.attention[i](output)
373
+ output = output + out_attn
374
+ #print("up-block output shape:", output.size())
375
+ return output
376
+ #------------------------------------------------------------------------
377
+ class Conv3(nn.Module):
378
+ def __init__(self, in_channels, out_channels, num_groups=8, kernel_size=3, norm=True, activation=True, dropout_rate=0.2):
379
+ super().__init__()
380
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
381
+ self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels) if norm else nn.Identity()
382
+ self.activation = nn.SiLU() if activation else nn.Identity()
383
+ self.dropout = nn.Dropout(p=dropout_rate)
384
+
385
+ def forward(self, batch):
386
+ batch = self.conv(batch)
387
+ batch = self.group_norm(batch)
388
+ batch = self.activation(batch)
389
+ batch = self.dropout(batch)
390
+ return batch
391
+ #----------------------------------------------------------------
392
+ class TimeEmbedding(nn.Module):
393
+ def __init__(self, output_dim, embed_dim):
394
+ super().__init__()
395
+ self.embedding = nn.Sequential(
396
+ nn.SiLU(),
397
+ nn.Linear(in_features=embed_dim, out_features=output_dim)
398
+ )
399
+ def forward(self, batch):
400
+ return self.embedding(batch)
401
+ #----------------------------------------------------------------
402
+ class GetEmbeddedTime(nn.Module):
403
+ def __init__(self, embed_dim):
404
+ super().__init__()
405
+ assert embed_dim % 2 == 0, "The embedding dimension must be divisible by two"
406
+ self.embed_dim = embed_dim
407
+
408
+ def forward(self, time_steps):
409
+ i = torch.arange(start=0, end=self.embed_dim // 2, dtype=torch.float32, device=time_steps.device)
410
+ factor = 10000 ** (2 * i / self.embed_dim)
411
+ embed_time = time_steps[:, None] / factor
412
+ embed_time = torch.cat(tensors=[torch.sin(embed_time), torch.cos(embed_time)], dim=-1)
413
+ return embed_time
414
+ #----------------------------------------------------------------
415
+ class Attention(nn.Module):
416
+ def __init__(self, in_channels, y_embed_dim=768, num_heads=4, num_groups=8, dropout_rate=0.1):
417
+ super().__init__()
418
+ self.in_channels = in_channels
419
+ self.y_embed_dim = y_embed_dim
420
+ self.num_heads = num_heads
421
+ self.dropout_rate = dropout_rate
422
+ self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True)
423
+ self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
424
+ self.dropout = nn.Dropout(dropout_rate)
425
+ self.y_projection = nn.Linear(y_embed_dim, in_channels)
426
+
427
+ def forward(self, x, y=None):
428
+ batch_size, channels, h, w = x.shape
429
+ assert channels == self.in_channels, f"Expected {self.in_channels} channels, got {channels}"
430
+ x_reshaped = x.view(batch_size, channels, h * w).permute(0, 2, 1)
431
+ if y is not None:
432
+ y = self.y_projection(y)
433
+ if y.dim() != 3:
434
+ if y.dim() == 2:
435
+ y = y.unsqueeze(1)
436
+ else:
437
+ raise ValueError(
438
+ f"Expected y to be 2D or 3D after projection, got {y.dim()}D with shape {y.shape}"
439
+ )
440
+ if y.shape[-1] != self.in_channels:
441
+ raise ValueError(
442
+ f"Expected y's embedding dim to match in_channels ({self.in_channels}), got {y.shape[-1]}"
443
+ )
444
+ out, _ = self.attention(x_reshaped, y, y)
445
+ else:
446
+ out, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
447
+ out = out.permute(0, 2, 1).view(batch_size, channels, h, w)
448
+ out = self.norm(out)
449
+ out = self.dropout(out)
450
+ return out
451
+ #-----------------------------------------------------------------
452
+ class DownSampling(nn.Module):
453
+ def __init__(self, in_channels, out_channels, down_sampling_factor, conv_block=True, max_pool=True):
454
+ super().__init__()
455
+ self.conv_block = conv_block
456
+ self.max_pool = max_pool
457
+ self.down_sampling_factor = down_sampling_factor
458
+ self.conv = nn.Sequential(
459
+ nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
460
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2 if max_pool else out_channels,
461
+ kernel_size=3, stride=down_sampling_factor, padding=1)
462
+ ) if conv_block else nn.Identity()
463
+ self.pool = nn.Sequential(
464
+ nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
465
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels//2 if conv_block else out_channels,
466
+ kernel_size=1, stride=1, padding=0)
467
+ ) if max_pool else nn.Identity()
468
+
469
+ def forward(self, batch):
470
+ if not self.conv_block:
471
+ return self.pool(batch)
472
+ if not self.max_pool:
473
+ return self.conv(batch)
474
+ return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
475
+ #--------------------------------------------------------------------------
476
+ class UpSampling(nn.Module):
477
+ def __init__(self, in_channels, out_channels, up_sampling_factor, conv_block=True, up_sampling=True):
478
+ super().__init__()
479
+ self.conv_block = conv_block
480
+ self.up_sampling = up_sampling
481
+ self.up_sampling_factor = up_sampling_factor
482
+ half_out_channels = out_channels // 2
483
+ self.conv = nn.Sequential(
484
+ nn.ConvTranspose2d(
485
+ in_channels=in_channels,
486
+ out_channels=half_out_channels if up_sampling else out_channels,
487
+ kernel_size=3,
488
+ stride=up_sampling_factor,
489
+ padding=1,
490
+ output_padding=up_sampling_factor - 1
491
+ ),
492
+ nn.Conv2d(
493
+ in_channels=half_out_channels if up_sampling else out_channels,
494
+ out_channels=half_out_channels if up_sampling else out_channels,
495
+ kernel_size=1,
496
+ stride=1,
497
+ padding=0
498
+ )
499
+ ) if conv_block else nn.Identity()
500
+
501
+ self.up_sample = nn.Sequential(
502
+ nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
503
+ nn.Conv2d(in_channels=in_channels, out_channels=half_out_channels if conv_block else out_channels,
504
+ kernel_size=1, stride=1, padding=0)
505
+ ) if up_sampling else nn.Identity()
506
+
507
+ def forward(self, batch):
508
+ if not self.conv_block:
509
+ return self.up_sample(batch)
510
+ if not self.up_sampling:
511
+ return self.conv(batch)
512
+ conv_output = self.conv(batch)
513
+ up_sample_output = self.up_sample(batch)
514
+ if conv_output.shape[2:] != up_sample_output.shape[2:]:
515
+ _, _, h, w = conv_output.shape
516
+ up_sample_output = torch.nn.functional.interpolate(
517
+ up_sample_output,
518
+ size=(h, w),
519
+ mode='nearest'
520
+ )
521
+ return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
sde/reverse_sde.py ADDED
@@ -0,0 +1,115 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+
7
+ class ReverseSDE(nn.Module):
8
+ """Reverse diffusion process for SDE-based generative models.
9
+
10
+ Implements the reverse diffusion process for score-based generative models using
11
+ Stochastic Differential Equations (SDEs), supporting Variance Exploding (VE),
12
+ Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE methods, as
13
+ described in Song et al. (2021). The reverse process denoises a noisy input using
14
+ predicted noise estimates.
15
+
16
+ Parameters
17
+ ----------
18
+ hyper_params : object
19
+ Hyperparameter object containing SDE-specific parameters. Expected to have
20
+ attributes:
21
+ - `dt`: Time step size for SDE integration (float).
22
+ - `sigmas`: Sigma values for VE method (torch.Tensor, optional).
23
+ - `betas`: Beta values for VP, sub-VP, or ODE methods (torch.Tensor).
24
+ - `cum_betas`: Cumulative beta values for sub-VP method (torch.Tensor, optional).
25
+ method : str
26
+ SDE method to use. Supported methods: "ve", "vp", "sub-vp", "ode".
27
+
28
+ Attributes
29
+ ----------
30
+ hyper_params : object
31
+ Stores the provided hyperparameter object.
32
+ method : str
33
+ Selected SDE method.
34
+
35
+ Raises
36
+ ------
37
+ ValueError
38
+ If `method` is not one of the supported methods ("ve", "vp", "sub-vp", "ode").
39
+ """
40
+ def __init__(self, hyper_params, method):
41
+ super().__init__()
42
+ self.hyper_params = hyper_params
43
+ self.method = method
44
+
45
+ def forward(self, xt, noise, predicted_noise, time_steps):
46
+ """Applies the reverse SDE diffusion process to the noisy input.
47
+
48
+ Denoises the input `xt` by applying the reverse SDE process, using predicted
49
+ noise estimates and optional stochastic noise, according to the specified SDE
50
+ method at given time steps. Incorporates drift and diffusion terms as applicable.
51
+
52
+ Parameters
53
+ ----------
54
+ xt : torch.Tensor
55
+ Noisy input tensor at time step `t`, shape (batch_size, channels, height, width).
56
+ noise : torch.Tensor or None
57
+ Gaussian noise tensor, same shape as `xt`, used for stochasticity. If None,
58
+ no stochastic noise is added (e.g., for deterministic ODE).
59
+ predicted_noise : torch.Tensor
60
+ Predicted noise tensor, same shape as `xt`, typically output by a neural network.
61
+ time_steps : torch.Tensor
62
+ Tensor of time step indices (long), shape (batch_size,), where each value
63
+ is in the range [0, hyper_params.num_steps - 1].
64
+
65
+ Returns
66
+ -------
67
+ torch.Tensor
68
+ Denoised tensor at the previous time step, same shape as `xt`.
69
+
70
+ Raises
71
+ ------
72
+ ValueError
73
+ If `method` is not one of the supported methods ("ve", "vp", "sub-vp", "ode").
74
+
75
+ Notes
76
+ -----
77
+ - For the "ve" and "ode" methods, the output is clamped to [-1e5, 1e5] to prevent
78
+ numerical instability.
79
+ - Stochastic noise (`noise`) is only added if provided and the method supports it
80
+ (not applicable for "ode" in non-VE cases).
81
+ """
82
+ dt = self.hyper_params.dt
83
+ betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
84
+ cum_betas = self.hyper_params.cum_betas[time_steps].view(-1, 1, 1, 1)
85
+ if self.method == "ve":
86
+ sigma_t = self.hyper_params.sigmas[time_steps]
87
+ sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
88
+ sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
89
+ drift = -(sigma_t ** 2 - sigma_t_prev ** 2).view(-1, 1, 1, 1) * predicted_noise * dt
90
+ diffusion = sigma_diff.view(-1, 1, 1, 1) * noise if noise is not None else 0
91
+ xt = xt + drift + diffusion
92
+ xt = torch.clamp(xt, -1e5, 1e5)
93
+
94
+ elif self.method == "vp":
95
+ drift = -0.5 * betas * xt * dt - betas * predicted_noise * dt
96
+ diffusion = torch.sqrt(betas * dt) * noise if noise is not None else 0
97
+ xt = xt + drift + diffusion
98
+
99
+ elif self.method == "sub-vp":
100
+ drift = -0.5 * betas * xt * dt - betas * (1 - torch.exp(-2 * cum_betas)) * predicted_noise * dt
101
+ diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise if noise is not None else 0
102
+ xt = xt + drift + diffusion
103
+
104
+ elif self.method == "ode":
105
+ if self.method == "ve":
106
+ sigma_t = self.hyper_params.sigmas[time_steps]
107
+ sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
108
+ drift = -0.5 * (sigma_t ** 2 - sigma_t_prev ** 2).view(-1, 1, 1, 1) * predicted_noise * dt
109
+ else:
110
+ drift = -0.5 * betas * xt * dt - 0.5 * betas * predicted_noise * dt
111
+ xt = xt + drift
112
+ xt = torch.clamp(xt, -1e5, 1e5)
113
+ else:
114
+ raise ValueError(f"Unknown method: {self.method}")
115
+ return xt