TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ldm/autoencoder.py
ADDED
|
@@ -0,0 +1,855 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AutoencoderLDM(nn.Module):
|
|
9
|
+
"""Variational autoencoder for latent space compression in Latent Diffusion Models.
|
|
10
|
+
|
|
11
|
+
Encodes images into a latent space and decodes them back to the image space, used as
|
|
12
|
+
the `compressor_model` in LDM’s `TrainLDM` and `SampleLDM`. Supports KL-divergence
|
|
13
|
+
or vector quantization (VQ) regularization for the latent representation.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
in_channels : int
|
|
18
|
+
Number of input channels (e.g., 3 for RGB images).
|
|
19
|
+
down_channels : list
|
|
20
|
+
List of channel sizes for encoder downsampling blocks (e.g., [32, 64, 128, 256]).
|
|
21
|
+
up_channels : list
|
|
22
|
+
List of channel sizes for decoder upsampling blocks (e.g., [256, 128, 64, 16]).
|
|
23
|
+
out_channels : int
|
|
24
|
+
Number of output channels, typically equal to `in_channels`.
|
|
25
|
+
dropout_rate : float
|
|
26
|
+
Dropout rate for regularization in convolutional and attention layers.
|
|
27
|
+
num_heads : int
|
|
28
|
+
Number of attention heads in self-attention layers.
|
|
29
|
+
num_groups : int
|
|
30
|
+
Number of groups for group normalization in attention layers.
|
|
31
|
+
num_layers_per_block : int
|
|
32
|
+
Number of convolutional layers in each downsampling and upsampling block.
|
|
33
|
+
total_down_sampling_factor : int
|
|
34
|
+
Total downsampling factor across the encoder (e.g., 8 for 8x reduction).
|
|
35
|
+
latent_channels : int
|
|
36
|
+
Number of channels in the latent representation for diffusion models.
|
|
37
|
+
num_embeddings : int
|
|
38
|
+
Number of discrete embeddings in the VQ codebook (if `use_vq=True`).
|
|
39
|
+
use_vq : bool, optional
|
|
40
|
+
If True, uses vector quantization (VQ) regularization; otherwise, uses
|
|
41
|
+
KL-divergence (default: False).
|
|
42
|
+
beta : float, optional
|
|
43
|
+
Weight for KL-divergence loss (if `use_vq=False`) (default: 1.0).
|
|
44
|
+
|
|
45
|
+
Attributes
|
|
46
|
+
----------
|
|
47
|
+
use_vq : bool
|
|
48
|
+
Whether VQ regularization is used.
|
|
49
|
+
beta : float
|
|
50
|
+
Fixed weight for KL-divergence loss.
|
|
51
|
+
current_beta : float
|
|
52
|
+
Current weight for KL-divergence loss (modifiable during training).
|
|
53
|
+
down_sampling_factor : int
|
|
54
|
+
Downsampling factor per block, derived from `total_down_sampling_factor`.
|
|
55
|
+
conv1 : torch.nn.Conv2d
|
|
56
|
+
Initial convolutional layer for encoding.
|
|
57
|
+
down_blocks : torch.nn.ModuleList
|
|
58
|
+
List of DownBlock modules for encoder downsampling.
|
|
59
|
+
attention1 : Attention
|
|
60
|
+
Self-attention layer after encoder downsampling.
|
|
61
|
+
vq_layer : VectorQuantizer or None
|
|
62
|
+
Vector quantization layer (if `use_vq=True`).
|
|
63
|
+
conv_mu : torch.nn.Conv2d or None
|
|
64
|
+
Convolutional layer for mean of latent distribution (if `use_vq=False`).
|
|
65
|
+
conv_logvar : torch.nn.Conv2d or None
|
|
66
|
+
Convolutional layer for log-variance of latent distribution (if `use_vq=False`).
|
|
67
|
+
quant_conv : torch.nn.Conv2d
|
|
68
|
+
Convolutional layer to project latent representation to `latent_channels`.
|
|
69
|
+
conv2 : torch.nn.Conv2d
|
|
70
|
+
Initial convolutional layer for decoding.
|
|
71
|
+
attention2 : Attention
|
|
72
|
+
Self-attention layer after decoder’s initial convolution.
|
|
73
|
+
up_blocks : torch.nn.ModuleList
|
|
74
|
+
List of UpBlock modules for decoder upsampling.
|
|
75
|
+
conv3 : Conv3
|
|
76
|
+
Final convolutional layer for output reconstruction.
|
|
77
|
+
|
|
78
|
+
Raises
|
|
79
|
+
------
|
|
80
|
+
AssertionError
|
|
81
|
+
If `in_channels` does not equal `out_channels`.
|
|
82
|
+
|
|
83
|
+
Notes
|
|
84
|
+
-----
|
|
85
|
+
- The encoder downsamples images using `DownBlock` modules, followed by self-attention
|
|
86
|
+
and latent projection (VQ or KL-based).
|
|
87
|
+
- The decoder upsamples the latent representation using `UpBlock` modules, with
|
|
88
|
+
self-attention and final convolution.
|
|
89
|
+
- The `down_sampling_factor` is computed as `total_down_sampling_factor` raised to
|
|
90
|
+
the power of `1 / (len(down_channels) - 1)`, applied per downsampling block.
|
|
91
|
+
- The latent representation has `latent_channels` channels, suitable for LDM’s
|
|
92
|
+
diffusion process.
|
|
93
|
+
"""
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
in_channels, # number of channels of the original image. e.g., 3 for RBG.
|
|
97
|
+
down_channels, # a list of channels used in encoder. e.g., [32, 64, 128, 256].
|
|
98
|
+
up_channels, # a list of channels used in decoder. e.g., [256, 128, 64, 16].
|
|
99
|
+
out_channels, # probably the same as in_channels. used to construct the image.
|
|
100
|
+
dropout_rate, # dropout rate, prevents overfitting.
|
|
101
|
+
num_heads, # number of attention heads in self-attention layers.
|
|
102
|
+
num_groups, # number of groups in group normalization. used in self-attention.
|
|
103
|
+
num_layers_per_block, # number of convolutional layers within each down/up block.
|
|
104
|
+
total_down_sampling_factor, # total down-sampling factor, used to calculate down sampling factor: an integer used to down/up sample the input batch of images.
|
|
105
|
+
latent_channels, # final z channels for DM.
|
|
106
|
+
num_embeddings, # number of discrete embeddings in the codebook/dimensionality of each embedding vector. in case of using VectorQuantizer
|
|
107
|
+
use_vq=False, # flag to toggle between vq regularization and kl regularization; if false, uses kl.
|
|
108
|
+
beta=1.0 # weight for KL loss.
|
|
109
|
+
|
|
110
|
+
):
|
|
111
|
+
super().__init__()
|
|
112
|
+
assert in_channels == out_channels, "Input and output channels must match for auto-encoding"
|
|
113
|
+
self.use_vq = use_vq
|
|
114
|
+
self.beta = beta
|
|
115
|
+
self.current_beta = beta
|
|
116
|
+
num_down_blocks = len(down_channels) - 1
|
|
117
|
+
self.down_sampling_factor = int(total_down_sampling_factor ** (1 / num_down_blocks))
|
|
118
|
+
|
|
119
|
+
# Encoder
|
|
120
|
+
self.conv1 = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, padding=1)
|
|
121
|
+
self.down_blocks = nn.ModuleList([
|
|
122
|
+
DownBlock(
|
|
123
|
+
in_channels=down_channels[i],
|
|
124
|
+
out_channels=down_channels[i + 1],
|
|
125
|
+
num_layers=num_layers_per_block,
|
|
126
|
+
down_sampling_factor=self.down_sampling_factor,
|
|
127
|
+
dropout_rate=dropout_rate
|
|
128
|
+
) for i in range(num_down_blocks)
|
|
129
|
+
])
|
|
130
|
+
self.attention1 = Attention(down_channels[-1], num_heads, num_groups, dropout_rate)
|
|
131
|
+
|
|
132
|
+
# Latent projection
|
|
133
|
+
if use_vq:
|
|
134
|
+
self.vq_layer = VectorQuantizer(num_embeddings, down_channels[-1])
|
|
135
|
+
self.quant_conv = nn.Conv2d(down_channels[-1], latent_channels, kernel_size=1)
|
|
136
|
+
else:
|
|
137
|
+
self.conv_mu = nn.Conv2d(down_channels[-1], down_channels[-1], kernel_size=3, padding=1)
|
|
138
|
+
self.conv_logvar = nn.Conv2d(down_channels[-1], down_channels[-1], kernel_size=3, padding=1)
|
|
139
|
+
self.quant_conv = nn.Conv2d(down_channels[-1], latent_channels, kernel_size=1)
|
|
140
|
+
|
|
141
|
+
# Decoder
|
|
142
|
+
self.conv2 = nn.Conv2d(latent_channels, up_channels[0], kernel_size=3, padding=1)
|
|
143
|
+
self.attention2 = Attention(up_channels[0], num_heads, num_groups, dropout_rate)
|
|
144
|
+
self.up_blocks = nn.ModuleList([
|
|
145
|
+
UpBlock(
|
|
146
|
+
in_channels=up_channels[i],
|
|
147
|
+
out_channels=up_channels[i + 1],
|
|
148
|
+
num_layers=num_layers_per_block,
|
|
149
|
+
up_sampling_factor=self.down_sampling_factor,
|
|
150
|
+
dropout_rate=dropout_rate
|
|
151
|
+
) for i in range(len(up_channels) - 1)
|
|
152
|
+
])
|
|
153
|
+
self.conv3 = Conv3(up_channels[-1], out_channels, dropout_rate)
|
|
154
|
+
|
|
155
|
+
def reparameterize(self, mu, logvar):
|
|
156
|
+
"""Applies reparameterization trick for variational autoencoding.
|
|
157
|
+
|
|
158
|
+
Samples from a Gaussian distribution using the mean and log-variance to enable
|
|
159
|
+
differentiable training.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
mu : torch.Tensor
|
|
164
|
+
Mean of the latent distribution, shape (batch_size, channels, height, width).
|
|
165
|
+
logvar : torch.Tensor
|
|
166
|
+
Log-variance of the latent distribution, same shape as `mu`.
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
-------
|
|
170
|
+
torch.Tensor
|
|
171
|
+
Sampled latent representation, same shape as `mu`.
|
|
172
|
+
"""
|
|
173
|
+
std = torch.exp(0.5 * logvar)
|
|
174
|
+
eps = torch.randn_like(std)
|
|
175
|
+
return mu + eps * std
|
|
176
|
+
|
|
177
|
+
def encode(self, x):
|
|
178
|
+
"""Encodes images into a latent representation.
|
|
179
|
+
|
|
180
|
+
Processes input images through the encoder, applying convolutions, downsampling,
|
|
181
|
+
self-attention, and latent projection (VQ or KL-based).
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
x : torch.Tensor
|
|
186
|
+
Input images, shape (batch_size, in_channels, height, width).
|
|
187
|
+
|
|
188
|
+
Returns
|
|
189
|
+
-------
|
|
190
|
+
tuple
|
|
191
|
+
A tuple containing:
|
|
192
|
+
- z: Latent representation, shape (batch_size, latent_channels,
|
|
193
|
+
height/down_sampling_factor, width/down_sampling_factor).
|
|
194
|
+
- reg_loss: Regularization loss (VQ loss if `use_vq=True`, KL-divergence
|
|
195
|
+
loss if `use_vq=False`).
|
|
196
|
+
|
|
197
|
+
Notes
|
|
198
|
+
-----
|
|
199
|
+
- The VQ loss is computed by `VectorQuantizer` if `use_vq=True`.
|
|
200
|
+
- The KL-divergence loss is normalized by batch size and latent size, weighted
|
|
201
|
+
by `current_beta`.
|
|
202
|
+
"""
|
|
203
|
+
x = self.conv1(x)
|
|
204
|
+
for block in self.down_blocks:
|
|
205
|
+
x = block(x)
|
|
206
|
+
res_x = x
|
|
207
|
+
x = self.attention1(x)
|
|
208
|
+
x = x + res_x
|
|
209
|
+
if self.use_vq:
|
|
210
|
+
z, vq_loss = self.vq_layer(x)
|
|
211
|
+
z = self.quant_conv(z)
|
|
212
|
+
return z, vq_loss
|
|
213
|
+
else:
|
|
214
|
+
mu = self.conv_mu(x)
|
|
215
|
+
logvar = self.conv_logvar(x)
|
|
216
|
+
z = self.reparameterize(mu, logvar)
|
|
217
|
+
z = self.quant_conv(z)
|
|
218
|
+
kl_unnormalized = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
|
219
|
+
batch_size = x.size(0)
|
|
220
|
+
latent_size = torch.prod(torch.tensor(mu.shape[1:])).item()
|
|
221
|
+
kl_loss = kl_unnormalized / (batch_size * latent_size) * self.current_beta
|
|
222
|
+
return z, kl_loss
|
|
223
|
+
|
|
224
|
+
def decode(self, z):
|
|
225
|
+
"""Decodes latent representations back to images.
|
|
226
|
+
|
|
227
|
+
Processes latent representations through the decoder, applying convolutions,
|
|
228
|
+
self-attention, upsampling, and final reconstruction.
|
|
229
|
+
|
|
230
|
+
Parameters
|
|
231
|
+
----------
|
|
232
|
+
z : torch.Tensor
|
|
233
|
+
Latent representation, shape (batch_size, latent_channels,
|
|
234
|
+
height/down_sampling_factor, width/down_sampling_factor).
|
|
235
|
+
|
|
236
|
+
Returns
|
|
237
|
+
-------
|
|
238
|
+
torch.Tensor
|
|
239
|
+
Reconstructed images, shape (batch_size, out_channels, height, width).
|
|
240
|
+
"""
|
|
241
|
+
x = self.conv2(z)
|
|
242
|
+
res_x = x
|
|
243
|
+
x = self.attention2(x)
|
|
244
|
+
x = x + res_x
|
|
245
|
+
for block in self.up_blocks:
|
|
246
|
+
x = block(x)
|
|
247
|
+
x = self.conv3(x)
|
|
248
|
+
return x
|
|
249
|
+
|
|
250
|
+
def forward(self, x):
|
|
251
|
+
"""Encodes images to latent space and decodes them, computing reconstruction and regularization losses.
|
|
252
|
+
|
|
253
|
+
Performs a full autoencoding pass, encoding images to the latent space, decoding
|
|
254
|
+
them back, and calculating MSE reconstruction loss and regularization loss (VQ
|
|
255
|
+
or KL-based).
|
|
256
|
+
|
|
257
|
+
Parameters
|
|
258
|
+
----------
|
|
259
|
+
x : torch.Tensor
|
|
260
|
+
Input images, shape (batch_size, in_channels, height, width).
|
|
261
|
+
|
|
262
|
+
Returns
|
|
263
|
+
-------
|
|
264
|
+
tuple
|
|
265
|
+
A tuple containing:
|
|
266
|
+
- x_hat: Reconstructed images, shape (batch_size, out_channels, height,
|
|
267
|
+
width).
|
|
268
|
+
- total_loss: Sum of reconstruction (MSE) and regularization losses.
|
|
269
|
+
- reg_loss: Regularization loss (VQ or KL-divergence).
|
|
270
|
+
- z: Latent representation, shape (batch_size, latent_channels,
|
|
271
|
+
height/down_sampling_factor, width/down_sampling_factor).
|
|
272
|
+
|
|
273
|
+
Notes
|
|
274
|
+
-----
|
|
275
|
+
- The reconstruction loss is computed as the mean squared error between `x_hat`
|
|
276
|
+
and `x`.
|
|
277
|
+
- The regularization loss depends on `use_vq` (VQ loss or KL-divergence).
|
|
278
|
+
"""
|
|
279
|
+
z, reg_loss = self.encode(x)
|
|
280
|
+
x_hat = self.decode(z)
|
|
281
|
+
recon_loss = F.mse_loss(x_hat, x)
|
|
282
|
+
total_loss = recon_loss + reg_loss
|
|
283
|
+
return x_hat, total_loss, reg_loss, z # return z for DM
|
|
284
|
+
#------------------------------------------------------------------------------------------------
|
|
285
|
+
class VectorQuantizer(nn.Module):
|
|
286
|
+
"""Vector quantization layer for discretizing latent representations.
|
|
287
|
+
|
|
288
|
+
Quantizes input latent vectors to the nearest embedding in a learned codebook,
|
|
289
|
+
used in `AutoencoderLDM` when `use_vq=True` to enable discrete latent spaces for
|
|
290
|
+
Latent Diffusion Models. Computes commitment and codebook losses to train the
|
|
291
|
+
codebook embeddings.
|
|
292
|
+
|
|
293
|
+
Parameters
|
|
294
|
+
----------
|
|
295
|
+
num_embeddings : int
|
|
296
|
+
Number of discrete embeddings in the codebook.
|
|
297
|
+
embedding_dim : int
|
|
298
|
+
Dimensionality of each embedding vector (matches input channel dimension).
|
|
299
|
+
commitment_cost : float, optional
|
|
300
|
+
Weight for the commitment loss, encouraging inputs to be close to quantized
|
|
301
|
+
values (default: 0.25).
|
|
302
|
+
|
|
303
|
+
Attributes
|
|
304
|
+
----------
|
|
305
|
+
embedding_dim : int
|
|
306
|
+
Dimensionality of embedding vectors.
|
|
307
|
+
num_embeddings : int
|
|
308
|
+
Number of embeddings in the codebook.
|
|
309
|
+
commitment_cost : float
|
|
310
|
+
Weight for commitment loss.
|
|
311
|
+
embedding : torch.nn.Embedding
|
|
312
|
+
Embedding layer containing the codebook, shape (num_embeddings,
|
|
313
|
+
embedding_dim).
|
|
314
|
+
|
|
315
|
+
Notes
|
|
316
|
+
-----
|
|
317
|
+
- The codebook embeddings are initialized uniformly in the range
|
|
318
|
+
[-1/num_embeddings, 1/num_embeddings].
|
|
319
|
+
- The forward pass flattens input latents, computes Euclidean distances to
|
|
320
|
+
codebook embeddings, and selects the nearest embedding for quantization.
|
|
321
|
+
- The commitment loss encourages input latents to be close to their quantized
|
|
322
|
+
versions, while the codebook loss updates embeddings to match inputs.
|
|
323
|
+
- A straight-through estimator is used to pass gradients from the quantized output
|
|
324
|
+
to the input.
|
|
325
|
+
"""
|
|
326
|
+
def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
|
|
327
|
+
super().__init__()
|
|
328
|
+
# dimensionality of each embedding vector
|
|
329
|
+
self.embedding_dim = embedding_dim
|
|
330
|
+
# number of discrete embeddings in the codebook
|
|
331
|
+
self.num_embeddings = num_embeddings
|
|
332
|
+
# commitment cost for the loss term to encourage z to be close to quantized values
|
|
333
|
+
self.commitment_cost = commitment_cost
|
|
334
|
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
|
335
|
+
# initialize embedding weights uniformly
|
|
336
|
+
self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
|
|
337
|
+
|
|
338
|
+
def forward(self, z):
|
|
339
|
+
"""Quantizes latent representations to the nearest codebook embedding.
|
|
340
|
+
|
|
341
|
+
Computes the closest embedding for each input vector, applies quantization,
|
|
342
|
+
and calculates commitment and codebook losses for training.
|
|
343
|
+
|
|
344
|
+
Parameters
|
|
345
|
+
----------
|
|
346
|
+
z : torch.Tensor
|
|
347
|
+
Input latent representation, shape (batch_size, embedding_dim, height,
|
|
348
|
+
width).
|
|
349
|
+
|
|
350
|
+
Returns
|
|
351
|
+
-------
|
|
352
|
+
tuple
|
|
353
|
+
A tuple containing:
|
|
354
|
+
- quantized: Quantized latent representation, same shape as `z`.
|
|
355
|
+
- vq_loss: Sum of commitment and codebook losses.
|
|
356
|
+
|
|
357
|
+
Raises
|
|
358
|
+
------
|
|
359
|
+
AssertionError
|
|
360
|
+
If the channel dimension of `z` does not match `embedding_dim`.
|
|
361
|
+
|
|
362
|
+
Notes
|
|
363
|
+
-----
|
|
364
|
+
- The input is flattened to (batch_size * height * width, embedding_dim) for
|
|
365
|
+
distance computation.
|
|
366
|
+
- Euclidean distances are computed efficiently using vectorized operations.
|
|
367
|
+
- The commitment loss is scaled by `commitment_cost`, and the total VQ loss
|
|
368
|
+
combines commitment and codebook losses.
|
|
369
|
+
"""
|
|
370
|
+
z = z.contiguous() # ensure contingency in memory
|
|
371
|
+
# flatten z to (batch_size * height * width, embedding_dim) for distance computation
|
|
372
|
+
assert z.size(1) == self.embedding_dim, f"Expected channel dim {self.embedding_dim}, got {z.size(1)}"
|
|
373
|
+
z_flattened = z.reshape(-1, self.embedding_dim)
|
|
374
|
+
# compute squared euclidean distances between z_flattened and all embeddings
|
|
375
|
+
distances = (torch.sum(z_flattened ** 2, dim=1, keepdim=True)
|
|
376
|
+
+ torch.sum(self.embedding.weight ** 2, dim=1)
|
|
377
|
+
- 2 * torch.matmul(z_flattened, self.embedding.weight.t()))
|
|
378
|
+
# find the index of the closest embedding for each z_flattened vector
|
|
379
|
+
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
|
|
380
|
+
# convert indices to one-hot encodings
|
|
381
|
+
encodings = F.one_hot(encoding_indices, self.num_embeddings).float().squeeze(1)
|
|
382
|
+
# map one-hot encodings to quantized values using the embedding weights
|
|
383
|
+
quantized = torch.matmul(encodings, self.embedding.weight).view_as(z)
|
|
384
|
+
# commitment loss to encourage z to be close to its quantized version
|
|
385
|
+
commitment_loss = self.commitment_cost * torch.mean((z.detach() - quantized) ** 2)
|
|
386
|
+
# codebook loss to encourage embeddings to move closer to z
|
|
387
|
+
codebook_loss = torch.mean((z - quantized.detach()) ** 2)
|
|
388
|
+
# straight-through estimator: copy gradients from quantized to z
|
|
389
|
+
quantized = z + (quantized - z).detach()
|
|
390
|
+
# return the quantized tensor and the combined vq loss
|
|
391
|
+
return quantized, commitment_loss + codebook_loss
|
|
392
|
+
#------------------------------------------------------------------------------------------------
|
|
393
|
+
class DownBlock(nn.Module):
|
|
394
|
+
"""Downsampling block for the encoder in AutoencoderLDM.
|
|
395
|
+
|
|
396
|
+
Applies multiple convolutional layers with residual connections followed by
|
|
397
|
+
downsampling to reduce spatial dimensions in the encoder of the variational
|
|
398
|
+
autoencoder used in Latent Diffusion Models.
|
|
399
|
+
|
|
400
|
+
Parameters
|
|
401
|
+
----------
|
|
402
|
+
in_channels : int
|
|
403
|
+
Number of input channels.
|
|
404
|
+
out_channels : int
|
|
405
|
+
Number of output channels for convolutional layers.
|
|
406
|
+
num_layers : int
|
|
407
|
+
Number of convolutional layer pairs (Conv3) per block.
|
|
408
|
+
down_sampling_factor : int
|
|
409
|
+
Factor by which to downsample spatial dimensions.
|
|
410
|
+
dropout_rate : float
|
|
411
|
+
Dropout rate for Conv3 layers.
|
|
412
|
+
|
|
413
|
+
Attributes
|
|
414
|
+
----------
|
|
415
|
+
num_layers : int
|
|
416
|
+
Number of convolutional layer pairs.
|
|
417
|
+
conv1 : torch.nn.ModuleList
|
|
418
|
+
List of Conv3 layers for the first convolution in each pair.
|
|
419
|
+
conv2 : torch.nn.ModuleList
|
|
420
|
+
List of Conv3 layers for the second convolution in each pair.
|
|
421
|
+
down_sampling : DownSampling
|
|
422
|
+
Downsampling module to reduce spatial dimensions.
|
|
423
|
+
resnet : torch.nn.ModuleList
|
|
424
|
+
List of 1x1 convolutional layers for residual connections.
|
|
425
|
+
|
|
426
|
+
Notes
|
|
427
|
+
-----
|
|
428
|
+
- Each layer pair consists of two Conv3 modules with a residual connection using a
|
|
429
|
+
1x1 convolution to match dimensions.
|
|
430
|
+
- The downsampling is applied after all convolutional layers, reducing spatial
|
|
431
|
+
dimensions by `down_sampling_factor`.
|
|
432
|
+
"""
|
|
433
|
+
def __init__(self, in_channels, out_channels, num_layers, down_sampling_factor, dropout_rate):
|
|
434
|
+
super().__init__()
|
|
435
|
+
self.num_layers = num_layers
|
|
436
|
+
self.conv1 = nn.ModuleList([
|
|
437
|
+
Conv3(
|
|
438
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
439
|
+
out_channels=out_channels,
|
|
440
|
+
dropout_rate=dropout_rate
|
|
441
|
+
) for i in range(self.num_layers)
|
|
442
|
+
])
|
|
443
|
+
self.conv2 = nn.ModuleList([
|
|
444
|
+
Conv3(
|
|
445
|
+
in_channels=out_channels,
|
|
446
|
+
out_channels=out_channels,
|
|
447
|
+
dropout_rate=dropout_rate
|
|
448
|
+
) for _ in range(self.num_layers)
|
|
449
|
+
])
|
|
450
|
+
|
|
451
|
+
self.down_sampling = DownSampling(
|
|
452
|
+
in_channels=out_channels,
|
|
453
|
+
out_channels=out_channels,
|
|
454
|
+
down_sampling_factor=down_sampling_factor
|
|
455
|
+
)
|
|
456
|
+
self.resnet = nn.ModuleList([
|
|
457
|
+
nn.Conv2d(
|
|
458
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
459
|
+
out_channels=out_channels,
|
|
460
|
+
kernel_size=1
|
|
461
|
+
) for i in range(num_layers)
|
|
462
|
+
|
|
463
|
+
])
|
|
464
|
+
|
|
465
|
+
def forward(self, x):
|
|
466
|
+
"""Processes input through convolutional layers and downsampling.
|
|
467
|
+
|
|
468
|
+
Parameters
|
|
469
|
+
----------
|
|
470
|
+
x : torch.Tensor
|
|
471
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
472
|
+
|
|
473
|
+
Returns
|
|
474
|
+
-------
|
|
475
|
+
torch.Tensor
|
|
476
|
+
Output tensor, shape (batch_size, out_channels,
|
|
477
|
+
height/down_sampling_factor, width/down_sampling_factor).
|
|
478
|
+
"""
|
|
479
|
+
output = x
|
|
480
|
+
for i in range(self.num_layers):
|
|
481
|
+
resnet_input = output
|
|
482
|
+
output = self.conv1[i](output)
|
|
483
|
+
output = self.conv2[i](output)
|
|
484
|
+
output = output + self.resnet[i](resnet_input)
|
|
485
|
+
output = self.down_sampling(output)
|
|
486
|
+
return output
|
|
487
|
+
# ------------------------------------------------------------------------------------------------
|
|
488
|
+
class Conv3(nn.Module):
|
|
489
|
+
"""Convolutional layer with group normalization, SiLU activation, and dropout.
|
|
490
|
+
|
|
491
|
+
Used in DownBlock and UpBlock of AutoencoderLDM for feature extraction and
|
|
492
|
+
transformation in the encoder and decoder.
|
|
493
|
+
|
|
494
|
+
Parameters
|
|
495
|
+
----------
|
|
496
|
+
in_channels : int
|
|
497
|
+
Number of input channels.
|
|
498
|
+
out_channels : int
|
|
499
|
+
Number of output channels.
|
|
500
|
+
dropout_rate : float
|
|
501
|
+
Dropout rate for regularization.
|
|
502
|
+
|
|
503
|
+
Attributes
|
|
504
|
+
----------
|
|
505
|
+
group_norm : torch.nn.GroupNorm
|
|
506
|
+
Group normalization with 8 groups.
|
|
507
|
+
activation : torch.nn.SiLU
|
|
508
|
+
SiLU (Swish) activation function.
|
|
509
|
+
conv : torch.nn.Conv2d
|
|
510
|
+
3x3 convolutional layer with padding to maintain spatial dimensions.
|
|
511
|
+
dropout : torch.nn.Dropout
|
|
512
|
+
Dropout layer for regularization.
|
|
513
|
+
|
|
514
|
+
Notes
|
|
515
|
+
-----
|
|
516
|
+
- The layer applies group normalization, SiLU activation, dropout, and a 3x3
|
|
517
|
+
convolution in sequence.
|
|
518
|
+
- Spatial dimensions are preserved due to padding=1 in the convolution.
|
|
519
|
+
"""
|
|
520
|
+
def __init__(self, in_channels, out_channels, dropout_rate):
|
|
521
|
+
super().__init__()
|
|
522
|
+
self.group_norm = nn.GroupNorm(num_groups=8, num_channels=in_channels)
|
|
523
|
+
self.activation = nn.SiLU()
|
|
524
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
525
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
526
|
+
|
|
527
|
+
def forward(self, x):
|
|
528
|
+
"""Processes input through group normalization, activation, dropout, and convolution.
|
|
529
|
+
|
|
530
|
+
Parameters
|
|
531
|
+
----------
|
|
532
|
+
x : torch.Tensor
|
|
533
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
534
|
+
|
|
535
|
+
Returns
|
|
536
|
+
-------
|
|
537
|
+
torch.Tensor
|
|
538
|
+
Output tensor, shape (batch_size, out_channels, height, width).
|
|
539
|
+
"""
|
|
540
|
+
x = self.group_norm(x)
|
|
541
|
+
x = self.activation(x)
|
|
542
|
+
x = self.dropout(x)
|
|
543
|
+
x = self.conv(x)
|
|
544
|
+
return x
|
|
545
|
+
#------------------------------------------------------------------------------------------------
|
|
546
|
+
class DownSampling(nn.Module):
|
|
547
|
+
"""Downsampling module for reducing spatial dimensions in AutoencoderLDM’s encoder.
|
|
548
|
+
|
|
549
|
+
Combines convolutional downsampling and max pooling, concatenating their outputs
|
|
550
|
+
to preserve feature information during downsampling in DownBlock.
|
|
551
|
+
|
|
552
|
+
Parameters
|
|
553
|
+
----------
|
|
554
|
+
in_channels : int
|
|
555
|
+
Number of input channels.
|
|
556
|
+
out_channels : int
|
|
557
|
+
Number of output channels (sum of conv and pool paths).
|
|
558
|
+
down_sampling_factor : int
|
|
559
|
+
Factor by which to downsample spatial dimensions.
|
|
560
|
+
|
|
561
|
+
Attributes
|
|
562
|
+
----------
|
|
563
|
+
down_sampling_factor : int
|
|
564
|
+
Downsampling factor.
|
|
565
|
+
conv : torch.nn.Sequential
|
|
566
|
+
Convolutional path with 1x1 and 3x3 convolutions, outputting out_channels/2.
|
|
567
|
+
pool : torch.nn.Sequential
|
|
568
|
+
Max pooling path with 1x1 convolution, outputting out_channels/2.
|
|
569
|
+
|
|
570
|
+
Notes
|
|
571
|
+
-----
|
|
572
|
+
- The module splits the output channels evenly between convolutional and pooling
|
|
573
|
+
paths, concatenating them along the channel dimension.
|
|
574
|
+
- The convolutional path uses a stride equal to `down_sampling_factor`, while the
|
|
575
|
+
pooling path uses max pooling with the same factor.
|
|
576
|
+
"""
|
|
577
|
+
def __init__(self, in_channels, out_channels, down_sampling_factor):
|
|
578
|
+
super().__init__()
|
|
579
|
+
self.down_sampling_factor = down_sampling_factor
|
|
580
|
+
self.conv = nn.Sequential(
|
|
581
|
+
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
|
|
582
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2,
|
|
583
|
+
kernel_size=3, stride=down_sampling_factor, padding=1)
|
|
584
|
+
)
|
|
585
|
+
self.pool = nn.Sequential(
|
|
586
|
+
nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
|
|
587
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2,
|
|
588
|
+
kernel_size=1, stride=1, padding=0)
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
def forward(self, batch):
|
|
592
|
+
"""Downsamples input by combining convolutional and pooling paths.
|
|
593
|
+
|
|
594
|
+
Parameters
|
|
595
|
+
----------
|
|
596
|
+
batch : torch.Tensor
|
|
597
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
598
|
+
|
|
599
|
+
Returns
|
|
600
|
+
-------
|
|
601
|
+
torch.Tensor
|
|
602
|
+
Downsampled tensor, shape (batch_size, out_channels,
|
|
603
|
+
height/down_sampling_factor, width/down_sampling_factor).
|
|
604
|
+
"""
|
|
605
|
+
return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
|
|
606
|
+
#------------------------------------------------------------------------------------------------
|
|
607
|
+
class Attention(nn.Module):
|
|
608
|
+
"""Self-attention module for feature enhancement in AutoencoderLDM.
|
|
609
|
+
|
|
610
|
+
Applies multi-head self-attention to enhance features in the encoder and decoder,
|
|
611
|
+
used after downsampling (in DownBlock) and before upsampling (in UpBlock).
|
|
612
|
+
|
|
613
|
+
Parameters
|
|
614
|
+
----------
|
|
615
|
+
num_channels : int
|
|
616
|
+
Number of input and output channels (embedding dimension for attention).
|
|
617
|
+
num_heads : int
|
|
618
|
+
Number of attention heads.
|
|
619
|
+
num_groups : int
|
|
620
|
+
Number of groups for group normalization.
|
|
621
|
+
dropout_rate : float
|
|
622
|
+
Dropout rate for attention outputs.
|
|
623
|
+
|
|
624
|
+
Attributes
|
|
625
|
+
----------
|
|
626
|
+
group_norm : torch.nn.GroupNorm
|
|
627
|
+
Group normalization before attention.
|
|
628
|
+
attention : torch.nn.MultiheadAttention
|
|
629
|
+
Multi-head self-attention with `batch_first=True`.
|
|
630
|
+
dropout : torch.nn.Dropout
|
|
631
|
+
Dropout layer for regularization.
|
|
632
|
+
|
|
633
|
+
Notes
|
|
634
|
+
-----
|
|
635
|
+
- The input is reshaped to (batch_size, height * width, num_channels) for
|
|
636
|
+
attention processing, then restored to (batch_size, num_channels, height, width).
|
|
637
|
+
- Group normalization is applied before attention to stabilize training.
|
|
638
|
+
"""
|
|
639
|
+
def __init__(self, num_channels, num_heads, num_groups, dropout_rate):
|
|
640
|
+
super().__init__()
|
|
641
|
+
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)
|
|
642
|
+
self.attention = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True)
|
|
643
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
644
|
+
|
|
645
|
+
def forward(self, x):
|
|
646
|
+
"""Applies self-attention to input features.
|
|
647
|
+
|
|
648
|
+
Parameters
|
|
649
|
+
----------
|
|
650
|
+
x : torch.Tensor
|
|
651
|
+
Input tensor, shape (batch_size, num_channels, height, width).
|
|
652
|
+
|
|
653
|
+
Returns
|
|
654
|
+
-------
|
|
655
|
+
torch.Tensor
|
|
656
|
+
Output tensor, same shape as input.
|
|
657
|
+
"""
|
|
658
|
+
batch_size, channels, h, w = x.shape
|
|
659
|
+
x = x.reshape(batch_size, channels, h * w)
|
|
660
|
+
x = self.group_norm(x)
|
|
661
|
+
x = x.transpose(1, 2)
|
|
662
|
+
x, _ = self.attention(x, x, x)
|
|
663
|
+
x = self.dropout(x)
|
|
664
|
+
x = x.transpose(1, 2).reshape(batch_size, channels, h, w)
|
|
665
|
+
return x
|
|
666
|
+
#------------------------------------------------------------------------------------------------
|
|
667
|
+
class UpBlock(nn.Module):
|
|
668
|
+
"""Upsampling block for the decoder in AutoencoderLDM.
|
|
669
|
+
|
|
670
|
+
Applies upsampling followed by multiple convolutional layers with residual
|
|
671
|
+
connections to increase spatial dimensions in the decoder of the variational
|
|
672
|
+
autoencoder used in Latent Diffusion Models.
|
|
673
|
+
|
|
674
|
+
Parameters
|
|
675
|
+
----------
|
|
676
|
+
in_channels : int
|
|
677
|
+
Number of input channels.
|
|
678
|
+
out_channels : int
|
|
679
|
+
Number of output channels for convolutional layers.
|
|
680
|
+
num_layers : int
|
|
681
|
+
Number of convolutional layer pairs (Conv3) per block.
|
|
682
|
+
up_sampling_factor : int
|
|
683
|
+
Factor by which to upsample spatial dimensions.
|
|
684
|
+
dropout_rate : float
|
|
685
|
+
Dropout rate for Conv3 layers.
|
|
686
|
+
|
|
687
|
+
Attributes
|
|
688
|
+
----------
|
|
689
|
+
num_layers : int
|
|
690
|
+
Number of convolutional layer pairs.
|
|
691
|
+
up_sampling : UpSampling
|
|
692
|
+
Upsampling module to increase spatial dimensions.
|
|
693
|
+
conv1 : torch.nn.ModuleList
|
|
694
|
+
List of Conv3 layers for the first convolution in each pair.
|
|
695
|
+
conv2 : torch.nn.ModuleList
|
|
696
|
+
List of Conv3 layers for the second convolution in each pair.
|
|
697
|
+
resnet : torch.nn.ModuleList
|
|
698
|
+
List of 1x1 convolutional layers for residual connections.
|
|
699
|
+
|
|
700
|
+
Notes
|
|
701
|
+
-----
|
|
702
|
+
- Upsampling is applied first, followed by convolutional layer pairs with residual
|
|
703
|
+
connections using 1x1 convolutions.
|
|
704
|
+
- Each layer pair consists of two Conv3 modules.
|
|
705
|
+
"""
|
|
706
|
+
def __init__(self, in_channels, out_channels, num_layers, up_sampling_factor, dropout_rate):
|
|
707
|
+
super().__init__()
|
|
708
|
+
self.num_layers = num_layers
|
|
709
|
+
effective_in_channels = in_channels
|
|
710
|
+
|
|
711
|
+
self.up_sampling = UpSampling(
|
|
712
|
+
in_channels=in_channels,
|
|
713
|
+
out_channels=in_channels,
|
|
714
|
+
up_sampling_factor=up_sampling_factor
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
self.conv1 = nn.ModuleList([
|
|
718
|
+
Conv3(
|
|
719
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
720
|
+
out_channels=out_channels,
|
|
721
|
+
dropout_rate=dropout_rate
|
|
722
|
+
) for i in range(self.num_layers)
|
|
723
|
+
])
|
|
724
|
+
self.conv2 = nn.ModuleList([
|
|
725
|
+
Conv3(
|
|
726
|
+
in_channels=out_channels,
|
|
727
|
+
out_channels=out_channels,
|
|
728
|
+
dropout_rate=dropout_rate
|
|
729
|
+
) for _ in range(self.num_layers)
|
|
730
|
+
])
|
|
731
|
+
self.resnet = nn.ModuleList([
|
|
732
|
+
nn.Conv2d(
|
|
733
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
734
|
+
out_channels=out_channels,
|
|
735
|
+
kernel_size=1
|
|
736
|
+
) for i in range(self.num_layers)
|
|
737
|
+
])
|
|
738
|
+
|
|
739
|
+
def forward(self, x):
|
|
740
|
+
"""Processes input through upsampling and convolutional layers.
|
|
741
|
+
|
|
742
|
+
Parameters
|
|
743
|
+
----------
|
|
744
|
+
x : torch.Tensor
|
|
745
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
746
|
+
|
|
747
|
+
Returns
|
|
748
|
+
-------
|
|
749
|
+
torch.Tensor
|
|
750
|
+
Output tensor, shape (batch_size, out_channels,
|
|
751
|
+
height * up_sampling_factor, width * up_sampling_factor).
|
|
752
|
+
"""
|
|
753
|
+
x = self.up_sampling(x)
|
|
754
|
+
output = x
|
|
755
|
+
for i in range(self.num_layers):
|
|
756
|
+
resnet_input = output
|
|
757
|
+
output = self.conv1[i](output)
|
|
758
|
+
output = self.conv2[i](output)
|
|
759
|
+
output = output + self.resnet[i](resnet_input)
|
|
760
|
+
return output
|
|
761
|
+
#------------------------------------------------------------------------------------------------
|
|
762
|
+
class UpSampling(nn.Module):
|
|
763
|
+
"""Upsampling module for increasing spatial dimensions in AutoencoderLDM’s decoder.
|
|
764
|
+
|
|
765
|
+
Combines transposed convolution and nearest-neighbor upsampling, concatenating
|
|
766
|
+
their outputs to preserve feature information during upsampling in UpBlock.
|
|
767
|
+
|
|
768
|
+
Parameters
|
|
769
|
+
----------
|
|
770
|
+
in_channels : int
|
|
771
|
+
Number of input channels.
|
|
772
|
+
out_channels : int
|
|
773
|
+
Number of output channels (sum of conv and upsample paths).
|
|
774
|
+
up_sampling_factor : int
|
|
775
|
+
Factor by which to upsample spatial dimensions.
|
|
776
|
+
|
|
777
|
+
Attributes
|
|
778
|
+
----------
|
|
779
|
+
up_sampling_factor : int
|
|
780
|
+
Upsampling factor.
|
|
781
|
+
conv : torch.nn.Sequential
|
|
782
|
+
Transposed convolutional path, outputting out_channels/2.
|
|
783
|
+
up_sample : torch.nn.Sequential
|
|
784
|
+
Nearest-neighbor upsampling path with 1x1 convolution, outputting
|
|
785
|
+
out_channels/2.
|
|
786
|
+
|
|
787
|
+
Notes
|
|
788
|
+
-----
|
|
789
|
+
- The module splits the output channels evenly between transposed convolution and
|
|
790
|
+
upsampling paths, concatenating them along the channel dimension.
|
|
791
|
+
- If the spatial dimensions of the two paths differ, the upsampling path is
|
|
792
|
+
interpolated to match the convolutional path’s size.
|
|
793
|
+
"""
|
|
794
|
+
def __init__(self, in_channels, out_channels, up_sampling_factor):
|
|
795
|
+
super().__init__()
|
|
796
|
+
half_out_channels = out_channels // 2
|
|
797
|
+
self.up_sampling_factor = up_sampling_factor
|
|
798
|
+
self.conv = nn.Sequential(
|
|
799
|
+
nn.ConvTranspose2d(
|
|
800
|
+
in_channels=in_channels,
|
|
801
|
+
out_channels=half_out_channels,
|
|
802
|
+
kernel_size=3,
|
|
803
|
+
stride=up_sampling_factor,
|
|
804
|
+
padding=1,
|
|
805
|
+
output_padding=up_sampling_factor - 1
|
|
806
|
+
),
|
|
807
|
+
nn.Conv2d(
|
|
808
|
+
in_channels=half_out_channels,
|
|
809
|
+
out_channels=half_out_channels,
|
|
810
|
+
kernel_size=1,
|
|
811
|
+
stride=1,
|
|
812
|
+
padding=0
|
|
813
|
+
)
|
|
814
|
+
)
|
|
815
|
+
self.up_sample = nn.Sequential(
|
|
816
|
+
nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
|
|
817
|
+
nn.Conv2d(
|
|
818
|
+
in_channels=in_channels,
|
|
819
|
+
out_channels=half_out_channels,
|
|
820
|
+
kernel_size=1,
|
|
821
|
+
stride=1,
|
|
822
|
+
padding=0
|
|
823
|
+
)
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
def forward(self, batch):
|
|
827
|
+
"""Upsamples input by combining transposed convolution and upsampling paths.
|
|
828
|
+
|
|
829
|
+
Parameters
|
|
830
|
+
----------
|
|
831
|
+
batch : torch.Tensor
|
|
832
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
833
|
+
|
|
834
|
+
Returns
|
|
835
|
+
-------
|
|
836
|
+
torch.Tensor
|
|
837
|
+
Upsampled tensor, shape (batch_size, out_channels,
|
|
838
|
+
height * up_sampling_factor, width * up_sampling_factor).
|
|
839
|
+
|
|
840
|
+
Notes
|
|
841
|
+
-----
|
|
842
|
+
- Interpolation is applied if the spatial dimensions of the convolutional and
|
|
843
|
+
upsampling paths differ, using nearest-neighbor mode.
|
|
844
|
+
"""
|
|
845
|
+
conv_output = self.conv(batch)
|
|
846
|
+
up_sample_output = self.up_sample(batch)
|
|
847
|
+
if conv_output.shape[2:] != up_sample_output.shape[2:]:
|
|
848
|
+
_, _, h, w = conv_output.shape
|
|
849
|
+
up_sample_output = torch.nn.functional.interpolate(
|
|
850
|
+
up_sample_output,
|
|
851
|
+
size=(h, w),
|
|
852
|
+
mode='nearest'
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
|