TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
torchdiff/utils.py
ADDED
|
@@ -0,0 +1,1660 @@
|
|
|
1
|
+
"""
|
|
2
|
+
**Utilities for text encoding, noise prediction, and evaluation in diffusion models**
|
|
3
|
+
|
|
4
|
+
This module provides core components for building diffusion model pipelines, including
|
|
5
|
+
text encoding, used as conditional model, U-Net-based noise prediction, and image quality evaluation. These
|
|
6
|
+
utilities support various diffusion model architectures, such as DDPM, DDIM, LDM, and
|
|
7
|
+
SDE, and are designed for standalone use in model training and sampling.
|
|
8
|
+
|
|
9
|
+
**Primary Components**
|
|
10
|
+
|
|
11
|
+
- **TextEncoder**: Encodes text prompts into embeddings using a pre-trained BERT model or a custom transformer.
|
|
12
|
+
- **NoisePredictor**: U-Net-like architecture for predicting noise in diffusion models, supporting time and text conditioning.
|
|
13
|
+
- **Metrics**: Computes image quality metrics (MSE, PSNR, SSIM, FID, LPIPS) for evaluating generated images.
|
|
14
|
+
|
|
15
|
+
**Notes**
|
|
16
|
+
|
|
17
|
+
- The primary components are intended to be imported directly for use in diffusion model workflows.
|
|
18
|
+
- Additional supporting classes and functions in this module provide internal functionality for the primary components.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
---------------------------------------------------------------------------------
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn as nn
|
|
27
|
+
import torch.nn.functional as F
|
|
28
|
+
from pytorch_fid import fid_score
|
|
29
|
+
from transformers import BertModel
|
|
30
|
+
import os
|
|
31
|
+
import math
|
|
32
|
+
import shutil
|
|
33
|
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
|
34
|
+
from torchvision.utils import save_image
|
|
35
|
+
from typing import Optional, Tuple, List
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
###==================================================================================================================###
|
|
39
|
+
|
|
40
|
+
class TextEncoder(torch.nn.Module):
|
|
41
|
+
"""Transformer-based encoder for text prompts in conditional diffusion models.
|
|
42
|
+
|
|
43
|
+
Encodes text prompts into embeddings using either a pre-trained BERT model or a
|
|
44
|
+
custom transformer architecture. Used as the `conditional_model` in diffusion models
|
|
45
|
+
(e.g., DDPM, DDIM, SDE, LDM) to provide conditional inputs for noise prediction.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
use_pretrained_model : bool, optional
|
|
50
|
+
If True, uses a pre-trained BERT model; otherwise, builds a custom transformer
|
|
51
|
+
(default: True).
|
|
52
|
+
model_name : str, optional
|
|
53
|
+
Name of the pre-trained model to load (default: "bert-base-uncased").
|
|
54
|
+
vocabulary_size : int, optional
|
|
55
|
+
Size of the vocabulary for the custom transformer’s embedding layer
|
|
56
|
+
(default: 30522).
|
|
57
|
+
num_layers : int, optional
|
|
58
|
+
Number of transformer encoder layers for the custom transformer (default: 6).
|
|
59
|
+
input_dimension : int, optional
|
|
60
|
+
Input embedding dimension for the custom transformer (default: 768).
|
|
61
|
+
output_dimension : int, optional
|
|
62
|
+
Output embedding dimension for both pre-trained and custom models
|
|
63
|
+
(default: 768).
|
|
64
|
+
num_heads : int, optional
|
|
65
|
+
Number of attention heads in the custom transformer (default: 8).
|
|
66
|
+
context_length : int, optional
|
|
67
|
+
Maximum sequence length for text prompts (default: 77).
|
|
68
|
+
dropout_rate : float, optional
|
|
69
|
+
Dropout rate for attention and feedforward layers (default: 0.1).
|
|
70
|
+
qkv_bias : bool, optional
|
|
71
|
+
If True, includes bias in query, key, and value projections for the custom
|
|
72
|
+
transformer (default: False).
|
|
73
|
+
scaling_value : int, optional
|
|
74
|
+
Scaling factor for the feedforward layer’s hidden dimension in the custom
|
|
75
|
+
transformer (default: 4).
|
|
76
|
+
epsilon : float, optional
|
|
77
|
+
Epsilon for layer normalization in the custom transformer (default: 1e-5).
|
|
78
|
+
use_learned_pos : bool, optional
|
|
79
|
+
If True, in the transformer structure uses learnable positional embeddings instead of sinusoidal encodings
|
|
80
|
+
(default: False).
|
|
81
|
+
|
|
82
|
+
**Notes**
|
|
83
|
+
|
|
84
|
+
- When `use_pretrained_model` is True, the BERT model’s parameters are frozen
|
|
85
|
+
(`requires_grad = False`), and a projection layer maps outputs to
|
|
86
|
+
`output_dimension`.
|
|
87
|
+
- The custom transformer uses `EncoderLayer` modules with multi-head attention and
|
|
88
|
+
feedforward networks, supporting variable input/output dimensions.
|
|
89
|
+
- The output shape is (batch_size, context_length, output_dimension).
|
|
90
|
+
"""
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
use_pretrained_model: bool = True,
|
|
94
|
+
model_name: str = "bert-base-uncased",
|
|
95
|
+
vocabulary_size: int = 30522,
|
|
96
|
+
num_layers: int = 6,
|
|
97
|
+
input_dimension: int = 768,
|
|
98
|
+
output_dimension: int = 768,
|
|
99
|
+
num_heads: int = 8,
|
|
100
|
+
context_length: int = 77,
|
|
101
|
+
dropout_rate: float = 0.1,
|
|
102
|
+
qkv_bias: bool = False,
|
|
103
|
+
scaling_value: int = 4,
|
|
104
|
+
epsilon: float = 1e-5,
|
|
105
|
+
use_learned_pos: bool = False
|
|
106
|
+
) -> None:
|
|
107
|
+
super().__init__()
|
|
108
|
+
self.use_pretrained_model = use_pretrained_model
|
|
109
|
+
if self.use_pretrained_model:
|
|
110
|
+
self.bert = BertModel.from_pretrained(model_name)
|
|
111
|
+
for param in self.bert.parameters():
|
|
112
|
+
param.requires_grad = False
|
|
113
|
+
self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
|
|
114
|
+
else:
|
|
115
|
+
self.embedding = Embedding(
|
|
116
|
+
vocabulary_size=vocabulary_size,
|
|
117
|
+
embedding_dimension=input_dimension,
|
|
118
|
+
max_context_length=context_length,
|
|
119
|
+
use_learned_pos=use_learned_pos
|
|
120
|
+
)
|
|
121
|
+
self.layers = torch.nn.ModuleList([
|
|
122
|
+
EncoderLayer(
|
|
123
|
+
input_dimension=input_dimension,
|
|
124
|
+
output_dimension=output_dimension,
|
|
125
|
+
num_heads=num_heads,
|
|
126
|
+
dropout_rate=dropout_rate,
|
|
127
|
+
qkv_bias=qkv_bias,
|
|
128
|
+
scaling_value=scaling_value,
|
|
129
|
+
epsilon=epsilon
|
|
130
|
+
)
|
|
131
|
+
for _ in range(num_layers)
|
|
132
|
+
])
|
|
133
|
+
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
134
|
+
"""Encodes text prompts into embeddings.
|
|
135
|
+
|
|
136
|
+
Processes input token IDs and an optional attention mask to produce embeddings
|
|
137
|
+
using either a pre-trained BERT model or a custom transformer.
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
x : torch.Tensor
|
|
142
|
+
Token IDs, shape (batch_size, seq_len).
|
|
143
|
+
attention_mask : torch.Tensor, optional
|
|
144
|
+
Attention mask, shape (batch_size, seq_len), where 0 indicates padding
|
|
145
|
+
tokens to ignore (default: None).
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
x (torch.Tensor) - Encoded embeddings, shape (batch_size, seq_len, output_dimension).
|
|
150
|
+
|
|
151
|
+
**Notes**
|
|
152
|
+
|
|
153
|
+
- For pre-trained BERT, the `last_hidden_state` is projected to
|
|
154
|
+
`output_dimension` and this layer is the only trainable layer in the model.
|
|
155
|
+
- For the custom transformer, token embeddings are processed through
|
|
156
|
+
`Embedding` and `EncoderLayer` modules.
|
|
157
|
+
- The attention mask should be 0 for padding tokens and 1 for valid tokens when
|
|
158
|
+
using the custom transformer, or follow BERT’s convention for pre-trained
|
|
159
|
+
models.
|
|
160
|
+
"""
|
|
161
|
+
if self.use_pretrained_model:
|
|
162
|
+
x = self.bert(input_ids=x, attention_mask=attention_mask)
|
|
163
|
+
x = x.last_hidden_state
|
|
164
|
+
x = self.projection(x)
|
|
165
|
+
else:
|
|
166
|
+
x = self.embedding(x)
|
|
167
|
+
for layer in self.layers:
|
|
168
|
+
x = layer(x, attention_mask=attention_mask)
|
|
169
|
+
return x
|
|
170
|
+
|
|
171
|
+
###==================================================================================================================###
|
|
172
|
+
|
|
173
|
+
class EncoderLayer(torch.nn.Module):
|
|
174
|
+
"""Single transformer encoder layer with multi-head attention and feedforward network.
|
|
175
|
+
|
|
176
|
+
Used in the custom transformer of `TextEncoder` to process embedded text prompts.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
input_dimension : int
|
|
181
|
+
Input embedding dimension.
|
|
182
|
+
output_dimension : int
|
|
183
|
+
Output embedding dimension.
|
|
184
|
+
num_heads : int
|
|
185
|
+
Number of attention heads.
|
|
186
|
+
dropout_rate : float
|
|
187
|
+
Dropout rate for attention and feedforward layers.
|
|
188
|
+
qkv_bias : bool
|
|
189
|
+
If True, includes bias in query, key, and value projections.
|
|
190
|
+
scaling_value : int
|
|
191
|
+
Scaling factor for the feedforward layer’s hidden dimension.
|
|
192
|
+
epsilon : float, optional
|
|
193
|
+
Epsilon for layer normalization (default: 1e-5).
|
|
194
|
+
|
|
195
|
+
**Notes**
|
|
196
|
+
|
|
197
|
+
- The layer follows the standard transformer encoder architecture: attention,
|
|
198
|
+
residual connection, normalization, feedforward, residual connection,
|
|
199
|
+
normalization.
|
|
200
|
+
- The attention mechanism uses `batch_first=True` for compatibility with
|
|
201
|
+
`TextEncoder`’s input format.
|
|
202
|
+
"""
|
|
203
|
+
def __init__(
|
|
204
|
+
self,
|
|
205
|
+
input_dimension: int,
|
|
206
|
+
output_dimension: int,
|
|
207
|
+
num_heads: int,
|
|
208
|
+
dropout_rate: float,
|
|
209
|
+
qkv_bias: bool,
|
|
210
|
+
scaling_value: int,
|
|
211
|
+
epsilon: float = 1e-5
|
|
212
|
+
) -> None:
|
|
213
|
+
super().__init__()
|
|
214
|
+
self.attention = nn.MultiheadAttention(
|
|
215
|
+
embed_dim=input_dimension,
|
|
216
|
+
num_heads=num_heads,
|
|
217
|
+
dropout=dropout_rate,
|
|
218
|
+
bias=qkv_bias,
|
|
219
|
+
batch_first=True
|
|
220
|
+
)
|
|
221
|
+
self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
|
|
222
|
+
self.norm1 = self.norm1 = nn.LayerNorm(normalized_shape=input_dimension, eps=epsilon)
|
|
223
|
+
self.dropout1 = nn.Dropout(dropout_rate)
|
|
224
|
+
self.feedforward = FeedForward(
|
|
225
|
+
embedding_dimension=input_dimension,
|
|
226
|
+
scaling_value=scaling_value,
|
|
227
|
+
dropout_rate=dropout_rate
|
|
228
|
+
)
|
|
229
|
+
self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
|
|
230
|
+
self.dropout2 = nn.Dropout(dropout_rate)
|
|
231
|
+
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
232
|
+
"""Processes input embeddings through attention and feedforward layers.
|
|
233
|
+
|
|
234
|
+
Parameters
|
|
235
|
+
----------
|
|
236
|
+
x : torch.Tensor
|
|
237
|
+
Input embeddings, shape (batch_size, seq_len, input_dimension).
|
|
238
|
+
attention_mask : torch.Tensor, optional
|
|
239
|
+
Attention mask, shape (batch_size, seq_len), where 0 indicates padding
|
|
240
|
+
tokens to ignore (default: None).
|
|
241
|
+
|
|
242
|
+
Returns
|
|
243
|
+
-------
|
|
244
|
+
x (torch.Tensor) - Processed embeddings, shape (batch_size, seq_len, output_dimension).
|
|
245
|
+
|
|
246
|
+
**Notes**
|
|
247
|
+
|
|
248
|
+
- The attention mask is passed as `key_padding_mask` to
|
|
249
|
+
`nn.MultiheadAttention`, where 0 indicates padding tokens.
|
|
250
|
+
- Residual connections and normalization are applied after attention and
|
|
251
|
+
feedforward layers.
|
|
252
|
+
"""
|
|
253
|
+
attn_output, _ = self.attention(x, key_padding_mask=attention_mask)
|
|
254
|
+
attn_output = self.output_projection(attn_output)
|
|
255
|
+
x = self.norm1(x + self.dropout1(attn_output))
|
|
256
|
+
ff_output = self.feedforward(x)
|
|
257
|
+
x = self.norm2(x + self.dropout2(ff_output))
|
|
258
|
+
return x
|
|
259
|
+
|
|
260
|
+
###==================================================================================================================###
|
|
261
|
+
|
|
262
|
+
class FeedForward(torch.nn.Module):
|
|
263
|
+
"""Feedforward network for transformer encoder layers.
|
|
264
|
+
|
|
265
|
+
Used in `EncoderLayer` to process attention outputs with a two-layer MLP and GELU
|
|
266
|
+
activation.
|
|
267
|
+
|
|
268
|
+
Parameters
|
|
269
|
+
----------
|
|
270
|
+
embedding_dimension : int
|
|
271
|
+
Input and output embedding dimension.
|
|
272
|
+
scaling_value : int
|
|
273
|
+
Scaling factor for the hidden layer’s dimension (hidden_dim =
|
|
274
|
+
embedding_dimension * scaling_value).
|
|
275
|
+
dropout_rate : float, optional
|
|
276
|
+
Dropout rate after the hidden layer (default: 0.1).
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
**Notes**
|
|
280
|
+
|
|
281
|
+
- The hidden layer dimension is `embedding_dimension * scaling_value`, following
|
|
282
|
+
standard transformer feedforward designs.
|
|
283
|
+
- GELU activation is used for non-linearity.
|
|
284
|
+
"""
|
|
285
|
+
def __init__(self, embedding_dimension: int, scaling_value: int, dropout_rate: float = 0.1) -> None:
|
|
286
|
+
super().__init__()
|
|
287
|
+
self.layers = torch.nn.Sequential(
|
|
288
|
+
torch.nn.Linear(
|
|
289
|
+
in_features=embedding_dimension,
|
|
290
|
+
out_features=embedding_dimension * scaling_value,
|
|
291
|
+
bias=True
|
|
292
|
+
),
|
|
293
|
+
torch.nn.GELU(),
|
|
294
|
+
torch.nn.Dropout(dropout_rate),
|
|
295
|
+
torch.nn.Linear(
|
|
296
|
+
in_features=embedding_dimension * scaling_value,
|
|
297
|
+
out_features=embedding_dimension,
|
|
298
|
+
bias=True
|
|
299
|
+
)
|
|
300
|
+
)
|
|
301
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
302
|
+
"""Processes input embeddings through the feedforward network.
|
|
303
|
+
|
|
304
|
+
Parameters
|
|
305
|
+
----------
|
|
306
|
+
x : torch.Tensor
|
|
307
|
+
Input embeddings, shape (batch_size, seq_len, embedding_dimension).
|
|
308
|
+
|
|
309
|
+
Returns
|
|
310
|
+
-------
|
|
311
|
+
x (torch.Tensor) - Processed embeddings, shape (batch_size, seq_len, embedding_dimension).
|
|
312
|
+
"""
|
|
313
|
+
return self.layers(x)
|
|
314
|
+
|
|
315
|
+
###==================================================================================================================###
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class Attention(nn.Module):
|
|
319
|
+
"""Attention module for NoisePredictor, supporting text conditioning or self-attention.
|
|
320
|
+
|
|
321
|
+
Applies multi-head attention to enhance features, with optional text embeddings for
|
|
322
|
+
conditional generation.
|
|
323
|
+
|
|
324
|
+
Parameters
|
|
325
|
+
----------
|
|
326
|
+
in_channels : int
|
|
327
|
+
Number of input channels (embedding dimension for attention).
|
|
328
|
+
y_embed_dim : int, optional
|
|
329
|
+
Dimensionality of text embeddings (default: 768).
|
|
330
|
+
num_heads : int, optional
|
|
331
|
+
Number of attention heads (default: 4).
|
|
332
|
+
num_groups : int, optional
|
|
333
|
+
Number of groups for group normalization (default: 8).
|
|
334
|
+
dropout_rate : float, optional
|
|
335
|
+
Dropout rate for attention and output (default: 0.1).
|
|
336
|
+
|
|
337
|
+
Attributes
|
|
338
|
+
----------
|
|
339
|
+
in_channels : int
|
|
340
|
+
Input channel dimension.
|
|
341
|
+
y_embed_dim : int
|
|
342
|
+
Text embedding dimension.
|
|
343
|
+
num_heads : int
|
|
344
|
+
Number of attention heads.
|
|
345
|
+
dropout_rate : float
|
|
346
|
+
Dropout rate.
|
|
347
|
+
attention : torch.nn.MultiheadAttention
|
|
348
|
+
Multi-head attention with `batch_first=True`.
|
|
349
|
+
norm : torch.nn.GroupNorm
|
|
350
|
+
Group normalization before attention.
|
|
351
|
+
dropout : torch.nn.Dropout
|
|
352
|
+
Dropout layer for output.
|
|
353
|
+
y_projection : torch.nn.Linear
|
|
354
|
+
Projection for text embeddings to match `in_channels`.
|
|
355
|
+
|
|
356
|
+
Raises
|
|
357
|
+
------
|
|
358
|
+
AssertionError
|
|
359
|
+
If input channels do not match `in_channels`.
|
|
360
|
+
ValueError
|
|
361
|
+
If text embeddings (`y`) have incorrect dimensions after projection.
|
|
362
|
+
"""
|
|
363
|
+
def __init__(
|
|
364
|
+
self,
|
|
365
|
+
in_channels: int,
|
|
366
|
+
y_embed_dim: int = 768,
|
|
367
|
+
num_heads: int = 4,
|
|
368
|
+
num_groups: int = 8,
|
|
369
|
+
dropout_rate: float = 0.1
|
|
370
|
+
) -> None:
|
|
371
|
+
super().__init__()
|
|
372
|
+
self.in_channels = in_channels
|
|
373
|
+
self.y_embed_dim = y_embed_dim
|
|
374
|
+
self.num_heads = num_heads
|
|
375
|
+
self.dropout_rate = dropout_rate
|
|
376
|
+
self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True)
|
|
377
|
+
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
|
|
378
|
+
self.dropout = nn.Dropout(dropout_rate)
|
|
379
|
+
self.y_projection = nn.Linear(y_embed_dim, in_channels)
|
|
380
|
+
|
|
381
|
+
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
|
|
382
|
+
"""Applies attention to input features with optional text conditioning.
|
|
383
|
+
|
|
384
|
+
Parameters
|
|
385
|
+
----------
|
|
386
|
+
x : torch.Tensor
|
|
387
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
388
|
+
y : torch.Tensor, optional
|
|
389
|
+
Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
|
|
390
|
+
(batch_size, y_embed_dim) (default: None).
|
|
391
|
+
|
|
392
|
+
Returns
|
|
393
|
+
-------
|
|
394
|
+
torch.Tensor
|
|
395
|
+
Output tensor, same shape as input `x`.
|
|
396
|
+
"""
|
|
397
|
+
batch_size, channels, h, w = x.shape
|
|
398
|
+
assert channels == self.in_channels, f"Expected {self.in_channels} channels, got {channels}"
|
|
399
|
+
x_reshaped = x.view(batch_size, channels, h * w).permute(0, 2, 1)
|
|
400
|
+
if y is not None:
|
|
401
|
+
y = self.y_projection(y)
|
|
402
|
+
if y.dim() != 3:
|
|
403
|
+
if y.dim() == 2:
|
|
404
|
+
y = y.unsqueeze(1)
|
|
405
|
+
else:
|
|
406
|
+
raise ValueError(
|
|
407
|
+
f"Expected y to be 2D or 3D after projection, got {y.dim()}D with shape {y.shape}"
|
|
408
|
+
)
|
|
409
|
+
if y.shape[-1] != self.in_channels:
|
|
410
|
+
raise ValueError(
|
|
411
|
+
f"Expected y's embedding dim to match in_channels ({self.in_channels}), got {y.shape[-1]}"
|
|
412
|
+
)
|
|
413
|
+
out, _ = self.attention(x_reshaped, y, y)
|
|
414
|
+
else:
|
|
415
|
+
out, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
|
|
416
|
+
out = out.permute(0, 2, 1).view(batch_size, channels, h, w)
|
|
417
|
+
out = self.norm(out)
|
|
418
|
+
out = self.dropout(out)
|
|
419
|
+
return out
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
###==================================================================================================================###
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class Embedding(nn.Module):
|
|
426
|
+
"""Token and positional embedding layer for transformer inputs.
|
|
427
|
+
|
|
428
|
+
Used in `TextEncoder`’s transformer to embed token IDs and add positional encodings.
|
|
429
|
+
|
|
430
|
+
Parameters
|
|
431
|
+
----------
|
|
432
|
+
vocabulary_size : int
|
|
433
|
+
Size of the vocabulary for token embeddings.
|
|
434
|
+
embedding_dimension : int, optional
|
|
435
|
+
Dimension of token and positional embeddings (default: 768).
|
|
436
|
+
max_context_length : int, optional
|
|
437
|
+
Maximum sequence length for precomputing positional encodings (default: 77).
|
|
438
|
+
use_learned_pos : bool, optional
|
|
439
|
+
If True, uses learnable positional embeddings instead of sinusoidal encodings
|
|
440
|
+
(default: False).
|
|
441
|
+
|
|
442
|
+
**Notes**
|
|
443
|
+
|
|
444
|
+
- Supports both sinusoidal (fixed) and learned positional embeddings, selectable via
|
|
445
|
+
`use_learned_pos`.
|
|
446
|
+
- Sinusoidal encodings follow the transformer architecture, computed on-the-fly for
|
|
447
|
+
memory efficiency and cached for sequences up to `max_context_length`.
|
|
448
|
+
- Learned positional embeddings are initialized as a learnable parameter for flexibility.
|
|
449
|
+
- Optimized for device-agnostic operation, ensuring seamless CPU/GPU transitions.
|
|
450
|
+
- The output shape is (batch_size, seq_len, embedding_dimension).
|
|
451
|
+
"""
|
|
452
|
+
def __init__(
|
|
453
|
+
self,
|
|
454
|
+
vocabulary_size: int,
|
|
455
|
+
embedding_dimension: int = 768,
|
|
456
|
+
max_context_length: int = 77,
|
|
457
|
+
use_learned_pos: bool = False
|
|
458
|
+
) -> None:
|
|
459
|
+
super().__init__()
|
|
460
|
+
self.vocabulary_size = vocabulary_size
|
|
461
|
+
self.embedding_dimension = embedding_dimension
|
|
462
|
+
self.max_context_length = max_context_length
|
|
463
|
+
self.use_learned_pos = use_learned_pos
|
|
464
|
+
|
|
465
|
+
# Token embedding layer
|
|
466
|
+
self.token_embedding = nn.Embedding(
|
|
467
|
+
num_embeddings=vocabulary_size,
|
|
468
|
+
embedding_dim=embedding_dimension
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
if use_learned_pos:
|
|
472
|
+
# Learnable positional embeddings
|
|
473
|
+
self.positional_embedding = nn.Parameter(
|
|
474
|
+
torch.randn(1, max_context_length, embedding_dimension) / math.sqrt(embedding_dimension)
|
|
475
|
+
)
|
|
476
|
+
else:
|
|
477
|
+
# Register buffer for sinusoidal encodings
|
|
478
|
+
self.register_buffer(
|
|
479
|
+
"positional_encoding_cache",
|
|
480
|
+
torch.empty(1, 0, embedding_dimension, dtype=torch.float32)
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
def _generate_positional_encoding(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
484
|
+
"""Generates sinusoidal positional encodings for transformer inputs.
|
|
485
|
+
|
|
486
|
+
Computes positional encodings using sine and cosine functions.
|
|
487
|
+
|
|
488
|
+
Parameters
|
|
489
|
+
----------
|
|
490
|
+
seq_len : int
|
|
491
|
+
Length of the sequence for which to generate positional encodings.
|
|
492
|
+
device : torch.device
|
|
493
|
+
Device on which to create the positional encodings.
|
|
494
|
+
|
|
495
|
+
Returns
|
|
496
|
+
-------
|
|
497
|
+
torch.Tensor
|
|
498
|
+
Positional encodings, shape (1, seq_len, embedding_dimension), where
|
|
499
|
+
even-indexed dimensions use sine and odd-indexed dimensions use cosine.
|
|
500
|
+
|
|
501
|
+
**Notes**
|
|
502
|
+
|
|
503
|
+
- Uses the formula: for position `pos` and dimension `i`,
|
|
504
|
+
`PE(pos, 2i) = sin(pos / 10000^(2i/d))` and
|
|
505
|
+
`PE(pos, 2i+1) = cos(pos / 10000^(2i/d))`, where `d` is `embedding_dimension`.
|
|
506
|
+
- Fully vectorized for efficiency and supports any sequence length.
|
|
507
|
+
"""
|
|
508
|
+
position = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1)
|
|
509
|
+
div_term = torch.exp(
|
|
510
|
+
torch.arange(0, self.embedding_dimension, 2, dtype=torch.float32, device=device) *
|
|
511
|
+
(-math.log(10000.0) / self.embedding_dimension)
|
|
512
|
+
)
|
|
513
|
+
pos_enc = torch.zeros((1, seq_len, self.embedding_dimension), dtype=torch.float32, device=device)
|
|
514
|
+
pos_enc[:, :, 0::2] = torch.sin(position * div_term)
|
|
515
|
+
pos_enc[:, :, 1::2] = torch.cos(position * div_term[:, :-1] if self.embedding_dimension % 2 else div_term)
|
|
516
|
+
return pos_enc
|
|
517
|
+
|
|
518
|
+
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
|
|
519
|
+
"""Embeds token IDs and adds positional encodings.
|
|
520
|
+
|
|
521
|
+
Parameters
|
|
522
|
+
----------
|
|
523
|
+
token_ids : torch.Tensor
|
|
524
|
+
Token IDs, shape (batch_size, seq_len).
|
|
525
|
+
|
|
526
|
+
Returns
|
|
527
|
+
-------
|
|
528
|
+
torch.Tensor
|
|
529
|
+
Embedded tokens with positional encodings, shape
|
|
530
|
+
(batch_size, seq_len, embedding_dimension).
|
|
531
|
+
|
|
532
|
+
**Notes**
|
|
533
|
+
|
|
534
|
+
- Automatically handles sequences longer than `max_context_length` by generating
|
|
535
|
+
positional encodings on-the-fly.
|
|
536
|
+
- For learned positional embeddings, sequences longer than `max_context_length`
|
|
537
|
+
will raise an error unless truncated.
|
|
538
|
+
- Ensures device compatibility by generating encodings on the input’s device.
|
|
539
|
+
"""
|
|
540
|
+
assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
|
|
541
|
+
batch_size, seq_len = token_ids.size()
|
|
542
|
+
device = token_ids.device
|
|
543
|
+
|
|
544
|
+
# Compute token embeddings
|
|
545
|
+
token_embedded = self.token_embedding(token_ids)
|
|
546
|
+
|
|
547
|
+
# Handle positional embeddings
|
|
548
|
+
if self.use_learned_pos:
|
|
549
|
+
if seq_len > self.max_context_length:
|
|
550
|
+
raise ValueError(
|
|
551
|
+
f"Sequence length ({seq_len}) exceeds max_context_length ({self.max_context_length}) "
|
|
552
|
+
"for learned positional embeddings."
|
|
553
|
+
)
|
|
554
|
+
position_encoded = self.positional_embedding[:, :seq_len, :]
|
|
555
|
+
else:
|
|
556
|
+
# Use cached sinusoidal encodings if available and sufficient
|
|
557
|
+
if (self.positional_encoding_cache.size(1) < seq_len or
|
|
558
|
+
self.positional_encoding_cache.device != device):
|
|
559
|
+
self.positional_encoding_cache = self._generate_positional_encoding(
|
|
560
|
+
max(seq_len, self.max_context_length), device
|
|
561
|
+
)
|
|
562
|
+
position_encoded = self.positional_encoding_cache[:, :seq_len, :]
|
|
563
|
+
|
|
564
|
+
return token_embedded + position_encoded
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
###==================================================================================================================###
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
class NoisePredictor(nn.Module):
|
|
571
|
+
"""U-Net-like architecture for noise prediction in Diffusion Models.
|
|
572
|
+
|
|
573
|
+
Predicts noise for diffusion models (DDPM, DDIM, SDE), incorporating
|
|
574
|
+
time embeddings and optional text conditioning. used as the `noise_predictor` in
|
|
575
|
+
`Train` and `Sample` from the `ldm`, `sde`, `ddpm`, `ddim` modules.
|
|
576
|
+
|
|
577
|
+
Parameters
|
|
578
|
+
----------
|
|
579
|
+
in_channels : int
|
|
580
|
+
Number of input channels.
|
|
581
|
+
down_channels : list of int
|
|
582
|
+
List of output channels for downsampling blocks.
|
|
583
|
+
mid_channels : list of int
|
|
584
|
+
List of channels for middle blocks.
|
|
585
|
+
up_channels : list of int
|
|
586
|
+
List of output channels for upsampling blocks.
|
|
587
|
+
down_sampling : list of bool
|
|
588
|
+
List indicating whether to downsample in each down block.
|
|
589
|
+
time_embed_dim : int
|
|
590
|
+
Dimensionality of time embeddings.
|
|
591
|
+
y_embed_dim : int
|
|
592
|
+
Dimensionality of text embeddings for conditioning.
|
|
593
|
+
num_down_blocks : int
|
|
594
|
+
Number of convolutional layer pairs per down block.
|
|
595
|
+
num_mid_blocks : int
|
|
596
|
+
Number of convolutional layer pairs per middle block.
|
|
597
|
+
num_up_blocks : int
|
|
598
|
+
Number of convolutional layer pairs per up block.
|
|
599
|
+
dropout_rate : float, optional
|
|
600
|
+
Dropout rate for convolutional and attention layers (default: 0.1).
|
|
601
|
+
down_sampling_factor : int, optional
|
|
602
|
+
Factor for spatial downsampling/upsampling (default: 2).
|
|
603
|
+
where_y : bool, optional
|
|
604
|
+
If True, text embeddings are used in attention; if False, concatenated to input
|
|
605
|
+
(default: True).
|
|
606
|
+
y_to_all : bool, optional
|
|
607
|
+
If True, apply text-conditioned attention to all layers; if False, only first layer
|
|
608
|
+
(default: False).
|
|
609
|
+
|
|
610
|
+
**Notes**
|
|
611
|
+
|
|
612
|
+
- The architecture follows a U-Net structure with downsampling, bottleneck, and
|
|
613
|
+
upsampling blocks, incorporating time embeddings and optional text conditioning via
|
|
614
|
+
attention or concatenation.
|
|
615
|
+
- Skip connections link down and up blocks, with channel adjustments for concatenation.
|
|
616
|
+
- Weights are initialized with Kaiming normal (Leaky ReLU nonlinearity) for stability.
|
|
617
|
+
- Input and output tensors have the same shape.
|
|
618
|
+
"""
|
|
619
|
+
def __init__(
|
|
620
|
+
self,
|
|
621
|
+
in_channels: int,
|
|
622
|
+
down_channels: List[int],
|
|
623
|
+
mid_channels: List[int],
|
|
624
|
+
up_channels: List[int],
|
|
625
|
+
down_sampling: List[bool],
|
|
626
|
+
time_embed_dim: int,
|
|
627
|
+
y_embed_dim: int,
|
|
628
|
+
num_down_blocks: int,
|
|
629
|
+
num_mid_blocks: int,
|
|
630
|
+
num_up_blocks: int,
|
|
631
|
+
dropout_rate: float = 0.1,
|
|
632
|
+
down_sampling_factor: int = 2,
|
|
633
|
+
where_y: bool = True,
|
|
634
|
+
y_to_all: bool = False
|
|
635
|
+
) -> None:
|
|
636
|
+
super().__init__()
|
|
637
|
+
self.in_channels = in_channels
|
|
638
|
+
self.down_channels = down_channels
|
|
639
|
+
self.mid_channels = mid_channels
|
|
640
|
+
self.up_channels = up_channels
|
|
641
|
+
self.down_sampling = down_sampling
|
|
642
|
+
self.time_embed_dim = time_embed_dim
|
|
643
|
+
self.y_embed_dim = y_embed_dim
|
|
644
|
+
self.num_down_blocks = num_down_blocks
|
|
645
|
+
self.num_mid_blocks = num_mid_blocks
|
|
646
|
+
self.num_up_blocks = num_up_blocks
|
|
647
|
+
self.dropout_rate = dropout_rate
|
|
648
|
+
self.where_y = where_y
|
|
649
|
+
self.up_sampling = list(reversed(self.down_sampling))
|
|
650
|
+
self.conv1 = nn.Conv2d(
|
|
651
|
+
in_channels=self.in_channels,
|
|
652
|
+
out_channels=self.down_channels[0],
|
|
653
|
+
kernel_size=3,
|
|
654
|
+
padding=1
|
|
655
|
+
)
|
|
656
|
+
# initial time embedding projection
|
|
657
|
+
self.time_projection = nn.Sequential(
|
|
658
|
+
nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim),
|
|
659
|
+
nn.SiLU(),
|
|
660
|
+
nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim)
|
|
661
|
+
)
|
|
662
|
+
# down blocks
|
|
663
|
+
self.down_blocks = nn.ModuleList([
|
|
664
|
+
DownBlock(
|
|
665
|
+
in_channels=self.down_channels[i],
|
|
666
|
+
out_channels=self.down_channels[i+1],
|
|
667
|
+
time_embed_dim=self.time_embed_dim,
|
|
668
|
+
y_embed_dim=y_embed_dim,
|
|
669
|
+
num_layers=self.num_down_blocks,
|
|
670
|
+
down_sampling_factor=down_sampling_factor,
|
|
671
|
+
down_sample=self.down_sampling[i],
|
|
672
|
+
dropout_rate=self.dropout_rate,
|
|
673
|
+
y_to_all=y_to_all
|
|
674
|
+
) for i in range(len(self.down_channels)-1)
|
|
675
|
+
])
|
|
676
|
+
# middle blocks
|
|
677
|
+
self.mid_blocks = nn.ModuleList([
|
|
678
|
+
MiddleBlock(
|
|
679
|
+
in_channels=self.mid_channels[i],
|
|
680
|
+
out_channels=self.mid_channels[i + 1],
|
|
681
|
+
time_embed_dim=self.time_embed_dim,
|
|
682
|
+
y_embed_dim=y_embed_dim,
|
|
683
|
+
num_layers=self.num_mid_blocks,
|
|
684
|
+
dropout_rate=self.dropout_rate,
|
|
685
|
+
y_to_all=y_to_all
|
|
686
|
+
) for i in range(len(self.mid_channels) - 1)
|
|
687
|
+
])
|
|
688
|
+
# up blocks
|
|
689
|
+
skip_channels = list(reversed(self.down_channels))
|
|
690
|
+
self.up_blocks = nn.ModuleList([
|
|
691
|
+
UpBlock(
|
|
692
|
+
in_channels=self.up_channels[i],
|
|
693
|
+
out_channels=self.up_channels[i+1],
|
|
694
|
+
skip_channels=skip_channels[i],
|
|
695
|
+
time_embed_dim=self.time_embed_dim,
|
|
696
|
+
y_embed_dim=y_embed_dim,
|
|
697
|
+
num_layers=self.num_up_blocks,
|
|
698
|
+
up_sampling_factor=down_sampling_factor,
|
|
699
|
+
up_sampling=self.up_sampling[i],
|
|
700
|
+
dropout_rate=self.dropout_rate,
|
|
701
|
+
y_to_all=y_to_all
|
|
702
|
+
) for i in range(len(self.up_channels)-1)
|
|
703
|
+
])
|
|
704
|
+
# final convolution layer
|
|
705
|
+
self.conv2 = nn.Sequential(
|
|
706
|
+
nn.GroupNorm(num_groups=8, num_channels=self.up_channels[-1]),
|
|
707
|
+
nn.Dropout(p=self.dropout_rate),
|
|
708
|
+
nn.Conv2d(in_channels=self.up_channels[-1], out_channels=self.in_channels, kernel_size=3, padding=1)
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
def initialize_weights(self) -> None:
|
|
712
|
+
"""Initializes model weights for training stability.
|
|
713
|
+
|
|
714
|
+
Applies Kaiming normal initialization to convolutional and linear layers with
|
|
715
|
+
Leaky ReLU nonlinearity (a=0.2), and zeros biases.
|
|
716
|
+
"""
|
|
717
|
+
for module in self.modules():
|
|
718
|
+
if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
|
|
719
|
+
nn.init.kaiming_normal_(module.weight, a=0.2, nonlinearity='leaky_relu')
|
|
720
|
+
if module.bias is not None:
|
|
721
|
+
nn.init.zeros_(module.bias)
|
|
722
|
+
|
|
723
|
+
def forward(
|
|
724
|
+
self,
|
|
725
|
+
x: torch.Tensor,
|
|
726
|
+
t: torch.Tensor,
|
|
727
|
+
y: Optional[torch.Tensor] = None,
|
|
728
|
+
clip_embeddings: Optional[torch.Tensor] = None
|
|
729
|
+
) -> torch.Tensor:
|
|
730
|
+
"""Predicts noise given input, time step, and optional text conditioning.
|
|
731
|
+
|
|
732
|
+
Parameters
|
|
733
|
+
----------
|
|
734
|
+
x : torch.Tensor
|
|
735
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
736
|
+
t : torch.Tensor
|
|
737
|
+
Time steps, shape (batch_size,).
|
|
738
|
+
y : torch.Tensor, optional
|
|
739
|
+
Text embeddings for conditioning, shape (batch_size, seq_len, y_embed_dim)
|
|
740
|
+
or (batch_size, y_embed_dim) (default: None).
|
|
741
|
+
clip_embeddings: torch.Tensor, optional
|
|
742
|
+
used in the context of un-clip algorithm
|
|
743
|
+
|
|
744
|
+
Returns
|
|
745
|
+
-------
|
|
746
|
+
output (torch.Tensor) - Predicted noise, same shape as input `x`.
|
|
747
|
+
"""
|
|
748
|
+
if not self.where_y and y is not None:
|
|
749
|
+
x = torch.cat(tensors=[x, y], dim=1)
|
|
750
|
+
output = self.conv1(x)
|
|
751
|
+
time_embed = GetEmbeddedTime(embed_dim=self.time_embed_dim)(time_steps=t)
|
|
752
|
+
time_embed = self.time_projection(time_embed)
|
|
753
|
+
|
|
754
|
+
if clip_embeddings is not None:
|
|
755
|
+
#if len(clip_embeddings.shape) == 3: # [batch_size, seq_len, time_embed_dim]
|
|
756
|
+
# time_embed = time_embed.unsqueeze(1)
|
|
757
|
+
time_embed = time_embed + clip_embeddings
|
|
758
|
+
|
|
759
|
+
skip_connections = []
|
|
760
|
+
for i, down in enumerate(self.down_blocks):
|
|
761
|
+
skip_connections.append(output)
|
|
762
|
+
output = down(x=output, embed_time=time_embed, y=y)
|
|
763
|
+
for i, mid in enumerate(self.mid_blocks):
|
|
764
|
+
output = mid(x=output, embed_time=time_embed, y=y)
|
|
765
|
+
for i, up in enumerate(self.up_blocks):
|
|
766
|
+
skip_connection = skip_connections.pop()
|
|
767
|
+
output = up(x=output, skip_connection=skip_connection, embed_time=time_embed, y=y)
|
|
768
|
+
|
|
769
|
+
output = self.conv2(output)
|
|
770
|
+
return output
|
|
771
|
+
|
|
772
|
+
###==================================================================================================================###
|
|
773
|
+
|
|
774
|
+
class DownBlock(nn.Module):
|
|
775
|
+
"""Downsampling block for NoisePredictor’s encoder.
|
|
776
|
+
|
|
777
|
+
Applies convolutional layers with residual connections, time embeddings, and optional
|
|
778
|
+
text-conditioned attention, followed by downsampling if enabled.
|
|
779
|
+
|
|
780
|
+
Parameters
|
|
781
|
+
----------
|
|
782
|
+
in_channels : int
|
|
783
|
+
Number of input channels.
|
|
784
|
+
out_channels : int
|
|
785
|
+
Number of output channels.
|
|
786
|
+
time_embed_dim : int
|
|
787
|
+
Dimensionality of time embeddings.
|
|
788
|
+
y_embed_dim : int
|
|
789
|
+
Dimensionality of text embeddings.
|
|
790
|
+
num_layers : int
|
|
791
|
+
Number of convolutional layer pairs (Conv3).
|
|
792
|
+
down_sampling_factor : int
|
|
793
|
+
Factor for spatial downsampling.
|
|
794
|
+
down_sample : bool
|
|
795
|
+
If True, apply downsampling; if False, use identity (no downsampling).
|
|
796
|
+
dropout_rate : float
|
|
797
|
+
Dropout rate for Conv3 and attention layers.
|
|
798
|
+
y_to_all : bool
|
|
799
|
+
If True, apply text-conditioned attention to all layers; if False, only first layer.
|
|
800
|
+
"""
|
|
801
|
+
def __init__(
|
|
802
|
+
self,
|
|
803
|
+
in_channels: int,
|
|
804
|
+
out_channels: int ,
|
|
805
|
+
time_embed_dim: int,
|
|
806
|
+
y_embed_dim: int,
|
|
807
|
+
num_layers: int,
|
|
808
|
+
down_sampling_factor: int,
|
|
809
|
+
down_sample: bool,
|
|
810
|
+
dropout_rate: float,
|
|
811
|
+
y_to_all: bool
|
|
812
|
+
) -> None:
|
|
813
|
+
super().__init__()
|
|
814
|
+
self.num_layers = num_layers
|
|
815
|
+
self.y_to_all = y_to_all
|
|
816
|
+
self.conv1 = nn.ModuleList([
|
|
817
|
+
Conv3(
|
|
818
|
+
in_channels=in_channels if i==0 else out_channels,
|
|
819
|
+
out_channels=out_channels,
|
|
820
|
+
num_groups=8,
|
|
821
|
+
kernel_size=3,
|
|
822
|
+
norm=True,
|
|
823
|
+
activation=True,
|
|
824
|
+
dropout_rate=dropout_rate
|
|
825
|
+
) for i in range(self.num_layers)
|
|
826
|
+
])
|
|
827
|
+
self.conv2 = nn.ModuleList([
|
|
828
|
+
Conv3(
|
|
829
|
+
in_channels=out_channels,
|
|
830
|
+
out_channels=out_channels,
|
|
831
|
+
num_groups=8,
|
|
832
|
+
kernel_size=3,
|
|
833
|
+
norm=True,
|
|
834
|
+
activation=True,
|
|
835
|
+
dropout_rate=dropout_rate
|
|
836
|
+
) for _ in range(self.num_layers)
|
|
837
|
+
])
|
|
838
|
+
self.time_embedding = nn.ModuleList([
|
|
839
|
+
TimeEmbedding(
|
|
840
|
+
output_dim=out_channels,
|
|
841
|
+
embed_dim=time_embed_dim
|
|
842
|
+
) for _ in range(self.num_layers)
|
|
843
|
+
])
|
|
844
|
+
self.attention = nn.ModuleList([
|
|
845
|
+
Attention(
|
|
846
|
+
in_channels=out_channels,
|
|
847
|
+
y_embed_dim=y_embed_dim,
|
|
848
|
+
num_groups=8,
|
|
849
|
+
num_heads=4,
|
|
850
|
+
dropout_rate=dropout_rate
|
|
851
|
+
) for _ in range(self.num_layers)
|
|
852
|
+
])
|
|
853
|
+
self.down_sampling = DownSampling(
|
|
854
|
+
in_channels=out_channels,
|
|
855
|
+
out_channels=out_channels,
|
|
856
|
+
down_sampling_factor=down_sampling_factor,
|
|
857
|
+
conv_block=True,
|
|
858
|
+
max_pool=True
|
|
859
|
+
) if down_sample else nn.Identity()
|
|
860
|
+
self.resnet = nn.ModuleList([
|
|
861
|
+
nn.Conv2d(
|
|
862
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
863
|
+
out_channels=out_channels,
|
|
864
|
+
kernel_size=1
|
|
865
|
+
) for i in range(num_layers)
|
|
866
|
+
|
|
867
|
+
])
|
|
868
|
+
|
|
869
|
+
def forward(self, x: torch.Tensor, embed_time: torch.Tensor, y: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
870
|
+
"""Processes input through convolutions, time embeddings, attention, and downsampling.
|
|
871
|
+
|
|
872
|
+
Parameters
|
|
873
|
+
----------
|
|
874
|
+
x : torch.Tensor
|
|
875
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
876
|
+
embed_time : torch.Tensor
|
|
877
|
+
Time embeddings, shape (batch_size, time_embed_dim).
|
|
878
|
+
y : torch.Tensor, optional
|
|
879
|
+
Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
|
|
880
|
+
(batch_size, y_embed_dim) (default: None).
|
|
881
|
+
key_padding_mask : torch.Tensor, optional
|
|
882
|
+
Boolean mask, shape (batch_size, seq_len) if `y` is None, or (batch_size, seq_len_y) if `y` is provided,
|
|
883
|
+
where `True` indicates positions to mask out (default: None).
|
|
884
|
+
|
|
885
|
+
Returns
|
|
886
|
+
-------
|
|
887
|
+
output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor) if downsampling; otherwise, same height/width as input.
|
|
888
|
+
"""
|
|
889
|
+
output = x
|
|
890
|
+
for i in range(self.num_layers):
|
|
891
|
+
resnet_input = output
|
|
892
|
+
output = self.conv1[i](output)
|
|
893
|
+
output = output + self.time_embedding[i](embed_time)[:, :, None, None]
|
|
894
|
+
output = self.conv2[i](output)
|
|
895
|
+
output = output + self.resnet[i](resnet_input)
|
|
896
|
+
|
|
897
|
+
if not self.y_to_all and i == 0:
|
|
898
|
+
out_attn = self.attention[i](output, y)
|
|
899
|
+
output = output + out_attn
|
|
900
|
+
elif self.y_to_all:
|
|
901
|
+
out_attn = self.attention[i](output, y)
|
|
902
|
+
output = output + out_attn
|
|
903
|
+
|
|
904
|
+
output = self.down_sampling(output)
|
|
905
|
+
return output
|
|
906
|
+
|
|
907
|
+
###==================================================================================================================###
|
|
908
|
+
|
|
909
|
+
class MiddleBlock(nn.Module):
|
|
910
|
+
"""Bottleneck block for NoisePredictor’s middle layers.
|
|
911
|
+
|
|
912
|
+
Applies convolutional layers with residual connections, time embeddings, and optional
|
|
913
|
+
text-conditioned attention, preserving spatial dimensions.
|
|
914
|
+
|
|
915
|
+
Parameters
|
|
916
|
+
----------
|
|
917
|
+
in_channels : int
|
|
918
|
+
Number of input channels.
|
|
919
|
+
out_channels : int
|
|
920
|
+
Number of output channels.
|
|
921
|
+
time_embed_dim : int
|
|
922
|
+
Dimensionality of time embeddings.
|
|
923
|
+
y_embed_dim : int
|
|
924
|
+
Dimensionality of text embeddings.
|
|
925
|
+
num_layers : int
|
|
926
|
+
Number of convolutional layer pairs (Conv3).
|
|
927
|
+
dropout_rate : float
|
|
928
|
+
Dropout rate for Conv3 and attention layers.
|
|
929
|
+
y_to_all : bool
|
|
930
|
+
If True, apply text-conditioned attention to all layers; if False, only first layer
|
|
931
|
+
(default: False).
|
|
932
|
+
"""
|
|
933
|
+
def __init__(
|
|
934
|
+
self,
|
|
935
|
+
in_channels: int,
|
|
936
|
+
out_channels: int,
|
|
937
|
+
time_embed_dim: int,
|
|
938
|
+
y_embed_dim: int,
|
|
939
|
+
num_layers: int,
|
|
940
|
+
dropout_rate: float,
|
|
941
|
+
y_to_all: bool
|
|
942
|
+
) -> None:
|
|
943
|
+
super().__init__()
|
|
944
|
+
self.num_layers = num_layers
|
|
945
|
+
self.y_to_all = y_to_all
|
|
946
|
+
self.conv1 = nn.ModuleList([
|
|
947
|
+
Conv3(
|
|
948
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
949
|
+
out_channels=out_channels,
|
|
950
|
+
num_groups=8,
|
|
951
|
+
kernel_size=3,
|
|
952
|
+
norm=True,
|
|
953
|
+
activation=True,
|
|
954
|
+
dropout_rate=dropout_rate
|
|
955
|
+
) for i in range(self.num_layers+1)
|
|
956
|
+
])
|
|
957
|
+
self.conv2 = nn.ModuleList([
|
|
958
|
+
Conv3(
|
|
959
|
+
in_channels=out_channels,
|
|
960
|
+
out_channels=out_channels,
|
|
961
|
+
num_groups=8,
|
|
962
|
+
kernel_size=3,
|
|
963
|
+
norm=True,
|
|
964
|
+
activation=True,
|
|
965
|
+
dropout_rate=dropout_rate
|
|
966
|
+
) for _ in range(self.num_layers+1)
|
|
967
|
+
])
|
|
968
|
+
self.time_embedding = nn.ModuleList([
|
|
969
|
+
TimeEmbedding(
|
|
970
|
+
output_dim=out_channels,
|
|
971
|
+
embed_dim=time_embed_dim
|
|
972
|
+
) for _ in range(self.num_layers+1)
|
|
973
|
+
])
|
|
974
|
+
self.attention = nn.ModuleList([
|
|
975
|
+
Attention(
|
|
976
|
+
in_channels=out_channels,
|
|
977
|
+
y_embed_dim=y_embed_dim,
|
|
978
|
+
num_groups=8,
|
|
979
|
+
num_heads=4,
|
|
980
|
+
dropout_rate=dropout_rate
|
|
981
|
+
) for _ in range(self.num_layers + 1)
|
|
982
|
+
])
|
|
983
|
+
self.resnet = nn.ModuleList([
|
|
984
|
+
nn.Conv2d(
|
|
985
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
986
|
+
out_channels=out_channels,
|
|
987
|
+
kernel_size=1
|
|
988
|
+
) for i in range(num_layers+1)
|
|
989
|
+
])
|
|
990
|
+
|
|
991
|
+
def forward(self, x: torch.Tensor, embed_time: torch.Tensor, y: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
992
|
+
"""Processes input through convolutions, time embeddings, and attention.
|
|
993
|
+
|
|
994
|
+
Parameters
|
|
995
|
+
----------
|
|
996
|
+
x : torch.Tensor
|
|
997
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
998
|
+
embed_time : torch.Tensor
|
|
999
|
+
Time embeddings, shape (batch_size, time_embed_dim).
|
|
1000
|
+
y : torch.Tensor, optional
|
|
1001
|
+
Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
|
|
1002
|
+
(batch_size, y_embed_dim) (default: None).
|
|
1003
|
+
key_padding_mask : torch.Tensor, optional
|
|
1004
|
+
Boolean mask, shape (batch_size, seq_len) if `y` is None, or (batch_size, seq_len_y) if `y` is provided,
|
|
1005
|
+
where `True` indicates positions to mask out (default: None).
|
|
1006
|
+
|
|
1007
|
+
Returns
|
|
1008
|
+
-------
|
|
1009
|
+
output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height, width).
|
|
1010
|
+
"""
|
|
1011
|
+
output = x
|
|
1012
|
+
resnet_input = output
|
|
1013
|
+
output = self.conv1[0](output)
|
|
1014
|
+
output = output + self.time_embedding[0](embed_time)[:, :, None, None]
|
|
1015
|
+
output = self.conv2[0](output)
|
|
1016
|
+
output = output + self.resnet[0](resnet_input)
|
|
1017
|
+
|
|
1018
|
+
for i in range(self.num_layers):
|
|
1019
|
+
if not self.y_to_all and i == 0:
|
|
1020
|
+
out_attn = self.attention[i](output, y)
|
|
1021
|
+
output = output + out_attn
|
|
1022
|
+
elif self.y_to_all:
|
|
1023
|
+
out_attn = self.attention[i](output, y)
|
|
1024
|
+
output = output + out_attn
|
|
1025
|
+
resnet_input = output
|
|
1026
|
+
output = self.conv1[i + 1](output)
|
|
1027
|
+
output = output + self.time_embedding[i + 1](embed_time)[:, :, None, None]
|
|
1028
|
+
output = self.conv2[i + 1](output)
|
|
1029
|
+
output = output + self.resnet[i+1](resnet_input)
|
|
1030
|
+
return output
|
|
1031
|
+
|
|
1032
|
+
###==================================================================================================================###
|
|
1033
|
+
|
|
1034
|
+
class UpBlock(nn.Module):
|
|
1035
|
+
"""Upsampling block for NoisePredictor’s decoder.
|
|
1036
|
+
|
|
1037
|
+
Applies upsampling (if enabled), concatenates skip connections, and processes through
|
|
1038
|
+
convolutional layers with residual connections, time embeddings, and optional
|
|
1039
|
+
text-conditioned attention.
|
|
1040
|
+
|
|
1041
|
+
Parameters
|
|
1042
|
+
----------
|
|
1043
|
+
in_channels : int
|
|
1044
|
+
Number of input channels (before upsampling).
|
|
1045
|
+
out_channels : int
|
|
1046
|
+
Number of output channels.
|
|
1047
|
+
skip_channels : int
|
|
1048
|
+
Number of channels from skip connection.
|
|
1049
|
+
time_embed_dim : int
|
|
1050
|
+
Dimensionality of time embeddings.
|
|
1051
|
+
y_embed_dim : int
|
|
1052
|
+
Dimensionality of text embeddings.
|
|
1053
|
+
num_layers : int
|
|
1054
|
+
Number of convolutional layer pairs (Conv3).
|
|
1055
|
+
up_sampling_factor : int
|
|
1056
|
+
Factor for spatial upsampling.
|
|
1057
|
+
up_sampling : bool
|
|
1058
|
+
If True, apply upsampling; if False, use identity (no upsampling).
|
|
1059
|
+
dropout_rate : float
|
|
1060
|
+
Dropout rate for Conv3 and attention layers.
|
|
1061
|
+
y_to_all : bool
|
|
1062
|
+
If True, apply text-conditioned attention to all layers; if False, only first layer
|
|
1063
|
+
(default: False).
|
|
1064
|
+
"""
|
|
1065
|
+
def __init__(
|
|
1066
|
+
self,
|
|
1067
|
+
in_channels: int,
|
|
1068
|
+
out_channels: int,
|
|
1069
|
+
skip_channels: int,
|
|
1070
|
+
time_embed_dim: int,
|
|
1071
|
+
y_embed_dim: int,
|
|
1072
|
+
num_layers: int,
|
|
1073
|
+
up_sampling_factor: int,
|
|
1074
|
+
up_sampling: bool,
|
|
1075
|
+
dropout_rate: float,
|
|
1076
|
+
y_to_all: bool
|
|
1077
|
+
) -> None:
|
|
1078
|
+
super().__init__()
|
|
1079
|
+
self.num_layers = num_layers
|
|
1080
|
+
self.y_to_all = y_to_all
|
|
1081
|
+
effective_in_channels = in_channels // 2 + skip_channels
|
|
1082
|
+
self.conv1 = nn.ModuleList([
|
|
1083
|
+
Conv3(
|
|
1084
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
1085
|
+
out_channels=out_channels,
|
|
1086
|
+
num_groups=8,
|
|
1087
|
+
kernel_size=3,
|
|
1088
|
+
norm=True,
|
|
1089
|
+
activation=True,
|
|
1090
|
+
dropout_rate=dropout_rate
|
|
1091
|
+
) for i in range(self.num_layers)
|
|
1092
|
+
])
|
|
1093
|
+
self.conv2 = nn.ModuleList([
|
|
1094
|
+
Conv3(
|
|
1095
|
+
in_channels=out_channels,
|
|
1096
|
+
out_channels=out_channels,
|
|
1097
|
+
num_groups=8,
|
|
1098
|
+
kernel_size=3,
|
|
1099
|
+
norm=True,
|
|
1100
|
+
activation=True,
|
|
1101
|
+
dropout_rate=dropout_rate
|
|
1102
|
+
) for _ in range(self.num_layers)
|
|
1103
|
+
])
|
|
1104
|
+
self.time_embedding = nn.ModuleList([
|
|
1105
|
+
TimeEmbedding(
|
|
1106
|
+
output_dim=out_channels,
|
|
1107
|
+
embed_dim=time_embed_dim
|
|
1108
|
+
) for _ in range(self.num_layers)
|
|
1109
|
+
])
|
|
1110
|
+
self.attention = nn.ModuleList([
|
|
1111
|
+
Attention(
|
|
1112
|
+
in_channels=out_channels,
|
|
1113
|
+
y_embed_dim=y_embed_dim,
|
|
1114
|
+
num_groups=8,
|
|
1115
|
+
num_heads=4,
|
|
1116
|
+
dropout_rate=dropout_rate
|
|
1117
|
+
) for _ in range(self.num_layers)
|
|
1118
|
+
])
|
|
1119
|
+
self.up_sampling_ = UpSampling(
|
|
1120
|
+
in_channels=in_channels,
|
|
1121
|
+
out_channels=in_channels,
|
|
1122
|
+
up_sampling_factor=up_sampling_factor,
|
|
1123
|
+
conv_block=True,
|
|
1124
|
+
up_sampling=True
|
|
1125
|
+
) if up_sampling else nn.Identity()
|
|
1126
|
+
self.resnet = nn.ModuleList([
|
|
1127
|
+
nn.Conv2d(
|
|
1128
|
+
in_channels=effective_in_channels if i == 0 else out_channels,
|
|
1129
|
+
out_channels=out_channels,
|
|
1130
|
+
kernel_size=1
|
|
1131
|
+
) for i in range(num_layers)
|
|
1132
|
+
|
|
1133
|
+
])
|
|
1134
|
+
|
|
1135
|
+
def forward(self, x: torch.Tensor, skip_connection: torch.Tensor, embed_time: torch.Tensor, y: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
1136
|
+
"""Processes input through upsampling, skip connection, convolutions, time embeddings, and attention.
|
|
1137
|
+
|
|
1138
|
+
Parameters
|
|
1139
|
+
----------
|
|
1140
|
+
x : torch.Tensor
|
|
1141
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
1142
|
+
skip_connection : torch.Tensor
|
|
1143
|
+
Skip connection tensor, shape (batch_size, skip_channels,
|
|
1144
|
+
height*up_sampling_factor, width*up_sampling_factor).
|
|
1145
|
+
embed_time : torch.Tensor
|
|
1146
|
+
Time embeddings, shape (batch_size, time_embed_dim).
|
|
1147
|
+
y : torch.Tensor, optional
|
|
1148
|
+
Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
|
|
1149
|
+
(batch_size, y_embed_dim) (default: None).
|
|
1150
|
+
key_padding_mask : torch.Tensor, optional
|
|
1151
|
+
Boolean mask, shape (batch_size, seq_len) if `y` is None, or (batch_size, seq_len_y) if `y` is provided,
|
|
1152
|
+
where `True` indicates positions to mask out (default: None).
|
|
1153
|
+
|
|
1154
|
+
Returns
|
|
1155
|
+
-------
|
|
1156
|
+
output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height*up_sampling_factor, width*up_sampling_factor) if upsampling; otherwise, same height/width as input (after skip connection).
|
|
1157
|
+
"""
|
|
1158
|
+
x = self.up_sampling_(x)
|
|
1159
|
+
x = torch.cat(tensors=[x, skip_connection], dim=1)
|
|
1160
|
+
output = x
|
|
1161
|
+
for i in range(self.num_layers):
|
|
1162
|
+
resnet_input = output
|
|
1163
|
+
output = self.conv1[i](output)
|
|
1164
|
+
output = output + self.time_embedding[i](embed_time)[:, :, None, None]
|
|
1165
|
+
output = self.conv2[i](output)
|
|
1166
|
+
output = output + self.resnet[i](resnet_input)
|
|
1167
|
+
|
|
1168
|
+
if not self.y_to_all and i == 0:
|
|
1169
|
+
out_attn = self.attention[i](output, y)
|
|
1170
|
+
output = output + out_attn
|
|
1171
|
+
elif self.y_to_all:
|
|
1172
|
+
out_attn = self.attention[i](output, y)
|
|
1173
|
+
output = output + out_attn
|
|
1174
|
+
|
|
1175
|
+
return output
|
|
1176
|
+
|
|
1177
|
+
###==================================================================================================================###
|
|
1178
|
+
|
|
1179
|
+
class Conv3(nn.Module):
|
|
1180
|
+
"""Convolutional layer with optional group normalization, SiLU activation, and dropout.
|
|
1181
|
+
|
|
1182
|
+
Used in DownBlock, MiddleBlock, and UpBlock for feature extraction in NoisePredictor.
|
|
1183
|
+
|
|
1184
|
+
Parameters
|
|
1185
|
+
----------
|
|
1186
|
+
in_channels : int
|
|
1187
|
+
Number of input channels.
|
|
1188
|
+
out_channels : int
|
|
1189
|
+
Number of output channels.
|
|
1190
|
+
num_groups : int, optional
|
|
1191
|
+
Number of groups for group normalization (default: 8).
|
|
1192
|
+
kernel_size : int, optional
|
|
1193
|
+
Convolutional kernel size (default: 3).
|
|
1194
|
+
norm : bool, optional
|
|
1195
|
+
If True, apply group normalization (default: True).
|
|
1196
|
+
activation : bool, optional
|
|
1197
|
+
If True, apply SiLU activation (default: True).
|
|
1198
|
+
dropout_rate : float, optional
|
|
1199
|
+
Dropout rate (default: 0.2).
|
|
1200
|
+
"""
|
|
1201
|
+
def __init__(
|
|
1202
|
+
self,
|
|
1203
|
+
in_channels: int,
|
|
1204
|
+
out_channels: int,
|
|
1205
|
+
num_groups: int = 8,
|
|
1206
|
+
kernel_size: int = 3,
|
|
1207
|
+
norm: bool = True,
|
|
1208
|
+
activation: bool = True,
|
|
1209
|
+
dropout_rate: float = 0.2
|
|
1210
|
+
) -> None:
|
|
1211
|
+
super().__init__()
|
|
1212
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
|
|
1213
|
+
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels) if norm else nn.Identity()
|
|
1214
|
+
self.activation = nn.SiLU() if activation else nn.Identity()
|
|
1215
|
+
self.dropout = nn.Dropout(p=dropout_rate)
|
|
1216
|
+
|
|
1217
|
+
def forward(self, batch: torch.Tensor) -> torch.Tensor:
|
|
1218
|
+
"""Processes input through convolution, normalization, activation, and dropout.
|
|
1219
|
+
|
|
1220
|
+
Parameters
|
|
1221
|
+
----------
|
|
1222
|
+
batch : torch.Tensor
|
|
1223
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
1224
|
+
|
|
1225
|
+
Returns
|
|
1226
|
+
-------
|
|
1227
|
+
batch (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height, width).
|
|
1228
|
+
"""
|
|
1229
|
+
batch = self.conv(batch)
|
|
1230
|
+
batch = self.group_norm(batch)
|
|
1231
|
+
batch = self.activation(batch)
|
|
1232
|
+
batch = self.dropout(batch)
|
|
1233
|
+
return batch
|
|
1234
|
+
|
|
1235
|
+
###==================================================================================================================###
|
|
1236
|
+
|
|
1237
|
+
class TimeEmbedding(nn.Module):
|
|
1238
|
+
"""Time embedding projection for conditioning NoisePredictor layers.
|
|
1239
|
+
|
|
1240
|
+
Projects time embeddings to match the channel dimension of convolutional outputs.
|
|
1241
|
+
|
|
1242
|
+
Parameters
|
|
1243
|
+
----------
|
|
1244
|
+
output_dim : int
|
|
1245
|
+
Output channel dimension (matches convolutional channels).
|
|
1246
|
+
embed_dim : int
|
|
1247
|
+
Input time embedding dimension.
|
|
1248
|
+
"""
|
|
1249
|
+
def __init__(self, output_dim: int, embed_dim: int) -> None:
|
|
1250
|
+
super().__init__()
|
|
1251
|
+
self.embedding = nn.Sequential(
|
|
1252
|
+
nn.SiLU(),
|
|
1253
|
+
nn.Linear(in_features=embed_dim, out_features=output_dim)
|
|
1254
|
+
)
|
|
1255
|
+
def forward(self, batch: torch.Tensor) -> torch.Tensor:
|
|
1256
|
+
"""Projects time embeddings to output dimension.
|
|
1257
|
+
|
|
1258
|
+
Parameters
|
|
1259
|
+
----------
|
|
1260
|
+
batch : torch.Tensor
|
|
1261
|
+
Time embeddings, shape (batch_size, embed_dim).
|
|
1262
|
+
|
|
1263
|
+
Returns
|
|
1264
|
+
-------
|
|
1265
|
+
torch.Tensor
|
|
1266
|
+
Projected embeddings, shape (batch_size, output_dim).
|
|
1267
|
+
"""
|
|
1268
|
+
return self.embedding(batch)
|
|
1269
|
+
|
|
1270
|
+
###==================================================================================================================###
|
|
1271
|
+
|
|
1272
|
+
class GetEmbeddedTime(nn.Module):
|
|
1273
|
+
"""Generates sinusoidal time embeddings for NoisePredictor.
|
|
1274
|
+
|
|
1275
|
+
Creates positional encodings for time steps using sine and cosine functions, following
|
|
1276
|
+
the transformer embedding approach.
|
|
1277
|
+
|
|
1278
|
+
Parameters
|
|
1279
|
+
----------
|
|
1280
|
+
embed_dim : int
|
|
1281
|
+
Dimensionality of the time embeddings (must be even).
|
|
1282
|
+
"""
|
|
1283
|
+
def __init__(self, embed_dim: int) -> None:
|
|
1284
|
+
super().__init__()
|
|
1285
|
+
assert embed_dim % 2 == 0, "The embedding dimension must be divisible by two"
|
|
1286
|
+
self.embed_dim = embed_dim
|
|
1287
|
+
|
|
1288
|
+
def forward(self, time_steps: torch.Tensor) -> torch.Tensor:
|
|
1289
|
+
"""Generates sinusoidal embeddings for time steps.
|
|
1290
|
+
|
|
1291
|
+
Parameters
|
|
1292
|
+
----------
|
|
1293
|
+
time_steps : torch.Tensor
|
|
1294
|
+
Time steps, shape (batch_size,).
|
|
1295
|
+
|
|
1296
|
+
Returns
|
|
1297
|
+
-------
|
|
1298
|
+
embed_time (torch.Tensor) - Sinusoidal embeddings, shape (batch_size, embed_dim).
|
|
1299
|
+
"""
|
|
1300
|
+
i = torch.arange(start=0, end=self.embed_dim // 2, dtype=torch.float32, device=time_steps.device)
|
|
1301
|
+
factor = 10000 ** (2 * i / self.embed_dim)
|
|
1302
|
+
embed_time = time_steps[:, None] / factor
|
|
1303
|
+
embed_time = torch.cat(tensors=[torch.sin(embed_time), torch.cos(embed_time)], dim=-1)
|
|
1304
|
+
return embed_time
|
|
1305
|
+
|
|
1306
|
+
###==================================================================================================================###
|
|
1307
|
+
|
|
1308
|
+
|
|
1309
|
+
class DownSampling(nn.Module):
|
|
1310
|
+
"""Downsampling module for NoisePredictor’s DownBlock.
|
|
1311
|
+
|
|
1312
|
+
Combines convolutional downsampling and max pooling (if enabled), concatenating
|
|
1313
|
+
outputs to preserve feature information.
|
|
1314
|
+
|
|
1315
|
+
Parameters
|
|
1316
|
+
----------
|
|
1317
|
+
in_channels : int
|
|
1318
|
+
Number of input channels.
|
|
1319
|
+
out_channels : int
|
|
1320
|
+
Number of output channels.
|
|
1321
|
+
down_sampling_factor : int
|
|
1322
|
+
Factor for spatial downsampling.
|
|
1323
|
+
conv_block : bool, optional
|
|
1324
|
+
If True, include convolutional path (default: True).
|
|
1325
|
+
max_pool : bool, optional
|
|
1326
|
+
If True, include max pooling path (default: True).
|
|
1327
|
+
"""
|
|
1328
|
+
def __init__(
|
|
1329
|
+
self,
|
|
1330
|
+
in_channels: int,
|
|
1331
|
+
out_channels: int,
|
|
1332
|
+
down_sampling_factor: int,
|
|
1333
|
+
conv_block: bool = True,
|
|
1334
|
+
max_pool: bool = True
|
|
1335
|
+
) -> None:
|
|
1336
|
+
super().__init__()
|
|
1337
|
+
self.conv_block = conv_block
|
|
1338
|
+
self.max_pool = max_pool
|
|
1339
|
+
self.down_sampling_factor = down_sampling_factor
|
|
1340
|
+
self.conv = nn.Sequential(
|
|
1341
|
+
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
|
|
1342
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2 if max_pool else out_channels,
|
|
1343
|
+
kernel_size=3, stride=down_sampling_factor, padding=1)
|
|
1344
|
+
) if conv_block else nn.Identity()
|
|
1345
|
+
self.pool = nn.Sequential(
|
|
1346
|
+
nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
|
|
1347
|
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels//2 if conv_block else out_channels,
|
|
1348
|
+
kernel_size=1, stride=1, padding=0)
|
|
1349
|
+
) if max_pool else nn.Identity()
|
|
1350
|
+
|
|
1351
|
+
def forward(self, batch: torch.Tensor) -> torch.Tensor:
|
|
1352
|
+
"""Downsamples input using convolutional and/or pooling paths.
|
|
1353
|
+
|
|
1354
|
+
Parameters
|
|
1355
|
+
----------
|
|
1356
|
+
batch : torch.Tensor
|
|
1357
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
1358
|
+
|
|
1359
|
+
Returns
|
|
1360
|
+
-------
|
|
1361
|
+
batch (torch.Tensor) - Downsampled tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor).
|
|
1362
|
+
"""
|
|
1363
|
+
if not self.conv_block:
|
|
1364
|
+
return self.pool(batch)
|
|
1365
|
+
if not self.max_pool:
|
|
1366
|
+
return self.conv(batch)
|
|
1367
|
+
return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
|
|
1368
|
+
|
|
1369
|
+
###==================================================================================================================###
|
|
1370
|
+
|
|
1371
|
+
class UpSampling(nn.Module):
|
|
1372
|
+
"""Upsampling module for NoisePredictor’s UpBlock.
|
|
1373
|
+
|
|
1374
|
+
Combines transposed convolution and nearest-neighbor upsampling (if enabled),
|
|
1375
|
+
concatenating outputs to preserve feature information, with interpolation to align
|
|
1376
|
+
spatial dimensions if needed.
|
|
1377
|
+
|
|
1378
|
+
Parameters
|
|
1379
|
+
----------
|
|
1380
|
+
in_channels : int
|
|
1381
|
+
Number of input channels.
|
|
1382
|
+
out_channels : int
|
|
1383
|
+
Number of output channels.
|
|
1384
|
+
up_sampling_factor : int
|
|
1385
|
+
Factor for spatial upsampling.
|
|
1386
|
+
conv_block : bool, optional
|
|
1387
|
+
If True, include transposed convolutional path (default: True).
|
|
1388
|
+
up_sampling : bool, optional
|
|
1389
|
+
If True, include nearest-neighbor upsampling path (default: True).
|
|
1390
|
+
"""
|
|
1391
|
+
def __init__(
|
|
1392
|
+
self,
|
|
1393
|
+
in_channels: int,
|
|
1394
|
+
out_channels: int,
|
|
1395
|
+
up_sampling_factor: int,
|
|
1396
|
+
conv_block: bool = True,
|
|
1397
|
+
up_sampling: bool = True
|
|
1398
|
+
) -> None:
|
|
1399
|
+
super().__init__()
|
|
1400
|
+
self.conv_block = conv_block
|
|
1401
|
+
self.up_sampling = up_sampling
|
|
1402
|
+
self.up_sampling_factor = up_sampling_factor
|
|
1403
|
+
half_out_channels = out_channels // 2
|
|
1404
|
+
self.conv = nn.Sequential(
|
|
1405
|
+
nn.ConvTranspose2d(
|
|
1406
|
+
in_channels=in_channels,
|
|
1407
|
+
out_channels=half_out_channels if up_sampling else out_channels,
|
|
1408
|
+
kernel_size=3,
|
|
1409
|
+
stride=up_sampling_factor,
|
|
1410
|
+
padding=1,
|
|
1411
|
+
output_padding=up_sampling_factor - 1
|
|
1412
|
+
),
|
|
1413
|
+
nn.Conv2d(
|
|
1414
|
+
in_channels=half_out_channels if up_sampling else out_channels,
|
|
1415
|
+
out_channels=half_out_channels if up_sampling else out_channels,
|
|
1416
|
+
kernel_size=1,
|
|
1417
|
+
stride=1,
|
|
1418
|
+
padding=0
|
|
1419
|
+
)
|
|
1420
|
+
) if conv_block else nn.Identity()
|
|
1421
|
+
|
|
1422
|
+
self.up_sample = nn.Sequential(
|
|
1423
|
+
nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
|
|
1424
|
+
nn.Conv2d(in_channels=in_channels, out_channels=half_out_channels if conv_block else out_channels,
|
|
1425
|
+
kernel_size=1, stride=1, padding=0)
|
|
1426
|
+
) if up_sampling else nn.Identity()
|
|
1427
|
+
|
|
1428
|
+
def forward(self, batch: torch.Tensor) -> torch.Tensor:
|
|
1429
|
+
"""Upsamples input using convolutional and/or upsampling paths.
|
|
1430
|
+
|
|
1431
|
+
Parameters
|
|
1432
|
+
----------
|
|
1433
|
+
batch : torch.Tensor
|
|
1434
|
+
Input tensor, shape (batch_size, in_channels, height, width).
|
|
1435
|
+
|
|
1436
|
+
Returns
|
|
1437
|
+
-------
|
|
1438
|
+
batch (torch.Tensor) - Upsampled tensor, shape (batch_size, out_channels, height*up_sampling_factor, width*up_sampling_factor).
|
|
1439
|
+
|
|
1440
|
+
**Notes**
|
|
1441
|
+
|
|
1442
|
+
- Interpolation is applied if the spatial dimensions of the convolutional and
|
|
1443
|
+
upsampling paths differ, using nearest-neighbor mode.
|
|
1444
|
+
"""
|
|
1445
|
+
if not self.conv_block:
|
|
1446
|
+
return self.up_sample(batch)
|
|
1447
|
+
if not self.up_sampling:
|
|
1448
|
+
return self.conv(batch)
|
|
1449
|
+
conv_output = self.conv(batch)
|
|
1450
|
+
up_sample_output = self.up_sample(batch)
|
|
1451
|
+
if conv_output.shape[2:] != up_sample_output.shape[2:]:
|
|
1452
|
+
_, _, h, w = conv_output.shape
|
|
1453
|
+
up_sample_output = torch.nn.functional.interpolate(
|
|
1454
|
+
up_sample_output,
|
|
1455
|
+
size=(h, w),
|
|
1456
|
+
mode='nearest'
|
|
1457
|
+
)
|
|
1458
|
+
return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
|
|
1459
|
+
|
|
1460
|
+
###==================================================================================================================###
|
|
1461
|
+
|
|
1462
|
+
class Metrics:
|
|
1463
|
+
"""Computes image quality metrics for evaluating diffusion models.
|
|
1464
|
+
|
|
1465
|
+
Supports Mean Squared Error (MSE), Peak Signal-to-Noise Ratio (PSNR), Structural
|
|
1466
|
+
Similarity Index (SSIM), Fréchet Inception Distance (FID), and Learned Perceptual
|
|
1467
|
+
Image Patch Similarity (LPIPS) for comparing generated and ground truth images.
|
|
1468
|
+
|
|
1469
|
+
Parameters
|
|
1470
|
+
----------
|
|
1471
|
+
device : str, optional
|
|
1472
|
+
Device for computation (e.g., 'cuda', 'cpu') (default: 'cuda').
|
|
1473
|
+
fid : bool, optional
|
|
1474
|
+
If True, compute FID score (default: True).
|
|
1475
|
+
metrics : bool, optional
|
|
1476
|
+
If True, compute MSE, PSNR, and SSIM (default: False).
|
|
1477
|
+
lpips : bool, optional
|
|
1478
|
+
If True, compute LPIPS using VGG backbone (default: False).
|
|
1479
|
+
"""
|
|
1480
|
+
|
|
1481
|
+
def __init__(
|
|
1482
|
+
self,
|
|
1483
|
+
device: str = "cuda",
|
|
1484
|
+
fid: bool = True,
|
|
1485
|
+
metrics: bool = False,
|
|
1486
|
+
lpips_: bool = False
|
|
1487
|
+
) -> None:
|
|
1488
|
+
self.device = device
|
|
1489
|
+
self.fid = fid
|
|
1490
|
+
self.metrics = metrics
|
|
1491
|
+
self.lpips = lpips_
|
|
1492
|
+
self.lpips_model = LearnedPerceptualImagePatchSimilarity(
|
|
1493
|
+
net_type='vgg',
|
|
1494
|
+
normalize=True # This handles [0,1] -> [-1,1] conversion
|
|
1495
|
+
).to(device) if self.lpips else None
|
|
1496
|
+
self.temp_dir_real = "temp_real"
|
|
1497
|
+
self.temp_dir_fake = "temp_fake"
|
|
1498
|
+
|
|
1499
|
+
def compute_fid(self, real_images: torch.Tensor, fake_images: torch.Tensor) -> float:
|
|
1500
|
+
"""Computes the Fréchet Inception Distance (FID) between real and generated images.
|
|
1501
|
+
|
|
1502
|
+
Saves images to temporary directories and uses Inception V3 to compute FID,
|
|
1503
|
+
cleaning up directories afterward.
|
|
1504
|
+
|
|
1505
|
+
Parameters
|
|
1506
|
+
----------
|
|
1507
|
+
real_images : torch.Tensor
|
|
1508
|
+
Real images, shape (batch_size, channels, height, width), in [-1, 1].
|
|
1509
|
+
fake_images : torch.Tensor
|
|
1510
|
+
Generated images, same shape, in [-1, 1].
|
|
1511
|
+
|
|
1512
|
+
Returns
|
|
1513
|
+
-------
|
|
1514
|
+
fid (float) - FID score, or `float('inf')` if computation fails.
|
|
1515
|
+
|
|
1516
|
+
**Notes**
|
|
1517
|
+
|
|
1518
|
+
- Images are normalized to [0, 1] and saved as PNG files for FID computation.
|
|
1519
|
+
- Uses Inception V3 with 2048-dimensional features (`dims=2048`).
|
|
1520
|
+
"""
|
|
1521
|
+
if real_images.shape != fake_images.shape:
|
|
1522
|
+
raise ValueError(f"Shape mismatch: real_images {real_images.shape}, fake_images {fake_images.shape}")
|
|
1523
|
+
|
|
1524
|
+
real_images = (real_images + 1) / 2
|
|
1525
|
+
fake_images = (fake_images + 1) / 2
|
|
1526
|
+
real_images = real_images.clamp(0, 1).cpu()
|
|
1527
|
+
fake_images = fake_images.clamp(0, 1).cpu()
|
|
1528
|
+
|
|
1529
|
+
os.makedirs(self.temp_dir_real, exist_ok=True)
|
|
1530
|
+
os.makedirs(self.temp_dir_fake, exist_ok=True)
|
|
1531
|
+
|
|
1532
|
+
try:
|
|
1533
|
+
for i, (real, fake) in enumerate(zip(real_images, fake_images)):
|
|
1534
|
+
save_image(real, f"{self.temp_dir_real}/{i}.png")
|
|
1535
|
+
save_image(fake, f"{self.temp_dir_fake}/{i}.png")
|
|
1536
|
+
|
|
1537
|
+
fid = fid_score.calculate_fid_given_paths(
|
|
1538
|
+
paths=[self.temp_dir_real, self.temp_dir_fake],
|
|
1539
|
+
batch_size=50,
|
|
1540
|
+
device=self.device,
|
|
1541
|
+
dims=2048
|
|
1542
|
+
)
|
|
1543
|
+
except Exception as e:
|
|
1544
|
+
print(f"Error computing FID: {e}")
|
|
1545
|
+
fid = float('inf')
|
|
1546
|
+
finally:
|
|
1547
|
+
shutil.rmtree(self.temp_dir_real, ignore_errors=True)
|
|
1548
|
+
shutil.rmtree(self.temp_dir_fake, ignore_errors=True)
|
|
1549
|
+
|
|
1550
|
+
return fid
|
|
1551
|
+
|
|
1552
|
+
def compute_metrics(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float]:
|
|
1553
|
+
"""Computes MSE, PSNR, and SSIM for evaluating image quality.
|
|
1554
|
+
|
|
1555
|
+
Parameters
|
|
1556
|
+
----------
|
|
1557
|
+
x : torch.Tensor
|
|
1558
|
+
Ground truth images, shape (batch_size, channels, height, width).
|
|
1559
|
+
x_hat : torch.Tensor
|
|
1560
|
+
Generated images, same shape as `x`.
|
|
1561
|
+
|
|
1562
|
+
Returns
|
|
1563
|
+
-------
|
|
1564
|
+
mse : float
|
|
1565
|
+
Mean squared error.
|
|
1566
|
+
psnr : float
|
|
1567
|
+
Peak signal-to-noise ratio.
|
|
1568
|
+
ssim : float
|
|
1569
|
+
Structural similarity index (mean over batch).
|
|
1570
|
+
"""
|
|
1571
|
+
if x.shape != x_hat.shape:
|
|
1572
|
+
raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
|
|
1573
|
+
|
|
1574
|
+
mse = F.mse_loss(x_hat, x)
|
|
1575
|
+
psnr = -10 * torch.log10(mse)
|
|
1576
|
+
c1, c2 = (0.01 * 2) ** 2, (0.03 * 2) ** 2 # Adjusted for [-1, 1] range
|
|
1577
|
+
eps = 1e-8
|
|
1578
|
+
mu_x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
|
1579
|
+
mu_y = F.avg_pool2d(x_hat, kernel_size=3, stride=1, padding=1)
|
|
1580
|
+
mu_xy = mu_x * mu_y
|
|
1581
|
+
sigma_x_sq = F.avg_pool2d(x.pow(2), kernel_size=3, stride=1, padding=1) - mu_x.pow(2)
|
|
1582
|
+
sigma_y_sq = F.avg_pool2d(x_hat.pow(2), kernel_size=3, stride=1, padding=1) - mu_y.pow(2)
|
|
1583
|
+
sigma_xy = F.avg_pool2d(x * x_hat, kernel_size=3, stride=1, padding=1) - mu_xy
|
|
1584
|
+
ssim = ((2 * mu_xy + c1) * (2 * sigma_xy + c2)) / (
|
|
1585
|
+
(mu_x.pow(2) + mu_y.pow(2) + c1) * (sigma_x_sq + sigma_y_sq + c2) + eps
|
|
1586
|
+
)
|
|
1587
|
+
|
|
1588
|
+
return mse.item(), psnr.item(), ssim.mean().item()
|
|
1589
|
+
|
|
1590
|
+
def compute_lpips(self, x: torch.Tensor, x_hat: torch.Tensor) -> float:
|
|
1591
|
+
"""Computes LPIPS using a pre-trained VGG network.
|
|
1592
|
+
|
|
1593
|
+
Parameters
|
|
1594
|
+
----------
|
|
1595
|
+
x : torch.Tensor
|
|
1596
|
+
Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
|
|
1597
|
+
x_hat : torch.Tensor
|
|
1598
|
+
Generated images, same shape as `x`.
|
|
1599
|
+
|
|
1600
|
+
Returns
|
|
1601
|
+
-------
|
|
1602
|
+
lpips (float) - Mean LPIPS score over the batch.
|
|
1603
|
+
"""
|
|
1604
|
+
if self.lpips_model is None:
|
|
1605
|
+
raise RuntimeError("LPIPS model not initialized; set lpips=True in __init__")
|
|
1606
|
+
if x.shape != x_hat.shape:
|
|
1607
|
+
raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
|
|
1608
|
+
|
|
1609
|
+
# Normalize inputs to [0, 1] range
|
|
1610
|
+
x = (x + 1) / 2 # Convert from [-1, 1] to [0, 1]
|
|
1611
|
+
x_hat = (x_hat + 1) / 2
|
|
1612
|
+
x = x.clamp(0, 1) # Ensure values are in [0, 1]
|
|
1613
|
+
x_hat = x_hat.clamp(0, 1)
|
|
1614
|
+
|
|
1615
|
+
x = x.to(self.device)
|
|
1616
|
+
x_hat = x_hat.to(self.device)
|
|
1617
|
+
|
|
1618
|
+
# Convert grayscale to RGB if needed
|
|
1619
|
+
if x.shape[1] == 1:
|
|
1620
|
+
x = x.repeat(1, 3, 1, 1) # Repeat grayscale channel 3 times
|
|
1621
|
+
if x_hat.shape[1] == 1:
|
|
1622
|
+
x_hat = x_hat.repeat(1, 3, 1, 1)
|
|
1623
|
+
|
|
1624
|
+
return self.lpips_model(x, x_hat).mean().item()
|
|
1625
|
+
|
|
1626
|
+
def forward(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float, float, float]:
|
|
1627
|
+
"""Computes specified metrics for ground truth and generated images.
|
|
1628
|
+
|
|
1629
|
+
Parameters
|
|
1630
|
+
----------
|
|
1631
|
+
x : torch.Tensor
|
|
1632
|
+
Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
|
|
1633
|
+
x_hat : torch.Tensor
|
|
1634
|
+
Generated images, same shape as `x`.
|
|
1635
|
+
|
|
1636
|
+
Returns
|
|
1637
|
+
-------
|
|
1638
|
+
fid : float, or `float('inf')` if not computed
|
|
1639
|
+
Mean FID score.
|
|
1640
|
+
mse : float, or None if not computed
|
|
1641
|
+
Mean MSE
|
|
1642
|
+
psnr : float, or None if not computed
|
|
1643
|
+
Mean PSNR
|
|
1644
|
+
ssim : float, or None if not computed
|
|
1645
|
+
Mean SSIM
|
|
1646
|
+
lpips_score : float, or None if not computed
|
|
1647
|
+
Mean LPIPS score
|
|
1648
|
+
"""
|
|
1649
|
+
fid = float('inf')
|
|
1650
|
+
mse, psnr, ssim = None, None, None
|
|
1651
|
+
lpips_score = None
|
|
1652
|
+
|
|
1653
|
+
if self.metrics:
|
|
1654
|
+
mse, psnr, ssim = self.compute_metrics(x, x_hat)
|
|
1655
|
+
if self.fid:
|
|
1656
|
+
fid = self.compute_fid(x, x_hat)
|
|
1657
|
+
if self.lpips:
|
|
1658
|
+
lpips_score = self.compute_lpips(x, x_hat)
|
|
1659
|
+
|
|
1660
|
+
return fid, mse, psnr, ssim, lpips_score
|