TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
ldm/autoencoder.py ADDED
@@ -0,0 +1,855 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+
7
+
8
+ class AutoencoderLDM(nn.Module):
9
+ """Variational autoencoder for latent space compression in Latent Diffusion Models.
10
+
11
+ Encodes images into a latent space and decodes them back to the image space, used as
12
+ the `compressor_model` in LDM’s `TrainLDM` and `SampleLDM`. Supports KL-divergence
13
+ or vector quantization (VQ) regularization for the latent representation.
14
+
15
+ Parameters
16
+ ----------
17
+ in_channels : int
18
+ Number of input channels (e.g., 3 for RGB images).
19
+ down_channels : list
20
+ List of channel sizes for encoder downsampling blocks (e.g., [32, 64, 128, 256]).
21
+ up_channels : list
22
+ List of channel sizes for decoder upsampling blocks (e.g., [256, 128, 64, 16]).
23
+ out_channels : int
24
+ Number of output channels, typically equal to `in_channels`.
25
+ dropout_rate : float
26
+ Dropout rate for regularization in convolutional and attention layers.
27
+ num_heads : int
28
+ Number of attention heads in self-attention layers.
29
+ num_groups : int
30
+ Number of groups for group normalization in attention layers.
31
+ num_layers_per_block : int
32
+ Number of convolutional layers in each downsampling and upsampling block.
33
+ total_down_sampling_factor : int
34
+ Total downsampling factor across the encoder (e.g., 8 for 8x reduction).
35
+ latent_channels : int
36
+ Number of channels in the latent representation for diffusion models.
37
+ num_embeddings : int
38
+ Number of discrete embeddings in the VQ codebook (if `use_vq=True`).
39
+ use_vq : bool, optional
40
+ If True, uses vector quantization (VQ) regularization; otherwise, uses
41
+ KL-divergence (default: False).
42
+ beta : float, optional
43
+ Weight for KL-divergence loss (if `use_vq=False`) (default: 1.0).
44
+
45
+ Attributes
46
+ ----------
47
+ use_vq : bool
48
+ Whether VQ regularization is used.
49
+ beta : float
50
+ Fixed weight for KL-divergence loss.
51
+ current_beta : float
52
+ Current weight for KL-divergence loss (modifiable during training).
53
+ down_sampling_factor : int
54
+ Downsampling factor per block, derived from `total_down_sampling_factor`.
55
+ conv1 : torch.nn.Conv2d
56
+ Initial convolutional layer for encoding.
57
+ down_blocks : torch.nn.ModuleList
58
+ List of DownBlock modules for encoder downsampling.
59
+ attention1 : Attention
60
+ Self-attention layer after encoder downsampling.
61
+ vq_layer : VectorQuantizer or None
62
+ Vector quantization layer (if `use_vq=True`).
63
+ conv_mu : torch.nn.Conv2d or None
64
+ Convolutional layer for mean of latent distribution (if `use_vq=False`).
65
+ conv_logvar : torch.nn.Conv2d or None
66
+ Convolutional layer for log-variance of latent distribution (if `use_vq=False`).
67
+ quant_conv : torch.nn.Conv2d
68
+ Convolutional layer to project latent representation to `latent_channels`.
69
+ conv2 : torch.nn.Conv2d
70
+ Initial convolutional layer for decoding.
71
+ attention2 : Attention
72
+ Self-attention layer after decoder’s initial convolution.
73
+ up_blocks : torch.nn.ModuleList
74
+ List of UpBlock modules for decoder upsampling.
75
+ conv3 : Conv3
76
+ Final convolutional layer for output reconstruction.
77
+
78
+ Raises
79
+ ------
80
+ AssertionError
81
+ If `in_channels` does not equal `out_channels`.
82
+
83
+ Notes
84
+ -----
85
+ - The encoder downsamples images using `DownBlock` modules, followed by self-attention
86
+ and latent projection (VQ or KL-based).
87
+ - The decoder upsamples the latent representation using `UpBlock` modules, with
88
+ self-attention and final convolution.
89
+ - The `down_sampling_factor` is computed as `total_down_sampling_factor` raised to
90
+ the power of `1 / (len(down_channels) - 1)`, applied per downsampling block.
91
+ - The latent representation has `latent_channels` channels, suitable for LDM’s
92
+ diffusion process.
93
+ """
94
+ def __init__(
95
+ self,
96
+ in_channels, # number of channels of the original image. e.g., 3 for RBG.
97
+ down_channels, # a list of channels used in encoder. e.g., [32, 64, 128, 256].
98
+ up_channels, # a list of channels used in decoder. e.g., [256, 128, 64, 16].
99
+ out_channels, # probably the same as in_channels. used to construct the image.
100
+ dropout_rate, # dropout rate, prevents overfitting.
101
+ num_heads, # number of attention heads in self-attention layers.
102
+ num_groups, # number of groups in group normalization. used in self-attention.
103
+ num_layers_per_block, # number of convolutional layers within each down/up block.
104
+ total_down_sampling_factor, # total down-sampling factor, used to calculate down sampling factor: an integer used to down/up sample the input batch of images.
105
+ latent_channels, # final z channels for DM.
106
+ num_embeddings, # number of discrete embeddings in the codebook/dimensionality of each embedding vector. in case of using VectorQuantizer
107
+ use_vq=False, # flag to toggle between vq regularization and kl regularization; if false, uses kl.
108
+ beta=1.0 # weight for KL loss.
109
+
110
+ ):
111
+ super().__init__()
112
+ assert in_channels == out_channels, "Input and output channels must match for auto-encoding"
113
+ self.use_vq = use_vq
114
+ self.beta = beta
115
+ self.current_beta = beta
116
+ num_down_blocks = len(down_channels) - 1
117
+ self.down_sampling_factor = int(total_down_sampling_factor ** (1 / num_down_blocks))
118
+
119
+ # Encoder
120
+ self.conv1 = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, padding=1)
121
+ self.down_blocks = nn.ModuleList([
122
+ DownBlock(
123
+ in_channels=down_channels[i],
124
+ out_channels=down_channels[i + 1],
125
+ num_layers=num_layers_per_block,
126
+ down_sampling_factor=self.down_sampling_factor,
127
+ dropout_rate=dropout_rate
128
+ ) for i in range(num_down_blocks)
129
+ ])
130
+ self.attention1 = Attention(down_channels[-1], num_heads, num_groups, dropout_rate)
131
+
132
+ # Latent projection
133
+ if use_vq:
134
+ self.vq_layer = VectorQuantizer(num_embeddings, down_channels[-1])
135
+ self.quant_conv = nn.Conv2d(down_channels[-1], latent_channels, kernel_size=1)
136
+ else:
137
+ self.conv_mu = nn.Conv2d(down_channels[-1], down_channels[-1], kernel_size=3, padding=1)
138
+ self.conv_logvar = nn.Conv2d(down_channels[-1], down_channels[-1], kernel_size=3, padding=1)
139
+ self.quant_conv = nn.Conv2d(down_channels[-1], latent_channels, kernel_size=1)
140
+
141
+ # Decoder
142
+ self.conv2 = nn.Conv2d(latent_channels, up_channels[0], kernel_size=3, padding=1)
143
+ self.attention2 = Attention(up_channels[0], num_heads, num_groups, dropout_rate)
144
+ self.up_blocks = nn.ModuleList([
145
+ UpBlock(
146
+ in_channels=up_channels[i],
147
+ out_channels=up_channels[i + 1],
148
+ num_layers=num_layers_per_block,
149
+ up_sampling_factor=self.down_sampling_factor,
150
+ dropout_rate=dropout_rate
151
+ ) for i in range(len(up_channels) - 1)
152
+ ])
153
+ self.conv3 = Conv3(up_channels[-1], out_channels, dropout_rate)
154
+
155
+ def reparameterize(self, mu, logvar):
156
+ """Applies reparameterization trick for variational autoencoding.
157
+
158
+ Samples from a Gaussian distribution using the mean and log-variance to enable
159
+ differentiable training.
160
+
161
+ Parameters
162
+ ----------
163
+ mu : torch.Tensor
164
+ Mean of the latent distribution, shape (batch_size, channels, height, width).
165
+ logvar : torch.Tensor
166
+ Log-variance of the latent distribution, same shape as `mu`.
167
+
168
+ Returns
169
+ -------
170
+ torch.Tensor
171
+ Sampled latent representation, same shape as `mu`.
172
+ """
173
+ std = torch.exp(0.5 * logvar)
174
+ eps = torch.randn_like(std)
175
+ return mu + eps * std
176
+
177
+ def encode(self, x):
178
+ """Encodes images into a latent representation.
179
+
180
+ Processes input images through the encoder, applying convolutions, downsampling,
181
+ self-attention, and latent projection (VQ or KL-based).
182
+
183
+ Parameters
184
+ ----------
185
+ x : torch.Tensor
186
+ Input images, shape (batch_size, in_channels, height, width).
187
+
188
+ Returns
189
+ -------
190
+ tuple
191
+ A tuple containing:
192
+ - z: Latent representation, shape (batch_size, latent_channels,
193
+ height/down_sampling_factor, width/down_sampling_factor).
194
+ - reg_loss: Regularization loss (VQ loss if `use_vq=True`, KL-divergence
195
+ loss if `use_vq=False`).
196
+
197
+ Notes
198
+ -----
199
+ - The VQ loss is computed by `VectorQuantizer` if `use_vq=True`.
200
+ - The KL-divergence loss is normalized by batch size and latent size, weighted
201
+ by `current_beta`.
202
+ """
203
+ x = self.conv1(x)
204
+ for block in self.down_blocks:
205
+ x = block(x)
206
+ res_x = x
207
+ x = self.attention1(x)
208
+ x = x + res_x
209
+ if self.use_vq:
210
+ z, vq_loss = self.vq_layer(x)
211
+ z = self.quant_conv(z)
212
+ return z, vq_loss
213
+ else:
214
+ mu = self.conv_mu(x)
215
+ logvar = self.conv_logvar(x)
216
+ z = self.reparameterize(mu, logvar)
217
+ z = self.quant_conv(z)
218
+ kl_unnormalized = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
219
+ batch_size = x.size(0)
220
+ latent_size = torch.prod(torch.tensor(mu.shape[1:])).item()
221
+ kl_loss = kl_unnormalized / (batch_size * latent_size) * self.current_beta
222
+ return z, kl_loss
223
+
224
+ def decode(self, z):
225
+ """Decodes latent representations back to images.
226
+
227
+ Processes latent representations through the decoder, applying convolutions,
228
+ self-attention, upsampling, and final reconstruction.
229
+
230
+ Parameters
231
+ ----------
232
+ z : torch.Tensor
233
+ Latent representation, shape (batch_size, latent_channels,
234
+ height/down_sampling_factor, width/down_sampling_factor).
235
+
236
+ Returns
237
+ -------
238
+ torch.Tensor
239
+ Reconstructed images, shape (batch_size, out_channels, height, width).
240
+ """
241
+ x = self.conv2(z)
242
+ res_x = x
243
+ x = self.attention2(x)
244
+ x = x + res_x
245
+ for block in self.up_blocks:
246
+ x = block(x)
247
+ x = self.conv3(x)
248
+ return x
249
+
250
+ def forward(self, x):
251
+ """Encodes images to latent space and decodes them, computing reconstruction and regularization losses.
252
+
253
+ Performs a full autoencoding pass, encoding images to the latent space, decoding
254
+ them back, and calculating MSE reconstruction loss and regularization loss (VQ
255
+ or KL-based).
256
+
257
+ Parameters
258
+ ----------
259
+ x : torch.Tensor
260
+ Input images, shape (batch_size, in_channels, height, width).
261
+
262
+ Returns
263
+ -------
264
+ tuple
265
+ A tuple containing:
266
+ - x_hat: Reconstructed images, shape (batch_size, out_channels, height,
267
+ width).
268
+ - total_loss: Sum of reconstruction (MSE) and regularization losses.
269
+ - reg_loss: Regularization loss (VQ or KL-divergence).
270
+ - z: Latent representation, shape (batch_size, latent_channels,
271
+ height/down_sampling_factor, width/down_sampling_factor).
272
+
273
+ Notes
274
+ -----
275
+ - The reconstruction loss is computed as the mean squared error between `x_hat`
276
+ and `x`.
277
+ - The regularization loss depends on `use_vq` (VQ loss or KL-divergence).
278
+ """
279
+ z, reg_loss = self.encode(x)
280
+ x_hat = self.decode(z)
281
+ recon_loss = F.mse_loss(x_hat, x)
282
+ total_loss = recon_loss + reg_loss
283
+ return x_hat, total_loss, reg_loss, z # return z for DM
284
+ #------------------------------------------------------------------------------------------------
285
+ class VectorQuantizer(nn.Module):
286
+ """Vector quantization layer for discretizing latent representations.
287
+
288
+ Quantizes input latent vectors to the nearest embedding in a learned codebook,
289
+ used in `AutoencoderLDM` when `use_vq=True` to enable discrete latent spaces for
290
+ Latent Diffusion Models. Computes commitment and codebook losses to train the
291
+ codebook embeddings.
292
+
293
+ Parameters
294
+ ----------
295
+ num_embeddings : int
296
+ Number of discrete embeddings in the codebook.
297
+ embedding_dim : int
298
+ Dimensionality of each embedding vector (matches input channel dimension).
299
+ commitment_cost : float, optional
300
+ Weight for the commitment loss, encouraging inputs to be close to quantized
301
+ values (default: 0.25).
302
+
303
+ Attributes
304
+ ----------
305
+ embedding_dim : int
306
+ Dimensionality of embedding vectors.
307
+ num_embeddings : int
308
+ Number of embeddings in the codebook.
309
+ commitment_cost : float
310
+ Weight for commitment loss.
311
+ embedding : torch.nn.Embedding
312
+ Embedding layer containing the codebook, shape (num_embeddings,
313
+ embedding_dim).
314
+
315
+ Notes
316
+ -----
317
+ - The codebook embeddings are initialized uniformly in the range
318
+ [-1/num_embeddings, 1/num_embeddings].
319
+ - The forward pass flattens input latents, computes Euclidean distances to
320
+ codebook embeddings, and selects the nearest embedding for quantization.
321
+ - The commitment loss encourages input latents to be close to their quantized
322
+ versions, while the codebook loss updates embeddings to match inputs.
323
+ - A straight-through estimator is used to pass gradients from the quantized output
324
+ to the input.
325
+ """
326
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
327
+ super().__init__()
328
+ # dimensionality of each embedding vector
329
+ self.embedding_dim = embedding_dim
330
+ # number of discrete embeddings in the codebook
331
+ self.num_embeddings = num_embeddings
332
+ # commitment cost for the loss term to encourage z to be close to quantized values
333
+ self.commitment_cost = commitment_cost
334
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
335
+ # initialize embedding weights uniformly
336
+ self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
337
+
338
+ def forward(self, z):
339
+ """Quantizes latent representations to the nearest codebook embedding.
340
+
341
+ Computes the closest embedding for each input vector, applies quantization,
342
+ and calculates commitment and codebook losses for training.
343
+
344
+ Parameters
345
+ ----------
346
+ z : torch.Tensor
347
+ Input latent representation, shape (batch_size, embedding_dim, height,
348
+ width).
349
+
350
+ Returns
351
+ -------
352
+ tuple
353
+ A tuple containing:
354
+ - quantized: Quantized latent representation, same shape as `z`.
355
+ - vq_loss: Sum of commitment and codebook losses.
356
+
357
+ Raises
358
+ ------
359
+ AssertionError
360
+ If the channel dimension of `z` does not match `embedding_dim`.
361
+
362
+ Notes
363
+ -----
364
+ - The input is flattened to (batch_size * height * width, embedding_dim) for
365
+ distance computation.
366
+ - Euclidean distances are computed efficiently using vectorized operations.
367
+ - The commitment loss is scaled by `commitment_cost`, and the total VQ loss
368
+ combines commitment and codebook losses.
369
+ """
370
+ z = z.contiguous() # ensure contingency in memory
371
+ # flatten z to (batch_size * height * width, embedding_dim) for distance computation
372
+ assert z.size(1) == self.embedding_dim, f"Expected channel dim {self.embedding_dim}, got {z.size(1)}"
373
+ z_flattened = z.reshape(-1, self.embedding_dim)
374
+ # compute squared euclidean distances between z_flattened and all embeddings
375
+ distances = (torch.sum(z_flattened ** 2, dim=1, keepdim=True)
376
+ + torch.sum(self.embedding.weight ** 2, dim=1)
377
+ - 2 * torch.matmul(z_flattened, self.embedding.weight.t()))
378
+ # find the index of the closest embedding for each z_flattened vector
379
+ encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
380
+ # convert indices to one-hot encodings
381
+ encodings = F.one_hot(encoding_indices, self.num_embeddings).float().squeeze(1)
382
+ # map one-hot encodings to quantized values using the embedding weights
383
+ quantized = torch.matmul(encodings, self.embedding.weight).view_as(z)
384
+ # commitment loss to encourage z to be close to its quantized version
385
+ commitment_loss = self.commitment_cost * torch.mean((z.detach() - quantized) ** 2)
386
+ # codebook loss to encourage embeddings to move closer to z
387
+ codebook_loss = torch.mean((z - quantized.detach()) ** 2)
388
+ # straight-through estimator: copy gradients from quantized to z
389
+ quantized = z + (quantized - z).detach()
390
+ # return the quantized tensor and the combined vq loss
391
+ return quantized, commitment_loss + codebook_loss
392
+ #------------------------------------------------------------------------------------------------
393
+ class DownBlock(nn.Module):
394
+ """Downsampling block for the encoder in AutoencoderLDM.
395
+
396
+ Applies multiple convolutional layers with residual connections followed by
397
+ downsampling to reduce spatial dimensions in the encoder of the variational
398
+ autoencoder used in Latent Diffusion Models.
399
+
400
+ Parameters
401
+ ----------
402
+ in_channels : int
403
+ Number of input channels.
404
+ out_channels : int
405
+ Number of output channels for convolutional layers.
406
+ num_layers : int
407
+ Number of convolutional layer pairs (Conv3) per block.
408
+ down_sampling_factor : int
409
+ Factor by which to downsample spatial dimensions.
410
+ dropout_rate : float
411
+ Dropout rate for Conv3 layers.
412
+
413
+ Attributes
414
+ ----------
415
+ num_layers : int
416
+ Number of convolutional layer pairs.
417
+ conv1 : torch.nn.ModuleList
418
+ List of Conv3 layers for the first convolution in each pair.
419
+ conv2 : torch.nn.ModuleList
420
+ List of Conv3 layers for the second convolution in each pair.
421
+ down_sampling : DownSampling
422
+ Downsampling module to reduce spatial dimensions.
423
+ resnet : torch.nn.ModuleList
424
+ List of 1x1 convolutional layers for residual connections.
425
+
426
+ Notes
427
+ -----
428
+ - Each layer pair consists of two Conv3 modules with a residual connection using a
429
+ 1x1 convolution to match dimensions.
430
+ - The downsampling is applied after all convolutional layers, reducing spatial
431
+ dimensions by `down_sampling_factor`.
432
+ """
433
+ def __init__(self, in_channels, out_channels, num_layers, down_sampling_factor, dropout_rate):
434
+ super().__init__()
435
+ self.num_layers = num_layers
436
+ self.conv1 = nn.ModuleList([
437
+ Conv3(
438
+ in_channels=in_channels if i == 0 else out_channels,
439
+ out_channels=out_channels,
440
+ dropout_rate=dropout_rate
441
+ ) for i in range(self.num_layers)
442
+ ])
443
+ self.conv2 = nn.ModuleList([
444
+ Conv3(
445
+ in_channels=out_channels,
446
+ out_channels=out_channels,
447
+ dropout_rate=dropout_rate
448
+ ) for _ in range(self.num_layers)
449
+ ])
450
+
451
+ self.down_sampling = DownSampling(
452
+ in_channels=out_channels,
453
+ out_channels=out_channels,
454
+ down_sampling_factor=down_sampling_factor
455
+ )
456
+ self.resnet = nn.ModuleList([
457
+ nn.Conv2d(
458
+ in_channels=in_channels if i == 0 else out_channels,
459
+ out_channels=out_channels,
460
+ kernel_size=1
461
+ ) for i in range(num_layers)
462
+
463
+ ])
464
+
465
+ def forward(self, x):
466
+ """Processes input through convolutional layers and downsampling.
467
+
468
+ Parameters
469
+ ----------
470
+ x : torch.Tensor
471
+ Input tensor, shape (batch_size, in_channels, height, width).
472
+
473
+ Returns
474
+ -------
475
+ torch.Tensor
476
+ Output tensor, shape (batch_size, out_channels,
477
+ height/down_sampling_factor, width/down_sampling_factor).
478
+ """
479
+ output = x
480
+ for i in range(self.num_layers):
481
+ resnet_input = output
482
+ output = self.conv1[i](output)
483
+ output = self.conv2[i](output)
484
+ output = output + self.resnet[i](resnet_input)
485
+ output = self.down_sampling(output)
486
+ return output
487
+ # ------------------------------------------------------------------------------------------------
488
+ class Conv3(nn.Module):
489
+ """Convolutional layer with group normalization, SiLU activation, and dropout.
490
+
491
+ Used in DownBlock and UpBlock of AutoencoderLDM for feature extraction and
492
+ transformation in the encoder and decoder.
493
+
494
+ Parameters
495
+ ----------
496
+ in_channels : int
497
+ Number of input channels.
498
+ out_channels : int
499
+ Number of output channels.
500
+ dropout_rate : float
501
+ Dropout rate for regularization.
502
+
503
+ Attributes
504
+ ----------
505
+ group_norm : torch.nn.GroupNorm
506
+ Group normalization with 8 groups.
507
+ activation : torch.nn.SiLU
508
+ SiLU (Swish) activation function.
509
+ conv : torch.nn.Conv2d
510
+ 3x3 convolutional layer with padding to maintain spatial dimensions.
511
+ dropout : torch.nn.Dropout
512
+ Dropout layer for regularization.
513
+
514
+ Notes
515
+ -----
516
+ - The layer applies group normalization, SiLU activation, dropout, and a 3x3
517
+ convolution in sequence.
518
+ - Spatial dimensions are preserved due to padding=1 in the convolution.
519
+ """
520
+ def __init__(self, in_channels, out_channels, dropout_rate):
521
+ super().__init__()
522
+ self.group_norm = nn.GroupNorm(num_groups=8, num_channels=in_channels)
523
+ self.activation = nn.SiLU()
524
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
525
+ self.dropout = nn.Dropout(p=dropout_rate)
526
+
527
+ def forward(self, x):
528
+ """Processes input through group normalization, activation, dropout, and convolution.
529
+
530
+ Parameters
531
+ ----------
532
+ x : torch.Tensor
533
+ Input tensor, shape (batch_size, in_channels, height, width).
534
+
535
+ Returns
536
+ -------
537
+ torch.Tensor
538
+ Output tensor, shape (batch_size, out_channels, height, width).
539
+ """
540
+ x = self.group_norm(x)
541
+ x = self.activation(x)
542
+ x = self.dropout(x)
543
+ x = self.conv(x)
544
+ return x
545
+ #------------------------------------------------------------------------------------------------
546
+ class DownSampling(nn.Module):
547
+ """Downsampling module for reducing spatial dimensions in AutoencoderLDM’s encoder.
548
+
549
+ Combines convolutional downsampling and max pooling, concatenating their outputs
550
+ to preserve feature information during downsampling in DownBlock.
551
+
552
+ Parameters
553
+ ----------
554
+ in_channels : int
555
+ Number of input channels.
556
+ out_channels : int
557
+ Number of output channels (sum of conv and pool paths).
558
+ down_sampling_factor : int
559
+ Factor by which to downsample spatial dimensions.
560
+
561
+ Attributes
562
+ ----------
563
+ down_sampling_factor : int
564
+ Downsampling factor.
565
+ conv : torch.nn.Sequential
566
+ Convolutional path with 1x1 and 3x3 convolutions, outputting out_channels/2.
567
+ pool : torch.nn.Sequential
568
+ Max pooling path with 1x1 convolution, outputting out_channels/2.
569
+
570
+ Notes
571
+ -----
572
+ - The module splits the output channels evenly between convolutional and pooling
573
+ paths, concatenating them along the channel dimension.
574
+ - The convolutional path uses a stride equal to `down_sampling_factor`, while the
575
+ pooling path uses max pooling with the same factor.
576
+ """
577
+ def __init__(self, in_channels, out_channels, down_sampling_factor):
578
+ super().__init__()
579
+ self.down_sampling_factor = down_sampling_factor
580
+ self.conv = nn.Sequential(
581
+ nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
582
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2,
583
+ kernel_size=3, stride=down_sampling_factor, padding=1)
584
+ )
585
+ self.pool = nn.Sequential(
586
+ nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
587
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2,
588
+ kernel_size=1, stride=1, padding=0)
589
+ )
590
+
591
+ def forward(self, batch):
592
+ """Downsamples input by combining convolutional and pooling paths.
593
+
594
+ Parameters
595
+ ----------
596
+ batch : torch.Tensor
597
+ Input tensor, shape (batch_size, in_channels, height, width).
598
+
599
+ Returns
600
+ -------
601
+ torch.Tensor
602
+ Downsampled tensor, shape (batch_size, out_channels,
603
+ height/down_sampling_factor, width/down_sampling_factor).
604
+ """
605
+ return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
606
+ #------------------------------------------------------------------------------------------------
607
+ class Attention(nn.Module):
608
+ """Self-attention module for feature enhancement in AutoencoderLDM.
609
+
610
+ Applies multi-head self-attention to enhance features in the encoder and decoder,
611
+ used after downsampling (in DownBlock) and before upsampling (in UpBlock).
612
+
613
+ Parameters
614
+ ----------
615
+ num_channels : int
616
+ Number of input and output channels (embedding dimension for attention).
617
+ num_heads : int
618
+ Number of attention heads.
619
+ num_groups : int
620
+ Number of groups for group normalization.
621
+ dropout_rate : float
622
+ Dropout rate for attention outputs.
623
+
624
+ Attributes
625
+ ----------
626
+ group_norm : torch.nn.GroupNorm
627
+ Group normalization before attention.
628
+ attention : torch.nn.MultiheadAttention
629
+ Multi-head self-attention with `batch_first=True`.
630
+ dropout : torch.nn.Dropout
631
+ Dropout layer for regularization.
632
+
633
+ Notes
634
+ -----
635
+ - The input is reshaped to (batch_size, height * width, num_channels) for
636
+ attention processing, then restored to (batch_size, num_channels, height, width).
637
+ - Group normalization is applied before attention to stabilize training.
638
+ """
639
+ def __init__(self, num_channels, num_heads, num_groups, dropout_rate):
640
+ super().__init__()
641
+ self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
642
+ self.attention = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True)
643
+ self.dropout = nn.Dropout(p=dropout_rate)
644
+
645
+ def forward(self, x):
646
+ """Applies self-attention to input features.
647
+
648
+ Parameters
649
+ ----------
650
+ x : torch.Tensor
651
+ Input tensor, shape (batch_size, num_channels, height, width).
652
+
653
+ Returns
654
+ -------
655
+ torch.Tensor
656
+ Output tensor, same shape as input.
657
+ """
658
+ batch_size, channels, h, w = x.shape
659
+ x = x.reshape(batch_size, channels, h * w)
660
+ x = self.group_norm(x)
661
+ x = x.transpose(1, 2)
662
+ x, _ = self.attention(x, x, x)
663
+ x = self.dropout(x)
664
+ x = x.transpose(1, 2).reshape(batch_size, channels, h, w)
665
+ return x
666
+ #------------------------------------------------------------------------------------------------
667
+ class UpBlock(nn.Module):
668
+ """Upsampling block for the decoder in AutoencoderLDM.
669
+
670
+ Applies upsampling followed by multiple convolutional layers with residual
671
+ connections to increase spatial dimensions in the decoder of the variational
672
+ autoencoder used in Latent Diffusion Models.
673
+
674
+ Parameters
675
+ ----------
676
+ in_channels : int
677
+ Number of input channels.
678
+ out_channels : int
679
+ Number of output channels for convolutional layers.
680
+ num_layers : int
681
+ Number of convolutional layer pairs (Conv3) per block.
682
+ up_sampling_factor : int
683
+ Factor by which to upsample spatial dimensions.
684
+ dropout_rate : float
685
+ Dropout rate for Conv3 layers.
686
+
687
+ Attributes
688
+ ----------
689
+ num_layers : int
690
+ Number of convolutional layer pairs.
691
+ up_sampling : UpSampling
692
+ Upsampling module to increase spatial dimensions.
693
+ conv1 : torch.nn.ModuleList
694
+ List of Conv3 layers for the first convolution in each pair.
695
+ conv2 : torch.nn.ModuleList
696
+ List of Conv3 layers for the second convolution in each pair.
697
+ resnet : torch.nn.ModuleList
698
+ List of 1x1 convolutional layers for residual connections.
699
+
700
+ Notes
701
+ -----
702
+ - Upsampling is applied first, followed by convolutional layer pairs with residual
703
+ connections using 1x1 convolutions.
704
+ - Each layer pair consists of two Conv3 modules.
705
+ """
706
+ def __init__(self, in_channels, out_channels, num_layers, up_sampling_factor, dropout_rate):
707
+ super().__init__()
708
+ self.num_layers = num_layers
709
+ effective_in_channels = in_channels
710
+
711
+ self.up_sampling = UpSampling(
712
+ in_channels=in_channels,
713
+ out_channels=in_channels,
714
+ up_sampling_factor=up_sampling_factor
715
+ )
716
+
717
+ self.conv1 = nn.ModuleList([
718
+ Conv3(
719
+ in_channels=effective_in_channels if i == 0 else out_channels,
720
+ out_channels=out_channels,
721
+ dropout_rate=dropout_rate
722
+ ) for i in range(self.num_layers)
723
+ ])
724
+ self.conv2 = nn.ModuleList([
725
+ Conv3(
726
+ in_channels=out_channels,
727
+ out_channels=out_channels,
728
+ dropout_rate=dropout_rate
729
+ ) for _ in range(self.num_layers)
730
+ ])
731
+ self.resnet = nn.ModuleList([
732
+ nn.Conv2d(
733
+ in_channels=effective_in_channels if i == 0 else out_channels,
734
+ out_channels=out_channels,
735
+ kernel_size=1
736
+ ) for i in range(self.num_layers)
737
+ ])
738
+
739
+ def forward(self, x):
740
+ """Processes input through upsampling and convolutional layers.
741
+
742
+ Parameters
743
+ ----------
744
+ x : torch.Tensor
745
+ Input tensor, shape (batch_size, in_channels, height, width).
746
+
747
+ Returns
748
+ -------
749
+ torch.Tensor
750
+ Output tensor, shape (batch_size, out_channels,
751
+ height * up_sampling_factor, width * up_sampling_factor).
752
+ """
753
+ x = self.up_sampling(x)
754
+ output = x
755
+ for i in range(self.num_layers):
756
+ resnet_input = output
757
+ output = self.conv1[i](output)
758
+ output = self.conv2[i](output)
759
+ output = output + self.resnet[i](resnet_input)
760
+ return output
761
+ #------------------------------------------------------------------------------------------------
762
+ class UpSampling(nn.Module):
763
+ """Upsampling module for increasing spatial dimensions in AutoencoderLDM’s decoder.
764
+
765
+ Combines transposed convolution and nearest-neighbor upsampling, concatenating
766
+ their outputs to preserve feature information during upsampling in UpBlock.
767
+
768
+ Parameters
769
+ ----------
770
+ in_channels : int
771
+ Number of input channels.
772
+ out_channels : int
773
+ Number of output channels (sum of conv and upsample paths).
774
+ up_sampling_factor : int
775
+ Factor by which to upsample spatial dimensions.
776
+
777
+ Attributes
778
+ ----------
779
+ up_sampling_factor : int
780
+ Upsampling factor.
781
+ conv : torch.nn.Sequential
782
+ Transposed convolutional path, outputting out_channels/2.
783
+ up_sample : torch.nn.Sequential
784
+ Nearest-neighbor upsampling path with 1x1 convolution, outputting
785
+ out_channels/2.
786
+
787
+ Notes
788
+ -----
789
+ - The module splits the output channels evenly between transposed convolution and
790
+ upsampling paths, concatenating them along the channel dimension.
791
+ - If the spatial dimensions of the two paths differ, the upsampling path is
792
+ interpolated to match the convolutional path’s size.
793
+ """
794
+ def __init__(self, in_channels, out_channels, up_sampling_factor):
795
+ super().__init__()
796
+ half_out_channels = out_channels // 2
797
+ self.up_sampling_factor = up_sampling_factor
798
+ self.conv = nn.Sequential(
799
+ nn.ConvTranspose2d(
800
+ in_channels=in_channels,
801
+ out_channels=half_out_channels,
802
+ kernel_size=3,
803
+ stride=up_sampling_factor,
804
+ padding=1,
805
+ output_padding=up_sampling_factor - 1
806
+ ),
807
+ nn.Conv2d(
808
+ in_channels=half_out_channels,
809
+ out_channels=half_out_channels,
810
+ kernel_size=1,
811
+ stride=1,
812
+ padding=0
813
+ )
814
+ )
815
+ self.up_sample = nn.Sequential(
816
+ nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
817
+ nn.Conv2d(
818
+ in_channels=in_channels,
819
+ out_channels=half_out_channels,
820
+ kernel_size=1,
821
+ stride=1,
822
+ padding=0
823
+ )
824
+ )
825
+
826
+ def forward(self, batch):
827
+ """Upsamples input by combining transposed convolution and upsampling paths.
828
+
829
+ Parameters
830
+ ----------
831
+ batch : torch.Tensor
832
+ Input tensor, shape (batch_size, in_channels, height, width).
833
+
834
+ Returns
835
+ -------
836
+ torch.Tensor
837
+ Upsampled tensor, shape (batch_size, out_channels,
838
+ height * up_sampling_factor, width * up_sampling_factor).
839
+
840
+ Notes
841
+ -----
842
+ - Interpolation is applied if the spatial dimensions of the convolutional and
843
+ upsampling paths differ, using nearest-neighbor mode.
844
+ """
845
+ conv_output = self.conv(batch)
846
+ up_sample_output = self.up_sample(batch)
847
+ if conv_output.shape[2:] != up_sample_output.shape[2:]:
848
+ _, _, h, w = conv_output.shape
849
+ up_sample_output = torch.nn.functional.interpolate(
850
+ up_sample_output,
851
+ size=(h, w),
852
+ mode='nearest'
853
+ )
854
+
855
+ return torch.cat(tensors=[conv_output, up_sample_output], dim=1)