TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ldm/noise_predictor.py
ADDED
|
@@ -0,0 +1,1074 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class NoisePredictor(nn.Module):
|
|
7
|
+
"""U-Net-like architecture for noise prediction in Latent Diffusion Models.
|
|
8
|
+
|
|
9
|
+
Predicts noise in the latent space for diffusion models (DDPM, DDIM, SDE), incorporating
|
|
10
|
+
time embeddings and optional text conditioning. Used as the `noise_predictor` in
|
|
11
|
+
`TrainLDM` and `SampleLDM` from the `ldm` module.
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
in_channels : int
|
|
16
|
+
Number of input channels (matches latent channels from `AutoencoderLDM`).
|
|
17
|
+
down_channels : list of int
|
|
18
|
+
List of output channels for downsampling blocks.
|
|
19
|
+
mid_channels : list of int
|
|
20
|
+
List of channels for middle blocks.
|
|
21
|
+
up_channels : list of int
|
|
22
|
+
List of output channels for upsampling blocks.
|
|
23
|
+
down_sampling : list of bool
|
|
24
|
+
List indicating whether to downsample in each down block.
|
|
25
|
+
time_embed_dim : int
|
|
26
|
+
Dimensionality of time embeddings.
|
|
27
|
+
y_embed_dim : int
|
|
28
|
+
Dimensionality of text embeddings for conditioning.
|
|
29
|
+
num_down_blocks : int
|
|
30
|
+
Number of convolutional layer pairs per down block.
|
|
31
|
+
num_mid_blocks : int
|
|
32
|
+
Number of convolutional layer pairs per middle block.
|
|
33
|
+
num_up_blocks : int
|
|
34
|
+
Number of convolutional layer pairs per up block.
|
|
35
|
+
dropout_rate : float, optional
|
|
36
|
+
Dropout rate for convolutional and attention layers (default: 0.1).
|
|
37
|
+
down_sampling_factor : int, optional
|
|
38
|
+
Factor for spatial downsampling/upsampling (default: 2).
|
|
39
|
+
where_y : bool, optional
|
|
40
|
+
If True, text embeddings are used in attention; if False, concatenated to input
|
|
41
|
+
(default: True).
|
|
42
|
+
y_to_all : bool, optional
|
|
43
|
+
If True, apply text-conditioned attention to all layers; if False, only first layer
|
|
44
|
+
(default: False).
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
in_channels : int
|
|
49
|
+
Number of input channels.
|
|
50
|
+
down_channels : list of int
|
|
51
|
+
Channels for downsampling blocks.
|
|
52
|
+
mid_channels : list of int
|
|
53
|
+
Channels for middle blocks.
|
|
54
|
+
up_channels : list of int
|
|
55
|
+
Channels for upsampling blocks.
|
|
56
|
+
down_sampling : list of bool
|
|
57
|
+
Downsampling flags.
|
|
58
|
+
time_embed_dim : int
|
|
59
|
+
Time embedding dimension.
|
|
60
|
+
y_embed_dim : int
|
|
61
|
+
Text embedding dimension.
|
|
62
|
+
num_down_blocks : int
|
|
63
|
+
Number of layer pairs per down block.
|
|
64
|
+
num_mid_blocks : int
|
|
65
|
+
Number of layer pairs per middle block.
|
|
66
|
+
num_up_blocks : int
|
|
67
|
+
Number of layer pairs per up block.
|
|
68
|
+
dropout_rate : float
|
|
69
|
+
Dropout rate.
|
|
70
|
+
where_y : bool
|
|
71
|
+
Flag for text embedding usage.
|
|
72
|
+
up_sampling : list of bool
|
|
73
|
+
Reversed `down_sampling` for upsampling blocks.
|
|
74
|
+
conv1 : torch.nn.Conv2d
|
|
75
|
+
Initial 3x3 convolutional layer.
|
|
76
|
+
time_projection : torch.nn.Sequential
|
|
77
|
+
Projection for time embeddings.
|
|
78
|
+
down_blocks : torch.nn.ModuleList
|
|
79
|
+
List of DownBlock modules for downsampling.
|
|
80
|
+
mid_blocks : torch.nn.ModuleList
|
|
81
|
+
List of MiddleBlock modules for bottleneck processing.
|
|
82
|
+
up_blocks : torch.nn.ModuleList
|
|
83
|
+
List of UpBlock modules for upsampling.
|
|
84
|
+
conv2 : torch.nn.Sequential
|
|
85
|
+
Final convolutional layer with group normalization and dropout.
|
|
86
|
+
|
|
87
|
+
Notes
|
|
88
|
+
-----
|
|
89
|
+
- The architecture follows a U-Net structure with downsampling, bottleneck, and
|
|
90
|
+
upsampling blocks, incorporating time embeddings and optional text conditioning via
|
|
91
|
+
attention or concatenation.
|
|
92
|
+
- Skip connections link down and up blocks, with channel adjustments for concatenation.
|
|
93
|
+
- Weights are initialized with Kaiming normal (Leaky ReLU nonlinearity) for stability.
|
|
94
|
+
- Input and output tensors have the same shape, matching the latent space of
|
|
95
|
+
`AutoencoderLDM`.
|
|
96
|
+
"""
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
in_channels,
|
|
100
|
+
down_channels,
|
|
101
|
+
mid_channels,
|
|
102
|
+
up_channels,
|
|
103
|
+
down_sampling,
|
|
104
|
+
time_embed_dim,
|
|
105
|
+
y_embed_dim, # output embedding dimension in text conditional net
|
|
106
|
+
num_down_blocks,
|
|
107
|
+
num_mid_blocks,
|
|
108
|
+
num_up_blocks,
|
|
109
|
+
dropout_rate=0.1,
|
|
110
|
+
down_sampling_factor=2,
|
|
111
|
+
where_y=True,
|
|
112
|
+
y_to_all=False
|
|
113
|
+
):
|
|
114
|
+
super().__init__()
|
|
115
|
+
self.in_channels = in_channels
|
|
116
|
+
self.down_channels = down_channels
|
|
117
|
+
self.mid_channels = mid_channels
|
|
118
|
+
self.up_channels = up_channels
|
|
119
|
+
self.down_sampling = down_sampling
|
|
120
|
+
self.time_embed_dim = time_embed_dim
|
|
121
|
+
self.y_embed_dim = y_embed_dim
|
|
122
|
+
self.num_down_blocks = num_down_blocks
|
|
123
|
+
self.num_mid_blocks = num_mid_blocks
|
|
124
|
+
self.num_up_blocks = num_up_blocks
|
|
125
|
+
self.dropout_rate = dropout_rate
|
|
126
|
+
self.where_y = where_y
|
|
127
|
+
self.up_sampling = list(reversed(self.down_sampling))
|
|
128
|
+
self.conv1 = nn.Conv2d(
|
|
129
|
+
in_channels=self.in_channels,
|
|
130
|
+
out_channels=self.down_channels[0],
|
|
131
|
+
kernel_size=3,
|
|
132
|
+
padding=1
|
|
133
|
+
)
|
|
134
|
+
# initial time embedding projection
|
|
135
|
+
self.time_projection = nn.Sequential(
|
|
136
|
+
nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim),
|
|
137
|
+
nn.SiLU(),
|
|
138
|
+
nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim)
|
|
139
|
+
)
|
|
140
|
+
# down blocks
|
|
141
|
+
self.down_blocks = nn.ModuleList([
|
|
142
|
+
DownBlock(
|
|
143
|
+
in_channels=self.down_channels[i],
|
|
144
|
+
out_channels=self.down_channels[i+1],
|
|
145
|
+
time_embed_dim=self.time_embed_dim,
|
|
146
|
+
y_embed_dim=y_embed_dim,
|
|
147
|
+
num_layers=self.num_down_blocks,
|
|
148
|
+
down_sampling_factor=down_sampling_factor,
|
|
149
|
+
down_sample=self.down_sampling[i],
|
|
150
|
+
dropout_rate=self.dropout_rate,
|
|
151
|
+
y_to_all=y_to_all
|
|
152
|
+
) for i in range(len(self.down_channels)-1)
|
|
153
|
+
])
|
|
154
|
+
# middle blocks
|
|
155
|
+
self.mid_blocks = nn.ModuleList([
|
|
156
|
+
MiddleBlock(
|
|
157
|
+
in_channels=self.mid_channels[i],
|
|
158
|
+
out_channels=self.mid_channels[i + 1],
|
|
159
|
+
time_embed_dim=self.time_embed_dim,
|
|
160
|
+
y_embed_dim=y_embed_dim,
|
|
161
|
+
num_layers=self.num_mid_blocks,
|
|
162
|
+
dropout_rate=self.dropout_rate,
|
|
163
|
+
y_to_all=y_to_all
|
|
164
|
+
) for i in range(len(self.mid_channels) - 1)
|
|
165
|
+
])
|
|
166
|
+
# up blocks
|
|
167
|
+
skip_channels = list(reversed(self.down_channels))
|
|
168
|
+
self.up_blocks = nn.ModuleList([
|
|
169
|
+
UpBlock(
|
|
170
|
+
in_channels=self.up_channels[i],
|
|
171
|
+
out_channels=self.up_channels[i+1],
|
|
172
|
+
skip_channels=skip_channels[i],
|
|
173
|
+
time_embed_dim=self.time_embed_dim,
|
|
174
|
+
y_embed_dim=y_embed_dim,
|
|
175
|
+
num_layers=self.num_up_blocks,
|
|
176
|
+
up_sampling_factor=down_sampling_factor,
|
|
177
|
+
up_sampling=self.up_sampling[i],
|
|
178
|
+
dropout_rate=self.dropout_rate,
|
|
179
|
+
y_to_all=y_to_all
|
|
180
|
+
) for i in range(len(self.up_channels)-1)
|
|
181
|
+
])
|
|
182
|
+
# final convolution layer
|
|
183
|
+
self.conv2 = nn.Sequential(
|
|
184
|
+
nn.GroupNorm(num_groups=8, num_channels=self.up_channels[-1]),
|
|
185
|
+
nn.Dropout(p=self.dropout_rate),
|
|
186
|
+
nn.Conv2d(in_channels=self.up_channels[-1], out_channels=self.in_channels, kernel_size=3, padding=1)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def initialize_weights(self):
|
|
190
|
+
"""Initializes model weights for training stability.
|
|
191
|
+
|
|
192
|
+
Applies Kaiming normal initialization to convolutional and linear layers with
|
|
193
|
+
Leaky ReLU nonlinearity (a=0.2), and zeros biases.
|
|
194
|
+
"""
|
|
195
|
+
for module in self.modules():
|
|
196
|
+
if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
|
|
197
|
+
nn.init.kaiming_normal_(module.weight, a=0.2, nonlinearity='leaky_relu')
|
|
198
|
+
if module.bias is not None:
|
|
199
|
+
nn.init.zeros_(module.bias)
|
|
200
|
+
|
|
201
|
+
def forward(self, x, t, y=None):
|
|
202
|
+
"""Predicts noise given latent input, time step, and optional text conditioning.
|
|
203
|
+
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
x : torch.Tensor
|
|
207
|
+
Input latent tensor, shape (batch_size, in_channels, height, width).
|
|
208
|
+
t : torch.Tensor
|
|
209
|
+
Time steps, shape (batch_size,).
|
|
210
|
+
y : torch.Tensor, optional
|
|
211
|
+
Text embeddings for conditioning, shape (batch_size, seq_len, y_embed_dim)
|
|
212
|
+
or (batch_size, y_embed_dim) (default: None).
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
-------
|
|
216
|
+
torch.Tensor
|
|
217
|
+
Predicted noise, same shape as input `x`.
|
|
218
|
+
"""
|
|
219
|
+
if not self.where_y and y is not None:
|
|
220
|
+
x = torch.cat(tensors=[x, y], dim=1)
|
|
221
|
+
output = self.conv1(x)
|
|
222
|
+
time_embed = GetEmbeddedTime(embed_dim=self.time_embed_dim)(time_steps=t)
|
|
223
|
+
time_embed = self.time_projection(time_embed)
|
|
224
|
+
skip_connections = []
|
|
225
|
+
|
|
226
|
+
for i, down in enumerate(self.down_blocks):
|
|
227
|
+
skip_connections.append(output)
|
|
228
|
+
output = down(x=output, embed_time=time_embed, y=y)
|
|
229
|
+
for i, mid in enumerate(self.mid_blocks):
|
|
230
|
+
output = mid(x=output, embed_time=time_embed, y=y)
|
|
231
|
+
for i, up in enumerate(self.up_blocks):
|
|
232
|
+
skip_connection = skip_connections.pop()
|
|
233
|
+
output = up(x=output, skip_connection=skip_connection, embed_time=time_embed, y=y)
|
|
234
|
+
|
|
235
|
+
output = self.conv2(output)
|
|
236
|
+
return output
|
|
237
|
+
#-----------------------------------------------------------------------------
|
|
238
|
+
class DownBlock(nn.Module):
|
|
239
|
+
"""Downsampling block for NoisePredictor’s encoder.
|
|
240
|
+
|
|
241
|
+
Applies convolutional layers with residual connections, time embeddings, and optional
|
|
242
|
+
text-conditioned attention, followed by downsampling if enabled.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
in_channels : int
|
|
247
|
+
Number of input channels.
|
|
248
|
+
out_channels : int
|
|
249
|
+
Number of output channels.
|
|
250
|
+
time_embed_dim : int
|
|
251
|
+
Dimensionality of time embeddings.
|
|
252
|
+
y_embed_dim : int
|
|
253
|
+
Dimensionality of text embeddings.
|
|
254
|
+
num_layers : int
|
|
255
|
+
Number of convolutional layer pairs (Conv3).
|
|
256
|
+
down_sampling_factor : int
|
|
257
|
+
Factor for spatial downsampling.
|
|
258
|
+
down_sample : bool
|
|
259
|
+
If True, apply downsampling; if False, use identity (no downsampling).
|
|
260
|
+
dropout_rate : float
|
|
261
|
+
Dropout rate for Conv3 and attention layers.
|
|
262
|
+
y_to_all : bool
|
|
263
|
+
If True, apply text-conditioned attention to all layers; if False, only first layer.
|
|
264
|
+
|
|
265
|
+
Attributes
|
|
266
|
+
----------
|
|
267
|
+
num_layers : int
|
|
268
|
+
Number of convolutional layer pairs.
|
|
269
|
+
y_to_all : bool
|
|
270
|
+
Flag for text-conditioned attention scope.
|
|
271
|
+
conv1 : torch.nn.ModuleList
|
|
272
|
+
List of Conv3 layers for first convolution in each pair.
|
|
273
|
+
conv2 : torch.nn.ModuleList
|
|
274
|
+
List of Conv3 layers for second convolution in each pair.
|
|
275
|
+
time_embedding : torch.nn.ModuleList
|
|
276
|
+
List of TimeEmbedding modules for time conditioning.
|
|
277
|
+
attention : torch.nn.ModuleList
|
|
278
|
+
List of Attention modules for text conditioning or self-attention.
|
|
279
|
+
down_sampling : DownSampling or torch.nn.Identity
|
|
280
|
+
Downsampling module or identity if `down_sample=False`.
|
|
281
|
+
resnet : torch.nn.ModuleList
|
|
282
|
+
List of 1x1 convolutional layers for residual connections.
|
|
283
|
+
"""
|
|
284
|
+
def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim,num_layers, down_sampling_factor, down_sample, dropout_rate, y_to_all):
|
|
285
|
+
super().__init__()
|
|
286
|
+
self.num_layers = num_layers
|
|
287
|
+
self.y_to_all = y_to_all
|
|
288
|
+
self.conv1 = nn.ModuleList([
|
|
289
|
+
Conv3(
|
|
290
|
+
in_channels=in_channels if i==0 else out_channels,
|
|
291
|
+
out_channels=out_channels,
|
|
292
|
+
num_groups=8,
|
|
293
|
+
kernel_size=3,
|
|
294
|
+
norm=True,
|
|
295
|
+
activation=True,
|
|
296
|
+
dropout_rate=dropout_rate
|
|
297
|
+
) for i in range(self.num_layers)
|
|
298
|
+
])
|
|
299
|
+
self.conv2 = nn.ModuleList([
|
|
300
|
+
Conv3(
|
|
301
|
+
in_channels=out_channels,
|
|
302
|
+
out_channels=out_channels,
|
|
303
|
+
num_groups=8,
|
|
304
|
+
kernel_size=3,
|
|
305
|
+
norm=True,
|
|
306
|
+
activation=True,
|
|
307
|
+
dropout_rate=dropout_rate
|
|
308
|
+
) for _ in range(self.num_layers)
|
|
309
|
+
])
|
|
310
|
+
self.time_embedding = nn.ModuleList([
|
|
311
|
+
TimeEmbedding(
|
|
312
|
+
output_dim=out_channels,
|
|
313
|
+
embed_dim=time_embed_dim
|
|
314
|
+
) for _ in range(self.num_layers)
|
|
315
|
+
])
|
|
316
|
+
self.attention = nn.ModuleList([
|
|
317
|
+
Attention(
|
|
318
|
+
in_channels=out_channels,
|
|
319
|
+
y_embed_dim= y_embed_dim,
|
|
320
|
+
num_groups=8,
|
|
321
|
+
num_heads=4,
|
|
322
|
+
dropout_rate=dropout_rate
|
|
323
|
+
) for _ in range(self.num_layers)
|
|
324
|
+
])
|
|
325
|
+
self.down_sampling = DownSampling(
|
|
326
|
+
in_channels=out_channels,
|
|
327
|
+
out_channels=out_channels,
|
|
328
|
+
down_sampling_factor=down_sampling_factor,
|
|
329
|
+
conv_block=True,
|
|
330
|
+
max_pool=True
|
|
331
|
+
) if down_sample else nn.Identity()
|
|
332
|
+
self.resnet = nn.ModuleList([
|
|
333
|
+
nn.Conv2d(
|
|
334
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
335
|
+
out_channels=out_channels,
|
|
336
|
+
kernel_size=1
|
|
337
|
+
) for i in range(num_layers)
|
|
338
|
+
|
|
339
|
+
])
|
|
340
|
+
|
|
341
|
+
def forward(self, x, embed_time, y):
|
|
342
|
+
"""Processes input through convolutions, time embeddings, attention, and downsampling.
|
|
343
|
+
|
|
344
|
+
Parameters
|
|
345
|
+
----------
|
|
346
|
+
x : torch.Tensor
|
|
347
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
348
|
+
embed_time : torch.Tensor
|
|
349
|
+
Time embeddings, shape (batch_size, time_embed_dim).
|
|
350
|
+
y : torch.Tensor, optional
|
|
351
|
+
Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
|
|
352
|
+
(batch_size, y_embed_dim) (default: None).
|
|
353
|
+
|
|
354
|
+
Returns
|
|
355
|
+
-------
|
|
356
|
+
torch.Tensor
|
|
357
|
+
Output tensor, shape (batch_size, out_channels,
|
|
358
|
+
height/down_sampling_factor, width/down_sampling_factor) if downsampling;
|
|
359
|
+
otherwise, same height/width as input.
|
|
360
|
+
"""
|
|
361
|
+
output = x
|
|
362
|
+
for i in range(self.num_layers):
|
|
363
|
+
resnet_input = output
|
|
364
|
+
output = self.conv1[i](output)
|
|
365
|
+
output = output + self.time_embedding[i](embed_time)[:, :, None, None]
|
|
366
|
+
output = self.conv2[i](output)
|
|
367
|
+
output = output + self.resnet[i](resnet_input)
|
|
368
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
369
|
+
out_attn = self.attention[i](output, y)
|
|
370
|
+
output = output + out_attn
|
|
371
|
+
elif y is not None and self.y_to_all:
|
|
372
|
+
out_attn = self.attention[i](output, y)
|
|
373
|
+
output = output + out_attn
|
|
374
|
+
elif y is None and self.y_to_all:
|
|
375
|
+
out_attn = self.attention[i](output)
|
|
376
|
+
output = output + out_attn
|
|
377
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
378
|
+
out_attn = self.attention[i](output)
|
|
379
|
+
output = output + out_attn
|
|
380
|
+
|
|
381
|
+
output = self.down_sampling(output)
|
|
382
|
+
return output
|
|
383
|
+
#------------------------------------------------------------------------------
|
|
384
|
+
class MiddleBlock(nn.Module):
|
|
385
|
+
"""Bottleneck block for NoisePredictor’s middle layers.
|
|
386
|
+
|
|
387
|
+
Applies convolutional layers with residual connections, time embeddings, and optional
|
|
388
|
+
text-conditioned attention, preserving spatial dimensions.
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
in_channels : int
|
|
393
|
+
Number of input channels.
|
|
394
|
+
out_channels : int
|
|
395
|
+
Number of output channels.
|
|
396
|
+
time_embed_dim : int
|
|
397
|
+
Dimensionality of time embeddings.
|
|
398
|
+
y_embed_dim : int
|
|
399
|
+
Dimensionality of text embeddings.
|
|
400
|
+
num_layers : int
|
|
401
|
+
Number of convolutional layer pairs (Conv3).
|
|
402
|
+
dropout_rate : float
|
|
403
|
+
Dropout rate for Conv3 and attention layers.
|
|
404
|
+
y_to_all : bool, optional
|
|
405
|
+
If True, apply text-conditioned attention to all layers; if False, only first layer
|
|
406
|
+
(default: False).
|
|
407
|
+
|
|
408
|
+
Attributes
|
|
409
|
+
----------
|
|
410
|
+
num_layers : int
|
|
411
|
+
Number of convolutional layer pairs.
|
|
412
|
+
y_to_all : bool
|
|
413
|
+
Flag for text-conditioned attention scope.
|
|
414
|
+
conv1 : torch.nn.ModuleList
|
|
415
|
+
List of Conv3 layers for first convolution in each pair.
|
|
416
|
+
conv2 : torch.nn.ModuleList
|
|
417
|
+
List of Conv3 layers for second convolution in each pair.
|
|
418
|
+
time_embedding : torch.nn.ModuleList
|
|
419
|
+
List of TimeEmbedding modules for time conditioning.
|
|
420
|
+
attention : torch.nn.ModuleList
|
|
421
|
+
List of Attention modules for text conditioning or self-attention.
|
|
422
|
+
resnet : torch.nn.ModuleList
|
|
423
|
+
List of 1x1 convolutional layers for residual connections.
|
|
424
|
+
"""
|
|
425
|
+
def __init__(self, in_channels, out_channels, time_embed_dim, y_embed_dim, num_layers, dropout_rate, y_to_all=False):
|
|
426
|
+
super().__init__()
|
|
427
|
+
self.num_layers = num_layers
|
|
428
|
+
self.y_to_all = y_to_all
|
|
429
|
+
self.conv1 = nn.ModuleList([
|
|
430
|
+
Conv3(
|
|
431
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
432
|
+
out_channels=out_channels,
|
|
433
|
+
num_groups=8,
|
|
434
|
+
kernel_size=3,
|
|
435
|
+
norm=True,
|
|
436
|
+
activation=True,
|
|
437
|
+
dropout_rate=dropout_rate
|
|
438
|
+
) for i in range(self.num_layers+1)
|
|
439
|
+
])
|
|
440
|
+
self.conv2 = nn.ModuleList([
|
|
441
|
+
Conv3(
|
|
442
|
+
in_channels=out_channels,
|
|
443
|
+
out_channels=out_channels,
|
|
444
|
+
num_groups=8,
|
|
445
|
+
kernel_size=3,
|
|
446
|
+
norm=True,
|
|
447
|
+
activation=True,
|
|
448
|
+
dropout_rate=dropout_rate
|
|
449
|
+
) for _ in range(self.num_layers+1)
|
|
450
|
+
])
|
|
451
|
+
self.time_embedding = nn.ModuleList([
|
|
452
|
+
TimeEmbedding(
|
|
453
|
+
output_dim=out_channels,
|
|
454
|
+
embed_dim=time_embed_dim
|
|
455
|
+
) for _ in range(self.num_layers+1)
|
|
456
|
+
])
|
|
457
|
+
self.attention = nn.ModuleList([
|
|
458
|
+
Attention(
|
|
459
|
+
in_channels=out_channels,
|
|
460
|
+
y_embed_dim=y_embed_dim,
|
|
461
|
+
num_groups=8,
|
|
462
|
+
num_heads=4,
|
|
463
|
+
dropout_rate=dropout_rate
|
|
464
|
+
) for _ in range(self.num_layers + 1)
|
|
465
|
+
])
|
|
466
|
+
self.resnet = nn.ModuleList([
|
|
467
|
+
nn.Conv2d(
|
|
468
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
469
|
+
out_channels=out_channels,
|
|
470
|
+
kernel_size=1
|
|
471
|
+
) for i in range(num_layers+1)
|
|
472
|
+
])
|
|
473
|
+
|
|
474
|
+
def forward(self, x, embed_time, y=None):
|
|
475
|
+
"""Processes input through convolutions, time embeddings, and attention.
|
|
476
|
+
|
|
477
|
+
Parameters
|
|
478
|
+
----------
|
|
479
|
+
x : torch.Tensor
|
|
480
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
481
|
+
embed_time : torch.Tensor
|
|
482
|
+
Time embeddings, shape (batch_size, time_embed_dim).
|
|
483
|
+
y : torch.Tensor, optional
|
|
484
|
+
Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
|
|
485
|
+
(batch_size, y_embed_dim) (default: None).
|
|
486
|
+
|
|
487
|
+
Returns
|
|
488
|
+
-------
|
|
489
|
+
torch.Tensor
|
|
490
|
+
Output tensor, shape (batch_size, out_channels, height, width).
|
|
491
|
+
"""
|
|
492
|
+
output = x
|
|
493
|
+
resnet_input = output
|
|
494
|
+
output = self.conv1[0](output)
|
|
495
|
+
output = output + self.time_embedding[0](embed_time)[:, :, None, None]
|
|
496
|
+
output = self.conv2[0](output)
|
|
497
|
+
output = output + self.resnet[0](resnet_input)
|
|
498
|
+
for i in range(self.num_layers):
|
|
499
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
500
|
+
out_attn = self.attention[i](output, y)
|
|
501
|
+
output = output + out_attn
|
|
502
|
+
elif y is not None and self.y_to_all:
|
|
503
|
+
out_attn = self.attention[i](output, y)
|
|
504
|
+
output = output + out_attn
|
|
505
|
+
elif y is None and self.y_to_all:
|
|
506
|
+
out_attn = self.attention[i](output)
|
|
507
|
+
output = output + out_attn
|
|
508
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
509
|
+
out_attn = self.attention[i](output)
|
|
510
|
+
output = output + out_attn
|
|
511
|
+
resnet_input = output
|
|
512
|
+
output = self.conv1[i + 1](output)
|
|
513
|
+
output = output + self.time_embedding[i + 1](embed_time)[:, :, None, None]
|
|
514
|
+
output = self.conv2[i + 1](output)
|
|
515
|
+
output = output + self.resnet[i+1](resnet_input)
|
|
516
|
+
return output
|
|
517
|
+
#------------------------------------------------------------------------------
|
|
518
|
+
class UpBlock(nn.Module):
|
|
519
|
+
"""Upsampling block for NoisePredictor’s decoder.
|
|
520
|
+
|
|
521
|
+
Applies upsampling (if enabled), concatenates skip connections, and processes through
|
|
522
|
+
convolutional layers with residual connections, time embeddings, and optional
|
|
523
|
+
text-conditioned attention.
|
|
524
|
+
|
|
525
|
+
Parameters
|
|
526
|
+
----------
|
|
527
|
+
in_channels : int
|
|
528
|
+
Number of input channels (before upsampling).
|
|
529
|
+
out_channels : int
|
|
530
|
+
Number of output channels.
|
|
531
|
+
skip_channels : int
|
|
532
|
+
Number of channels from skip connection.
|
|
533
|
+
time_embed_dim : int
|
|
534
|
+
Dimensionality of time embeddings.
|
|
535
|
+
y_embed_dim : int
|
|
536
|
+
Dimensionality of text embeddings.
|
|
537
|
+
num_layers : int
|
|
538
|
+
Number of convolutional layer pairs (Conv3).
|
|
539
|
+
up_sampling_factor : int
|
|
540
|
+
Factor for spatial upsampling.
|
|
541
|
+
up_sampling : bool
|
|
542
|
+
If True, apply upsampling; if False, use identity (no upsampling).
|
|
543
|
+
dropout_rate : float
|
|
544
|
+
Dropout rate for Conv3 and attention layers.
|
|
545
|
+
y_to_all : bool, optional
|
|
546
|
+
If True, apply text-conditioned attention to all layers; if False, only first layer
|
|
547
|
+
(default: False).
|
|
548
|
+
|
|
549
|
+
Attributes
|
|
550
|
+
----------
|
|
551
|
+
num_layers : int
|
|
552
|
+
Number of convolutional layer pairs.
|
|
553
|
+
y_to_all : bool
|
|
554
|
+
Flag for text-conditioned attention scope.
|
|
555
|
+
conv1 : torch.nn.ModuleList
|
|
556
|
+
List of Conv3 layers for first convolution in each pair.
|
|
557
|
+
conv2 : torch.nn.ModuleList
|
|
558
|
+
List of Conv3 layers for second convolution in each pair.
|
|
559
|
+
time_embedding : torch.nn.ModuleList
|
|
560
|
+
List of TimeEmbedding modules for time conditioning.
|
|
561
|
+
attention : torch.nn.ModuleList
|
|
562
|
+
List of Attention modules for text conditioning or self-attention.
|
|
563
|
+
up_sampling : UpSampling or torch.nn.Identity
|
|
564
|
+
Upsampling module or identity if `up_sampling=False`.
|
|
565
|
+
resnet : torch.nn.ModuleList
|
|
566
|
+
List of 1x1 convolutional layers for residual connections.
|
|
567
|
+
"""
|
|
568
|
+
def __init__(self, in_channels, out_channels, skip_channels, time_embed_dim, y_embed_dim, num_layers, up_sampling_factor, up_sampling=True, dropout_rate=0.2, y_to_all=False):
|
|
569
|
+
super().__init__()
|
|
570
|
+
self.num_layers = num_layers
|
|
571
|
+
self.y_to_all = y_to_all
|
|
572
|
+
effective_in_channels = in_channels//2 + skip_channels
|
|
573
|
+
self.conv1 = nn.ModuleList([
|
|
574
|
+
Conv3(
|
|
575
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
576
|
+
out_channels=out_channels,
|
|
577
|
+
num_groups=8,
|
|
578
|
+
kernel_size=3,
|
|
579
|
+
norm=True,
|
|
580
|
+
activation=True,
|
|
581
|
+
dropout_rate=dropout_rate
|
|
582
|
+
) for i in range(self.num_layers)
|
|
583
|
+
])
|
|
584
|
+
self.conv2 = nn.ModuleList([
|
|
585
|
+
Conv3(
|
|
586
|
+
in_channels=out_channels,
|
|
587
|
+
out_channels=out_channels,
|
|
588
|
+
num_groups=8,
|
|
589
|
+
kernel_size=3,
|
|
590
|
+
norm=True,
|
|
591
|
+
activation=True,
|
|
592
|
+
dropout_rate=dropout_rate
|
|
593
|
+
) for _ in range(self.num_layers)
|
|
594
|
+
])
|
|
595
|
+
self.time_embedding = nn.ModuleList([
|
|
596
|
+
TimeEmbedding(
|
|
597
|
+
output_dim=out_channels,
|
|
598
|
+
embed_dim=time_embed_dim
|
|
599
|
+
) for _ in range(self.num_layers)
|
|
600
|
+
])
|
|
601
|
+
self.attention = nn.ModuleList([
|
|
602
|
+
Attention(
|
|
603
|
+
in_channels=out_channels,
|
|
604
|
+
y_embed_dim=y_embed_dim,
|
|
605
|
+
num_groups=8,
|
|
606
|
+
num_heads=4,
|
|
607
|
+
dropout_rate=dropout_rate
|
|
608
|
+
) for _ in range(self.num_layers)
|
|
609
|
+
])
|
|
610
|
+
self.up_sampling = UpSampling(
|
|
611
|
+
in_channels=in_channels,
|
|
612
|
+
out_channels=in_channels,
|
|
613
|
+
up_sampling_factor=up_sampling_factor,
|
|
614
|
+
conv_block=True,
|
|
615
|
+
up_sampling=True
|
|
616
|
+
) if up_sampling else nn.Identity()
|
|
617
|
+
self.resnet = nn.ModuleList([
|
|
618
|
+
nn.Conv2d(
|
|
619
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
620
|
+
out_channels=out_channels,
|
|
621
|
+
kernel_size=1
|
|
622
|
+
) for i in range(num_layers)
|
|
623
|
+
|
|
624
|
+
])
|
|
625
|
+
|
|
626
|
+
def forward(self, x, skip_connection, embed_time, y=None):
|
|
627
|
+
"""Processes input through upsampling, skip connection, convolutions, time embeddings, and attention.
|
|
628
|
+
|
|
629
|
+
Parameters
|
|
630
|
+
----------
|
|
631
|
+
x : torch.Tensor
|
|
632
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
633
|
+
skip_connection : torch.Tensor
|
|
634
|
+
Skip connection tensor, shape (batch_size, skip_channels,
|
|
635
|
+
height*up_sampling_factor, width*up_sampling_factor).
|
|
636
|
+
embed_time : torch.Tensor
|
|
637
|
+
Time embeddings, shape (batch_size, time_embed_dim).
|
|
638
|
+
y : torch.Tensor, optional
|
|
639
|
+
Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
|
|
640
|
+
(batch_size, y_embed_dim) (default: None).
|
|
641
|
+
|
|
642
|
+
Returns
|
|
643
|
+
-------
|
|
644
|
+
torch.Tensor
|
|
645
|
+
Output tensor, shape (batch_size, out_channels,
|
|
646
|
+
height*up_sampling_factor, width*up_sampling_factor) if upsampling;
|
|
647
|
+
otherwise, same height/width as input (after skip connection).
|
|
648
|
+
"""
|
|
649
|
+
x = self.up_sampling(x)
|
|
650
|
+
x = torch.cat(tensors=[x, skip_connection], dim=1)
|
|
651
|
+
output = x
|
|
652
|
+
for i in range(self.num_layers):
|
|
653
|
+
resnet_input = output
|
|
654
|
+
output = self.conv1[i](output)
|
|
655
|
+
output = output + self.time_embedding[i](embed_time)[:, :, None, None]
|
|
656
|
+
output = self.conv2[i](output)
|
|
657
|
+
output = output + self.resnet[i](resnet_input)
|
|
658
|
+
if y is not None and not self.y_to_all and i == 0:
|
|
659
|
+
out_attn = self.attention[i](output, y)
|
|
660
|
+
output = output + out_attn
|
|
661
|
+
elif y is not None and self.y_to_all:
|
|
662
|
+
out_attn = self.attention[i](output, y)
|
|
663
|
+
output = output + out_attn
|
|
664
|
+
elif y is None and self.y_to_all:
|
|
665
|
+
out_attn = self.attention[i](output)
|
|
666
|
+
output = output + out_attn
|
|
667
|
+
elif y is None and not self.y_to_all and i == 0:
|
|
668
|
+
out_attn = self.attention[i](output)
|
|
669
|
+
output = output + out_attn
|
|
670
|
+
#print("up-block output shape:", output.size())
|
|
671
|
+
return output
|
|
672
|
+
#------------------------------------------------------------------------
|
|
673
|
+
class Conv3(nn.Module):
|
|
674
|
+
"""Convolutional layer with optional group normalization, SiLU activation, and dropout.
|
|
675
|
+
|
|
676
|
+
Used in DownBlock, MiddleBlock, and UpBlock for feature extraction in NoisePredictor.
|
|
677
|
+
|
|
678
|
+
Parameters
|
|
679
|
+
----------
|
|
680
|
+
in_channels : int
|
|
681
|
+
Number of input channels.
|
|
682
|
+
out_channels : int
|
|
683
|
+
Number of output channels.
|
|
684
|
+
num_groups : int, optional
|
|
685
|
+
Number of groups for group normalization (default: 8).
|
|
686
|
+
kernel_size : int, optional
|
|
687
|
+
Convolutional kernel size (default: 3).
|
|
688
|
+
norm : bool, optional
|
|
689
|
+
If True, apply group normalization (default: True).
|
|
690
|
+
activation : bool, optional
|
|
691
|
+
If True, apply SiLU activation (default: True).
|
|
692
|
+
dropout_rate : float, optional
|
|
693
|
+
Dropout rate (default: 0.2).
|
|
694
|
+
|
|
695
|
+
Attributes
|
|
696
|
+
----------
|
|
697
|
+
conv : torch.nn.Conv2d
|
|
698
|
+
Convolutional layer with specified kernel size and padding.
|
|
699
|
+
group_norm : torch.nn.GroupNorm or torch.nn.Identity
|
|
700
|
+
Group normalization or identity if `norm=False`.
|
|
701
|
+
activation : torch.nn.SiLU or torch.nn.Identity
|
|
702
|
+
SiLU activation or identity if `activation=False`.
|
|
703
|
+
dropout : torch.nn.Dropout
|
|
704
|
+
Dropout layer.
|
|
705
|
+
"""
|
|
706
|
+
def __init__(self, in_channels, out_channels, num_groups=8, kernel_size=3, norm=True, activation=True, dropout_rate=0.2):
|
|
707
|
+
super().__init__()
|
|
708
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
|
|
709
|
+
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels) if norm else nn.Identity()
|
|
710
|
+
self.activation = nn.SiLU() if activation else nn.Identity()
|
|
711
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
712
|
+
|
|
713
|
+
def forward(self, batch):
|
|
714
|
+
"""Processes input through convolution, normalization, activation, and dropout.
|
|
715
|
+
|
|
716
|
+
Parameters
|
|
717
|
+
----------
|
|
718
|
+
batch : torch.Tensor
|
|
719
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
720
|
+
|
|
721
|
+
Returns
|
|
722
|
+
-------
|
|
723
|
+
torch.Tensor
|
|
724
|
+
Output tensor, shape (batch_size, out_channels, height, width).
|
|
725
|
+
"""
|
|
726
|
+
batch = self.conv(batch)
|
|
727
|
+
batch = self.group_norm(batch)
|
|
728
|
+
batch = self.activation(batch)
|
|
729
|
+
batch = self.dropout(batch)
|
|
730
|
+
return batch
|
|
731
|
+
#----------------------------------------------------------------
|
|
732
|
+
class TimeEmbedding(nn.Module):
|
|
733
|
+
"""Time embedding projection for conditioning NoisePredictor layers.
|
|
734
|
+
|
|
735
|
+
Projects time embeddings to match the channel dimension of convolutional outputs.
|
|
736
|
+
|
|
737
|
+
Parameters
|
|
738
|
+
----------
|
|
739
|
+
output_dim : int
|
|
740
|
+
Output channel dimension (matches convolutional channels).
|
|
741
|
+
embed_dim : int
|
|
742
|
+
Input time embedding dimension.
|
|
743
|
+
|
|
744
|
+
Attributes
|
|
745
|
+
----------
|
|
746
|
+
embedding : torch.nn.Sequential
|
|
747
|
+
Sequential layer with SiLU activation and linear projection.
|
|
748
|
+
"""
|
|
749
|
+
def __init__(self, output_dim, embed_dim):
|
|
750
|
+
super().__init__()
|
|
751
|
+
self.embedding = nn.Sequential(
|
|
752
|
+
nn.SiLU(),
|
|
753
|
+
nn.Linear(in_features=embed_dim, out_features=output_dim)
|
|
754
|
+
)
|
|
755
|
+
def forward(self, batch):
|
|
756
|
+
"""Projects time embeddings to output dimension.
|
|
757
|
+
|
|
758
|
+
Parameters
|
|
759
|
+
----------
|
|
760
|
+
batch : torch.Tensor
|
|
761
|
+
Time embeddings, shape (batch_size, embed_dim).
|
|
762
|
+
|
|
763
|
+
Returns
|
|
764
|
+
-------
|
|
765
|
+
torch.Tensor
|
|
766
|
+
Projected embeddings, shape (batch_size, output_dim).
|
|
767
|
+
"""
|
|
768
|
+
return self.embedding(batch)
|
|
769
|
+
#----------------------------------------------------------------
|
|
770
|
+
class GetEmbeddedTime(nn.Module):
|
|
771
|
+
"""Generates sinusoidal time embeddings for NoisePredictor.
|
|
772
|
+
|
|
773
|
+
Creates positional encodings for time steps using sine and cosine functions, following
|
|
774
|
+
the transformer embedding approach.
|
|
775
|
+
|
|
776
|
+
Parameters
|
|
777
|
+
----------
|
|
778
|
+
embed_dim : int
|
|
779
|
+
Dimensionality of the time embeddings (must be even).
|
|
780
|
+
|
|
781
|
+
Attributes
|
|
782
|
+
----------
|
|
783
|
+
embed_dim : int
|
|
784
|
+
Time embedding dimension.
|
|
785
|
+
|
|
786
|
+
Raises
|
|
787
|
+
------
|
|
788
|
+
AssertionError
|
|
789
|
+
If `embed_dim` is not divisible by 2.
|
|
790
|
+
"""
|
|
791
|
+
def __init__(self, embed_dim):
|
|
792
|
+
super().__init__()
|
|
793
|
+
assert embed_dim % 2 == 0, "The embedding dimension must be divisible by two"
|
|
794
|
+
self.embed_dim = embed_dim
|
|
795
|
+
|
|
796
|
+
def forward(self, time_steps):
|
|
797
|
+
"""Generates sinusoidal embeddings for time steps.
|
|
798
|
+
|
|
799
|
+
Parameters
|
|
800
|
+
----------
|
|
801
|
+
time_steps : torch.Tensor
|
|
802
|
+
Time steps, shape (batch_size,).
|
|
803
|
+
|
|
804
|
+
Returns
|
|
805
|
+
-------
|
|
806
|
+
torch.Tensor
|
|
807
|
+
Sinusoidal embeddings, shape (batch_size, embed_dim).
|
|
808
|
+
"""
|
|
809
|
+
i = torch.arange(start=0, end=self.embed_dim // 2, dtype=torch.float32, device=time_steps.device)
|
|
810
|
+
factor = 10000 ** (2 * i / self.embed_dim)
|
|
811
|
+
embed_time = time_steps[:, None] / factor
|
|
812
|
+
embed_time = torch.cat(tensors=[torch.sin(embed_time), torch.cos(embed_time)], dim=-1)
|
|
813
|
+
return embed_time
|
|
814
|
+
#----------------------------------------------------------------
|
|
815
|
+
class Attention(nn.Module):
|
|
816
|
+
"""Attention module for NoisePredictor, supporting text conditioning or self-attention.
|
|
817
|
+
|
|
818
|
+
Applies multi-head attention to enhance features, with optional text embeddings for
|
|
819
|
+
conditional generation.
|
|
820
|
+
|
|
821
|
+
Parameters
|
|
822
|
+
----------
|
|
823
|
+
in_channels : int
|
|
824
|
+
Number of input channels (embedding dimension for attention).
|
|
825
|
+
y_embed_dim : int, optional
|
|
826
|
+
Dimensionality of text embeddings (default: 768).
|
|
827
|
+
num_heads : int, optional
|
|
828
|
+
Number of attention heads (default: 4).
|
|
829
|
+
num_groups : int, optional
|
|
830
|
+
Number of groups for group normalization (default: 8).
|
|
831
|
+
dropout_rate : float, optional
|
|
832
|
+
Dropout rate for attention and output (default: 0.1).
|
|
833
|
+
|
|
834
|
+
Attributes
|
|
835
|
+
----------
|
|
836
|
+
in_channels : int
|
|
837
|
+
Input channel dimension.
|
|
838
|
+
y_embed_dim : int
|
|
839
|
+
Text embedding dimension.
|
|
840
|
+
num_heads : int
|
|
841
|
+
Number of attention heads.
|
|
842
|
+
dropout_rate : float
|
|
843
|
+
Dropout rate.
|
|
844
|
+
attention : torch.nn.MultiheadAttention
|
|
845
|
+
Multi-head attention with `batch_first=True`.
|
|
846
|
+
norm : torch.nn.GroupNorm
|
|
847
|
+
Group normalization before attention.
|
|
848
|
+
dropout : torch.nn.Dropout
|
|
849
|
+
Dropout layer for output.
|
|
850
|
+
y_projection : torch.nn.Linear
|
|
851
|
+
Projection for text embeddings to match `in_channels`.
|
|
852
|
+
|
|
853
|
+
Raises
|
|
854
|
+
------
|
|
855
|
+
AssertionError
|
|
856
|
+
If input channels do not match `in_channels`.
|
|
857
|
+
ValueError
|
|
858
|
+
If text embeddings (`y`) have incorrect dimensions after projection.
|
|
859
|
+
"""
|
|
860
|
+
def __init__(self, in_channels, y_embed_dim=768, num_heads=4, num_groups=8, dropout_rate=0.1):
|
|
861
|
+
super().__init__()
|
|
862
|
+
self.in_channels = in_channels
|
|
863
|
+
self.y_embed_dim = y_embed_dim
|
|
864
|
+
self.num_heads = num_heads
|
|
865
|
+
self.dropout_rate = dropout_rate
|
|
866
|
+
self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True)
|
|
867
|
+
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
|
|
868
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
869
|
+
self.y_projection = nn.Linear(y_embed_dim, in_channels)
|
|
870
|
+
|
|
871
|
+
def forward(self, x, y=None):
|
|
872
|
+
"""Applies attention to input features with optional text conditioning.
|
|
873
|
+
|
|
874
|
+
Parameters
|
|
875
|
+
----------
|
|
876
|
+
x : torch.Tensor
|
|
877
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
878
|
+
y : torch.Tensor, optional
|
|
879
|
+
Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
|
|
880
|
+
(batch_size, y_embed_dim) (default: None).
|
|
881
|
+
|
|
882
|
+
Returns
|
|
883
|
+
-------
|
|
884
|
+
torch.Tensor
|
|
885
|
+
Output tensor, same shape as input `x`.
|
|
886
|
+
"""
|
|
887
|
+
batch_size, channels, h, w = x.shape
|
|
888
|
+
assert channels == self.in_channels, f"Expected {self.in_channels} channels, got {channels}"
|
|
889
|
+
x_reshaped = x.view(batch_size, channels, h * w).permute(0, 2, 1)
|
|
890
|
+
if y is not None:
|
|
891
|
+
y = self.y_projection(y)
|
|
892
|
+
if y.dim() != 3:
|
|
893
|
+
if y.dim() == 2:
|
|
894
|
+
y = y.unsqueeze(1)
|
|
895
|
+
else:
|
|
896
|
+
raise ValueError(
|
|
897
|
+
f"Expected y to be 2D or 3D after projection, got {y.dim()}D with shape {y.shape}"
|
|
898
|
+
)
|
|
899
|
+
if y.shape[-1] != self.in_channels:
|
|
900
|
+
raise ValueError(
|
|
901
|
+
f"Expected y's embedding dim to match in_channels ({self.in_channels}), got {y.shape[-1]}"
|
|
902
|
+
)
|
|
903
|
+
out, _ = self.attention(x_reshaped, y, y)
|
|
904
|
+
else:
|
|
905
|
+
out, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
|
|
906
|
+
out = out.permute(0, 2, 1).view(batch_size, channels, h, w)
|
|
907
|
+
out = self.norm(out)
|
|
908
|
+
out = self.dropout(out)
|
|
909
|
+
return out
|
|
910
|
+
#-----------------------------------------------------------------
|
|
911
|
+
class DownSampling(nn.Module):
|
|
912
|
+
"""Downsampling module for NoisePredictor’s DownBlock.
|
|
913
|
+
|
|
914
|
+
Combines convolutional downsampling and max pooling (if enabled), concatenating
|
|
915
|
+
outputs to preserve feature information.
|
|
916
|
+
|
|
917
|
+
Parameters
|
|
918
|
+
----------
|
|
919
|
+
in_channels : int
|
|
920
|
+
Number of input channels.
|
|
921
|
+
out_channels : int
|
|
922
|
+
Number of output channels.
|
|
923
|
+
down_sampling_factor : int
|
|
924
|
+
Factor for spatial downsampling.
|
|
925
|
+
conv_block : bool, optional
|
|
926
|
+
If True, include convolutional path (default: True).
|
|
927
|
+
max_pool : bool, optional
|
|
928
|
+
If True, include max pooling path (default: True).
|
|
929
|
+
|
|
930
|
+
Attributes
|
|
931
|
+
----------
|
|
932
|
+
conv_block : bool
|
|
933
|
+
Flag for convolutional path.
|
|
934
|
+
max_pool : bool
|
|
935
|
+
Flag for max pooling path.
|
|
936
|
+
down_sampling_factor : int
|
|
937
|
+
Downsampling factor.
|
|
938
|
+
conv : torch.nn.Sequential or torch.nn.Identity
|
|
939
|
+
Convolutional path or identity if `conv_block=False`.
|
|
940
|
+
pool : torch.nn.Sequential or torch.nn.Identity
|
|
941
|
+
Max pooling path or identity if `max_pool=False`.
|
|
942
|
+
"""
|
|
943
|
+
def __init__(self, in_channels, out_channels, down_sampling_factor, conv_block=True, max_pool=True):
|
|
944
|
+
super().__init__()
|
|
945
|
+
self.conv_block = conv_block
|
|
946
|
+
self.max_pool = max_pool
|
|
947
|
+
self.down_sampling_factor = down_sampling_factor
|
|
948
|
+
self.conv = nn.Sequential(
|
|
949
|
+
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
|
|
950
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2 if max_pool else out_channels,
|
|
951
|
+
kernel_size=3, stride=down_sampling_factor, padding=1)
|
|
952
|
+
) if conv_block else nn.Identity()
|
|
953
|
+
self.pool = nn.Sequential(
|
|
954
|
+
nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
|
|
955
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels//2 if conv_block else out_channels,
|
|
956
|
+
kernel_size=1, stride=1, padding=0)
|
|
957
|
+
) if max_pool else nn.Identity()
|
|
958
|
+
|
|
959
|
+
def forward(self, batch):
|
|
960
|
+
"""Downsamples input using convolutional and/or pooling paths.
|
|
961
|
+
|
|
962
|
+
Parameters
|
|
963
|
+
----------
|
|
964
|
+
batch : torch.Tensor
|
|
965
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
966
|
+
|
|
967
|
+
Returns
|
|
968
|
+
-------
|
|
969
|
+
torch.Tensor
|
|
970
|
+
Downsampled tensor, shape (batch_size, out_channels,
|
|
971
|
+
height/down_sampling_factor, width/down_sampling_factor).
|
|
972
|
+
"""
|
|
973
|
+
if not self.conv_block:
|
|
974
|
+
return self.pool(batch)
|
|
975
|
+
if not self.max_pool:
|
|
976
|
+
return self.conv(batch)
|
|
977
|
+
return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
|
|
978
|
+
#--------------------------------------------------------------------------
|
|
979
|
+
class UpSampling(nn.Module):
|
|
980
|
+
"""Upsampling module for NoisePredictor’s UpBlock.
|
|
981
|
+
|
|
982
|
+
Combines transposed convolution and nearest-neighbor upsampling (if enabled),
|
|
983
|
+
concatenating outputs to preserve feature information, with interpolation to align
|
|
984
|
+
spatial dimensions if needed.
|
|
985
|
+
|
|
986
|
+
Parameters
|
|
987
|
+
----------
|
|
988
|
+
in_channels : int
|
|
989
|
+
Number of input channels.
|
|
990
|
+
out_channels : int
|
|
991
|
+
Number of output channels.
|
|
992
|
+
up_sampling_factor : int
|
|
993
|
+
Factor for spatial upsampling.
|
|
994
|
+
conv_block : bool, optional
|
|
995
|
+
If True, include transposed convolutional path (default: True).
|
|
996
|
+
up_sampling : bool, optional
|
|
997
|
+
If True, include nearest-neighbor upsampling path (default: True).
|
|
998
|
+
|
|
999
|
+
Attributes
|
|
1000
|
+
----------
|
|
1001
|
+
conv_block : bool
|
|
1002
|
+
Flag for convolutional path.
|
|
1003
|
+
up_sampling : bool
|
|
1004
|
+
Flag for upsampling path.
|
|
1005
|
+
up_sampling_factor : int
|
|
1006
|
+
Upsampling factor.
|
|
1007
|
+
conv : torch.nn.Sequential or torch.nn.Identity
|
|
1008
|
+
Transposed convolutional path or identity if `conv_block=False`.
|
|
1009
|
+
up_sample : torch.nn.Sequential or torch.nn.Identity
|
|
1010
|
+
Nearest-neighbor upsampling path or identity if `up_sampling=False`.
|
|
1011
|
+
"""
|
|
1012
|
+
def __init__(self, in_channels, out_channels, up_sampling_factor, conv_block=True, up_sampling=True):
|
|
1013
|
+
super().__init__()
|
|
1014
|
+
self.conv_block = conv_block
|
|
1015
|
+
self.up_sampling = up_sampling
|
|
1016
|
+
self.up_sampling_factor = up_sampling_factor
|
|
1017
|
+
half_out_channels = out_channels // 2
|
|
1018
|
+
self.conv = nn.Sequential(
|
|
1019
|
+
nn.ConvTranspose2d(
|
|
1020
|
+
in_channels=in_channels,
|
|
1021
|
+
out_channels=half_out_channels if up_sampling else out_channels,
|
|
1022
|
+
kernel_size=3,
|
|
1023
|
+
stride=up_sampling_factor,
|
|
1024
|
+
padding=1,
|
|
1025
|
+
output_padding=up_sampling_factor - 1
|
|
1026
|
+
),
|
|
1027
|
+
nn.Conv2d(
|
|
1028
|
+
in_channels=half_out_channels if up_sampling else out_channels,
|
|
1029
|
+
out_channels=half_out_channels if up_sampling else out_channels,
|
|
1030
|
+
kernel_size=1,
|
|
1031
|
+
stride=1,
|
|
1032
|
+
padding=0
|
|
1033
|
+
)
|
|
1034
|
+
) if conv_block else nn.Identity()
|
|
1035
|
+
|
|
1036
|
+
self.up_sample = nn.Sequential(
|
|
1037
|
+
nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
|
|
1038
|
+
nn.Conv2d(in_channels=in_channels, out_channels=half_out_channels if conv_block else out_channels,
|
|
1039
|
+
kernel_size=1, stride=1, padding=0)
|
|
1040
|
+
) if up_sampling else nn.Identity()
|
|
1041
|
+
|
|
1042
|
+
def forward(self, batch):
|
|
1043
|
+
"""Upsamples input using convolutional and/or upsampling paths.
|
|
1044
|
+
|
|
1045
|
+
Parameters
|
|
1046
|
+
----------
|
|
1047
|
+
batch : torch.Tensor
|
|
1048
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
1049
|
+
|
|
1050
|
+
Returns
|
|
1051
|
+
-------
|
|
1052
|
+
torch.Tensor
|
|
1053
|
+
Upsampled tensor, shape (batch_size, out_channels,
|
|
1054
|
+
height*up_sampling_factor, width*up_sampling_factor).
|
|
1055
|
+
|
|
1056
|
+
Notes
|
|
1057
|
+
-----
|
|
1058
|
+
- Interpolation is applied if the spatial dimensions of the convolutional and
|
|
1059
|
+
upsampling paths differ, using nearest-neighbor mode.
|
|
1060
|
+
"""
|
|
1061
|
+
if not self.conv_block:
|
|
1062
|
+
return self.up_sample(batch)
|
|
1063
|
+
if not self.up_sampling:
|
|
1064
|
+
return self.conv(batch)
|
|
1065
|
+
conv_output = self.conv(batch)
|
|
1066
|
+
up_sample_output = self.up_sample(batch)
|
|
1067
|
+
if conv_output.shape[2:] != up_sample_output.shape[2:]:
|
|
1068
|
+
_, _, h, w = conv_output.shape
|
|
1069
|
+
up_sample_output = torch.nn.functional.interpolate(
|
|
1070
|
+
up_sample_output,
|
|
1071
|
+
size=(h, w),
|
|
1072
|
+
mode='nearest'
|
|
1073
|
+
)
|
|
1074
|
+
return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
|