TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ldm/text_encoder.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import math
|
|
4
|
+
from transformers import BertModel, DistilBertModel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TextEncoder(torch.nn.Module):
|
|
10
|
+
"""Transformer-based encoder for text prompts in conditional diffusion models.
|
|
11
|
+
|
|
12
|
+
Encodes text prompts into embeddings using either a pre-trained BERT model or a
|
|
13
|
+
custom transformer architecture. Used as the `conditional_model` in diffusion models
|
|
14
|
+
(e.g., DDPM, DDIM, SDE, LDM) to provide conditional inputs for noise prediction.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
use_pretrained_model : bool, optional
|
|
19
|
+
If True, uses a pre-trained BERT model; otherwise, builds a custom transformer
|
|
20
|
+
(default: True).
|
|
21
|
+
model_name : str, optional
|
|
22
|
+
Name of the pre-trained model to load (default: "bert-base-uncased").
|
|
23
|
+
vocabulary_size : int, optional
|
|
24
|
+
Size of the vocabulary for the custom transformer’s embedding layer
|
|
25
|
+
(default: 30522).
|
|
26
|
+
num_layers : int, optional
|
|
27
|
+
Number of transformer encoder layers for the custom transformer (default: 6).
|
|
28
|
+
input_dimension : int, optional
|
|
29
|
+
Input embedding dimension for the custom transformer (default: 768).
|
|
30
|
+
output_dimension : int, optional
|
|
31
|
+
Output embedding dimension for both pre-trained and custom models
|
|
32
|
+
(default: 768).
|
|
33
|
+
num_heads : int, optional
|
|
34
|
+
Number of attention heads in the custom transformer (default: 8).
|
|
35
|
+
context_length : int, optional
|
|
36
|
+
Maximum sequence length for text prompts (default: 77).
|
|
37
|
+
dropout_rate : float, optional
|
|
38
|
+
Dropout rate for attention and feedforward layers (default: 0.1).
|
|
39
|
+
qkv_bias : bool, optional
|
|
40
|
+
If True, includes bias in query, key, and value projections for the custom
|
|
41
|
+
transformer (default: False).
|
|
42
|
+
scaling_value : int, optional
|
|
43
|
+
Scaling factor for the feedforward layer’s hidden dimension in the custom
|
|
44
|
+
transformer (default: 4).
|
|
45
|
+
epsilon : float, optional
|
|
46
|
+
Epsilon for layer normalization in the custom transformer (default: 1e-5).
|
|
47
|
+
|
|
48
|
+
Attributes
|
|
49
|
+
----------
|
|
50
|
+
use_pretrained_model : bool
|
|
51
|
+
Whether a pre-trained model is used.
|
|
52
|
+
bert : transformers.BertModel or None
|
|
53
|
+
Pre-trained BERT model, if `use_pretrained_model` is True.
|
|
54
|
+
projection : torch.nn.Linear or None
|
|
55
|
+
Linear layer to project BERT outputs to `output_dimension`, if
|
|
56
|
+
`use_pretrained_model` is True.
|
|
57
|
+
embedding : Embedding or None
|
|
58
|
+
Token and positional embedding layer for the custom transformer, if
|
|
59
|
+
`use_pretrained_model` is False.
|
|
60
|
+
layers : torch.nn.ModuleList or None
|
|
61
|
+
List of EncoderLayer modules for the custom transformer, if
|
|
62
|
+
`use_pretrained_model` is False.
|
|
63
|
+
|
|
64
|
+
Notes
|
|
65
|
+
-----
|
|
66
|
+
- When `use_pretrained_model` is True, the BERT model’s parameters are frozen
|
|
67
|
+
(`requires_grad = False`), and a projection layer maps outputs to
|
|
68
|
+
`output_dimension`.
|
|
69
|
+
- The custom transformer uses `EncoderLayer` modules with multi-head attention and
|
|
70
|
+
feedforward networks, supporting variable input/output dimensions.
|
|
71
|
+
- The output shape is (batch_size, context_length, output_dimension).
|
|
72
|
+
"""
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
use_pretrained_model=True,
|
|
76
|
+
model_name="bert-base-uncased",
|
|
77
|
+
vocabulary_size=30522,
|
|
78
|
+
num_layers=6,
|
|
79
|
+
input_dimension=768,
|
|
80
|
+
output_dimension=768,
|
|
81
|
+
num_heads=8,
|
|
82
|
+
context_length=77,
|
|
83
|
+
dropout_rate=0.1,
|
|
84
|
+
qkv_bias=False,
|
|
85
|
+
scaling_value=4,
|
|
86
|
+
epsilon=1e-5
|
|
87
|
+
):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.use_pretrained_model = use_pretrained_model
|
|
90
|
+
if self.use_pretrained_model:
|
|
91
|
+
# self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
|
|
92
|
+
self.bert = BertModel.from_pretrained(model_name)
|
|
93
|
+
for param in self.bert.parameters():
|
|
94
|
+
param.requires_grad = False
|
|
95
|
+
self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
|
|
96
|
+
else:
|
|
97
|
+
self.embedding = Embedding(
|
|
98
|
+
vocabulary_size=vocabulary_size,
|
|
99
|
+
embedding_dimension=input_dimension,
|
|
100
|
+
context_length=context_length
|
|
101
|
+
)
|
|
102
|
+
self.layers = torch.nn.ModuleList([
|
|
103
|
+
EncoderLayer(
|
|
104
|
+
input_dimension=input_dimension,
|
|
105
|
+
output_dimension=output_dimension,
|
|
106
|
+
num_heads=num_heads,
|
|
107
|
+
dropout_rate=dropout_rate,
|
|
108
|
+
qkv_bias=qkv_bias,
|
|
109
|
+
scaling_value=scaling_value,
|
|
110
|
+
epsilon=epsilon
|
|
111
|
+
)
|
|
112
|
+
for _ in range(num_layers)
|
|
113
|
+
])
|
|
114
|
+
def forward(self, x, attention_mask=None):
|
|
115
|
+
"""Encodes text prompts into embeddings.
|
|
116
|
+
|
|
117
|
+
Processes input token IDs and an optional attention mask to produce embeddings
|
|
118
|
+
using either a pre-trained BERT model or a custom transformer.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
x : torch.Tensor
|
|
123
|
+
Token IDs, shape (batch_size, seq_len).
|
|
124
|
+
attention_mask : torch.Tensor, optional
|
|
125
|
+
Attention mask, shape (batch_size, seq_len), where 0 indicates padding
|
|
126
|
+
tokens to ignore (default: None).
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
torch.Tensor
|
|
131
|
+
Encoded embeddings, shape (batch_size, seq_len, output_dimension).
|
|
132
|
+
|
|
133
|
+
Notes
|
|
134
|
+
-----
|
|
135
|
+
- For pre-trained BERT, the `last_hidden_state` is projected to
|
|
136
|
+
`output_dimension`.
|
|
137
|
+
- For the custom transformer, token embeddings are processed through
|
|
138
|
+
`Embedding` and `EncoderLayer` modules.
|
|
139
|
+
- The attention mask should be 0 for padding tokens and 1 for valid tokens when
|
|
140
|
+
using the custom transformer, or follow BERT’s convention for pre-trained
|
|
141
|
+
models.
|
|
142
|
+
"""
|
|
143
|
+
if self.use_pretrained_model:
|
|
144
|
+
x = self.bert(input_ids=x, attention_mask=attention_mask)
|
|
145
|
+
x = x.last_hidden_state
|
|
146
|
+
x = self.projection(x)
|
|
147
|
+
else:
|
|
148
|
+
x = self.embedding(x)
|
|
149
|
+
for layer in self.layers:
|
|
150
|
+
x = layer(x, attention_mask=attention_mask)
|
|
151
|
+
return x
|
|
152
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
153
|
+
class EncoderLayer(torch.nn.Module):
|
|
154
|
+
"""Single transformer encoder layer with multi-head attention and feedforward network.
|
|
155
|
+
|
|
156
|
+
Used in the custom transformer of `TextEncoder` to process embedded text prompts.
|
|
157
|
+
|
|
158
|
+
Parameters
|
|
159
|
+
----------
|
|
160
|
+
input_dimension : int
|
|
161
|
+
Input embedding dimension.
|
|
162
|
+
output_dimension : int
|
|
163
|
+
Output embedding dimension.
|
|
164
|
+
num_heads : int
|
|
165
|
+
Number of attention heads.
|
|
166
|
+
dropout_rate : float
|
|
167
|
+
Dropout rate for attention and feedforward layers.
|
|
168
|
+
qkv_bias : bool
|
|
169
|
+
If True, includes bias in query, key, and value projections.
|
|
170
|
+
scaling_value : int
|
|
171
|
+
Scaling factor for the feedforward layer’s hidden dimension.
|
|
172
|
+
epsilon : float, optional
|
|
173
|
+
Epsilon for layer normalization (default: 1e-5).
|
|
174
|
+
|
|
175
|
+
Attributes
|
|
176
|
+
----------
|
|
177
|
+
attention : torch.nn.MultiheadAttention
|
|
178
|
+
Multi-head self-attention mechanism.
|
|
179
|
+
output_projection : torch.nn.Linear or torch.nn.Identity
|
|
180
|
+
Linear layer to project attention outputs to `output_dimension`, or identity
|
|
181
|
+
if `input_dimension` equals `output_dimension`.
|
|
182
|
+
norm1 : torch.nn.LayerNorm
|
|
183
|
+
Layer normalization after attention.
|
|
184
|
+
dropout1 : torch.nn.Dropout
|
|
185
|
+
Dropout after attention.
|
|
186
|
+
feedforward : FeedForward
|
|
187
|
+
Feedforward network.
|
|
188
|
+
norm2 : torch.nn.LayerNorm
|
|
189
|
+
Layer normalization after feedforward.
|
|
190
|
+
dropout2 : torch.nn.Dropout
|
|
191
|
+
Dropout after feedforward.
|
|
192
|
+
|
|
193
|
+
Notes
|
|
194
|
+
-----
|
|
195
|
+
- The layer follows the standard transformer encoder architecture: attention,
|
|
196
|
+
residual connection, normalization, feedforward, residual connection,
|
|
197
|
+
normalization.
|
|
198
|
+
- The attention mechanism uses `batch_first=True` for compatibility with
|
|
199
|
+
`TextEncoder`’s input format.
|
|
200
|
+
"""
|
|
201
|
+
def __init__(
|
|
202
|
+
self,
|
|
203
|
+
input_dimension,
|
|
204
|
+
output_dimension,
|
|
205
|
+
num_heads,
|
|
206
|
+
dropout_rate,
|
|
207
|
+
qkv_bias,
|
|
208
|
+
scaling_value,
|
|
209
|
+
epsilon=1e-5
|
|
210
|
+
):
|
|
211
|
+
super().__init__()
|
|
212
|
+
self.attention = nn.MultiheadAttention(
|
|
213
|
+
embed_dim=input_dimension,
|
|
214
|
+
num_heads=num_heads,
|
|
215
|
+
dropout=dropout_rate,
|
|
216
|
+
bias=qkv_bias,
|
|
217
|
+
batch_first=True
|
|
218
|
+
)
|
|
219
|
+
self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
|
|
220
|
+
self.norm1 = nn.LayerNorm(normalized_shape=input_dimension, eps=epsilon)
|
|
221
|
+
self.dropout1 = nn.Dropout(dropout_rate)
|
|
222
|
+
self.feedforward = FeedForward(
|
|
223
|
+
embedding_dimension=input_dimension,
|
|
224
|
+
scaling_value=scaling_value,
|
|
225
|
+
dropout_rate=dropout_rate
|
|
226
|
+
)
|
|
227
|
+
self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
|
|
228
|
+
self.dropout2 = nn.Dropout(dropout_rate)
|
|
229
|
+
def forward(self, x, attention_mask=None):
|
|
230
|
+
"""Processes input embeddings through attention and feedforward layers.
|
|
231
|
+
|
|
232
|
+
Parameters
|
|
233
|
+
----------
|
|
234
|
+
x : torch.Tensor
|
|
235
|
+
Input embeddings, shape (batch_size, seq_len, input_dimension).
|
|
236
|
+
attention_mask : torch.Tensor, optional
|
|
237
|
+
Attention mask, shape (batch_size, seq_len), where 0 indicates padding
|
|
238
|
+
tokens to ignore (default: None).
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
torch.Tensor
|
|
243
|
+
Processed embeddings, shape (batch_size, seq_len, output_dimension).
|
|
244
|
+
|
|
245
|
+
Notes
|
|
246
|
+
-----
|
|
247
|
+
- The attention mask is passed as `key_padding_mask` to
|
|
248
|
+
`nn.MultiheadAttention`, where 0 indicates padding tokens.
|
|
249
|
+
- Residual connections and normalization are applied after attention and
|
|
250
|
+
feedforward layers.
|
|
251
|
+
"""
|
|
252
|
+
attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
|
|
253
|
+
attn_output = self.output_projection(attn_output)
|
|
254
|
+
x = self.norm1(x + self.dropout1(attn_output))
|
|
255
|
+
ff_output = self.feedforward(x)
|
|
256
|
+
x = self.norm2(x + self.dropout2(ff_output))
|
|
257
|
+
return x
|
|
258
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
259
|
+
class FeedForward(torch.nn.Module):
|
|
260
|
+
"""Feedforward network for transformer encoder layers.
|
|
261
|
+
|
|
262
|
+
Used in `EncoderLayer` to process attention outputs with a two-layer MLP and GELU
|
|
263
|
+
activation.
|
|
264
|
+
|
|
265
|
+
Parameters
|
|
266
|
+
----------
|
|
267
|
+
embedding_dimension : int
|
|
268
|
+
Input and output embedding dimension.
|
|
269
|
+
scaling_value : int
|
|
270
|
+
Scaling factor for the hidden layer’s dimension (hidden_dim =
|
|
271
|
+
embedding_dimension * scaling_value).
|
|
272
|
+
dropout_rate : float, optional
|
|
273
|
+
Dropout rate after the hidden layer (default: 0.1).
|
|
274
|
+
|
|
275
|
+
Attributes
|
|
276
|
+
----------
|
|
277
|
+
layers : torch.nn.Sequential
|
|
278
|
+
Sequential container with linear, GELU, dropout, and linear layers.
|
|
279
|
+
|
|
280
|
+
Notes
|
|
281
|
+
-----
|
|
282
|
+
- The hidden layer dimension is `embedding_dimension * scaling_value`, following
|
|
283
|
+
standard transformer feedforward designs.
|
|
284
|
+
- GELU activation is used for non-linearity.
|
|
285
|
+
"""
|
|
286
|
+
def __init__(self, embedding_dimension, scaling_value, dropout_rate=0.1):
|
|
287
|
+
super().__init__()
|
|
288
|
+
self.layers = torch.nn.Sequential(
|
|
289
|
+
torch.nn.Linear(
|
|
290
|
+
in_features=embedding_dimension,
|
|
291
|
+
out_features=embedding_dimension * scaling_value,
|
|
292
|
+
bias=True
|
|
293
|
+
),
|
|
294
|
+
torch.nn.GELU(),
|
|
295
|
+
torch.nn.Dropout(dropout_rate),
|
|
296
|
+
torch.nn.Linear(
|
|
297
|
+
in_features=embedding_dimension * scaling_value,
|
|
298
|
+
out_features=embedding_dimension,
|
|
299
|
+
bias=True
|
|
300
|
+
)
|
|
301
|
+
)
|
|
302
|
+
def forward(self, x):
|
|
303
|
+
"""Processes input embeddings through the feedforward network.
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
x : torch.Tensor
|
|
308
|
+
Input embeddings, shape (batch_size, seq_len, embedding_dimension).
|
|
309
|
+
|
|
310
|
+
Returns
|
|
311
|
+
-------
|
|
312
|
+
torch.Tensor
|
|
313
|
+
Processed embeddings, shape (batch_size, seq_len, embedding_dimension).
|
|
314
|
+
"""
|
|
315
|
+
return self.layers(x)
|
|
316
|
+
#-----------------------------------------------------------------------------------------------------------------------
|
|
317
|
+
class Embedding(torch.nn.Module):
|
|
318
|
+
"""Token and positional embedding layer for transformer inputs.
|
|
319
|
+
|
|
320
|
+
Used in `TextEncoder`’s custom transformer to embed token IDs and add positional
|
|
321
|
+
encodings.
|
|
322
|
+
|
|
323
|
+
Parameters
|
|
324
|
+
----------
|
|
325
|
+
vocabulary_size : int
|
|
326
|
+
Size of the vocabulary for token embeddings.
|
|
327
|
+
embedding_dimension : int, optional
|
|
328
|
+
Dimension of token and positional embeddings (default: 768).
|
|
329
|
+
context_length : int, optional
|
|
330
|
+
Maximum sequence length for positional encodings (default: 77).
|
|
331
|
+
|
|
332
|
+
Attributes
|
|
333
|
+
----------
|
|
334
|
+
token_embedding : torch.nn.Embedding
|
|
335
|
+
Token embedding layer.
|
|
336
|
+
embedding_dimension : int
|
|
337
|
+
Dimension of embeddings.
|
|
338
|
+
context_length : int
|
|
339
|
+
Maximum sequence length.
|
|
340
|
+
positional_encoding : torch.Tensor
|
|
341
|
+
Pre-computed positional encodings, shape (1, context_length,
|
|
342
|
+
embedding_dimension).
|
|
343
|
+
|
|
344
|
+
Notes
|
|
345
|
+
-----
|
|
346
|
+
- Positional encodings are computed using sinusoidal functions, following the
|
|
347
|
+
transformer architecture.
|
|
348
|
+
- For sequences longer than `context_length`, positional encodings are dynamically
|
|
349
|
+
generated.
|
|
350
|
+
- The output shape is (batch_size, seq_len, embedding_dimension).
|
|
351
|
+
"""
|
|
352
|
+
def __init__(
|
|
353
|
+
self,
|
|
354
|
+
vocabulary_size,
|
|
355
|
+
embedding_dimension=768,
|
|
356
|
+
context_length=77
|
|
357
|
+
):
|
|
358
|
+
super().__init__()
|
|
359
|
+
self.token_embedding = nn.Embedding(
|
|
360
|
+
num_embeddings=vocabulary_size,
|
|
361
|
+
embedding_dim=embedding_dimension
|
|
362
|
+
)
|
|
363
|
+
self.embedding_dimension = embedding_dimension
|
|
364
|
+
self.context_length = context_length
|
|
365
|
+
self.register_buffer("positional_encoding", self._generate_positional_encoding(context_length))
|
|
366
|
+
|
|
367
|
+
def _generate_positional_encoding(self, seq_len):
|
|
368
|
+
"""Generates sinusoidal positional encodings for transformer inputs.
|
|
369
|
+
|
|
370
|
+
Computes positional encodings using sine and cosine functions, following the
|
|
371
|
+
transformer architecture, to represent token positions in a sequence.
|
|
372
|
+
|
|
373
|
+
Parameters
|
|
374
|
+
----------
|
|
375
|
+
seq_len : int
|
|
376
|
+
Length of the sequence for which to generate positional encodings.
|
|
377
|
+
|
|
378
|
+
Returns
|
|
379
|
+
-------
|
|
380
|
+
torch.Tensor
|
|
381
|
+
Positional encodings, shape (1, seq_len, embedding_dimension), where
|
|
382
|
+
even-indexed dimensions use sine and odd-indexed dimensions use cosine.
|
|
383
|
+
|
|
384
|
+
Notes
|
|
385
|
+
-----
|
|
386
|
+
- The encoding follows the formula: for position `pos` and dimension `i`,
|
|
387
|
+
`PE(pos, 2i) = sin(pos / 10000^(2i/d))` and
|
|
388
|
+
`PE(pos, 2i+1) = cos(pos / 10000^(2i/d))`, where `d` is
|
|
389
|
+
`embedding_dimension`.
|
|
390
|
+
- The output is unsqueezed to include a batch dimension for compatibility with
|
|
391
|
+
token embeddings.
|
|
392
|
+
- The tensor is created on the same device as the input positions for
|
|
393
|
+
compatibility with the model’s device.
|
|
394
|
+
"""
|
|
395
|
+
position = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)
|
|
396
|
+
div_term = torch.exp(torch.arange(0, self.embedding_dimension, 2, dtype=torch.float) *
|
|
397
|
+
-(math.log(10000.0) / self.embedding_dimension))
|
|
398
|
+
pos_enc = torch.zeros((seq_len, self.embedding_dimension), device=position.device)
|
|
399
|
+
pos_enc[:, 0::2] = torch.sin(position * div_term)
|
|
400
|
+
pos_enc[:, 1::2] = torch.cos(position * div_term)
|
|
401
|
+
return pos_enc.unsqueeze(0)
|
|
402
|
+
|
|
403
|
+
def forward(self, token_ids):
|
|
404
|
+
"""Embeds token IDs and adds positional encodings.
|
|
405
|
+
|
|
406
|
+
Parameters
|
|
407
|
+
----------
|
|
408
|
+
token_ids : torch.Tensor
|
|
409
|
+
Token IDs, shape (batch_size, seq_len).
|
|
410
|
+
|
|
411
|
+
Returns
|
|
412
|
+
-------
|
|
413
|
+
torch.Tensor
|
|
414
|
+
Embedded tokens with positional encodings, shape (batch_size, seq_len,
|
|
415
|
+
embedding_dimension).
|
|
416
|
+
|
|
417
|
+
Raises
|
|
418
|
+
------
|
|
419
|
+
AssertionError
|
|
420
|
+
If `token_ids` is not a 2D tensor (batch_size, seq_len).
|
|
421
|
+
"""
|
|
422
|
+
assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
|
|
423
|
+
token_embedded = self.token_embedding(token_ids)
|
|
424
|
+
seq_len = token_ids.size(1)
|
|
425
|
+
if seq_len > self.context_length:
|
|
426
|
+
position_encoded = self._generate_positional_encoding(seq_len).to(token_embedded.device)
|
|
427
|
+
else:
|
|
428
|
+
position_encoded = self.positional_encoding[:, :seq_len, :].to(token_embedded.device)
|
|
429
|
+
return token_embedded + position_encoded
|
ldm/train_autoencoder.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torch.cuda.amp import autocast, GradScaler
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TrainAE(nn.Module):
|
|
11
|
+
"""Trainer for the AutoencoderLDM variational autoencoder in Latent Diffusion Models.
|
|
12
|
+
|
|
13
|
+
Optimizes the AutoencoderLDM model to compress images into latent space and reconstruct
|
|
14
|
+
them, using reconstruction loss (MSE), regularization (KL or VQ), and optional
|
|
15
|
+
perceptual loss (LPIPS). Supports mixed precision, KL warmup, early stopping, and
|
|
16
|
+
learning rate scheduling, with evaluation metrics (MSE, PSNR, SSIM, FID, LPIPS).
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
model : AutoencoderLDM
|
|
21
|
+
The variational autoencoder model (AutoencoderLDM) to train.
|
|
22
|
+
optimizer : torch.optim.Optimizer
|
|
23
|
+
Optimizer for training (e.g., Adam).
|
|
24
|
+
data_loader : torch.utils.data.DataLoader
|
|
25
|
+
DataLoader for training data.
|
|
26
|
+
val_loader : torch.utils.data.DataLoader, optional
|
|
27
|
+
DataLoader for validation data (default: None).
|
|
28
|
+
max_epoch : int, optional
|
|
29
|
+
Maximum number of training epochs (default: 100).
|
|
30
|
+
metrics_ : Metrics, optional
|
|
31
|
+
Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
|
|
32
|
+
device : str, optional
|
|
33
|
+
Device for computation (e.g., 'cuda', 'cpu') (default: 'cuda').
|
|
34
|
+
save_path : str, optional
|
|
35
|
+
Path to save model checkpoints (default: 'vlc_model.pth').
|
|
36
|
+
checkpoint : int, optional
|
|
37
|
+
Frequency (in epochs) to save model checkpoints (default: 10).
|
|
38
|
+
kl_warmup_epochs : int, optional
|
|
39
|
+
Number of epochs for KL loss warmup (default: 10).
|
|
40
|
+
patience : int, optional
|
|
41
|
+
Number of epochs to wait for early stopping if validation loss doesn’t improve
|
|
42
|
+
(default: 10).
|
|
43
|
+
val_frequency : int, optional
|
|
44
|
+
Frequency (in epochs) for validation and metric computation (default: 5).
|
|
45
|
+
|
|
46
|
+
Attributes
|
|
47
|
+
----------
|
|
48
|
+
device : torch.device
|
|
49
|
+
Computation device.
|
|
50
|
+
model : AutoencoderLDM
|
|
51
|
+
Autoencoder model being trained.
|
|
52
|
+
optimizer : torch.optim.Optimizer
|
|
53
|
+
Training optimizer.
|
|
54
|
+
data_loader : torch.utils.data.DataLoader
|
|
55
|
+
Training DataLoader.
|
|
56
|
+
val_loader : torch.utils.data.DataLoader or None
|
|
57
|
+
Validation DataLoader.
|
|
58
|
+
max_epoch : int
|
|
59
|
+
Maximum training epochs.
|
|
60
|
+
metrics_ : Metrics or None
|
|
61
|
+
Metrics object for evaluation.
|
|
62
|
+
save_path : str
|
|
63
|
+
Checkpoint save path.
|
|
64
|
+
checkpoint : int
|
|
65
|
+
Checkpoint frequency.
|
|
66
|
+
kl_warmup_epochs : int
|
|
67
|
+
KL warmup epochs.
|
|
68
|
+
patience : int
|
|
69
|
+
Early stopping patience.
|
|
70
|
+
scheduler : torch.optim.lr_scheduler.ReduceLROnPlateau
|
|
71
|
+
Learning rate scheduler.
|
|
72
|
+
val_frequency : int
|
|
73
|
+
Validation frequency.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(self, model, optimizer, data_loader, val_loader=None, max_epoch=100, metrics_=None,
|
|
77
|
+
device="cuda", save_path="vlc_model.pth", checkpoint=10, kl_warmup_epochs=10,
|
|
78
|
+
patience=10, val_frequency=5):
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
81
|
+
self.model = model.to(self.device)
|
|
82
|
+
self.optimizer = optimizer
|
|
83
|
+
self.data_loader = data_loader
|
|
84
|
+
self.val_loader = val_loader
|
|
85
|
+
self.max_epoch = max_epoch
|
|
86
|
+
self.metrics_ = metrics_ # Metrics object, not moved to device
|
|
87
|
+
self.save_path = save_path
|
|
88
|
+
self.checkpoint = checkpoint
|
|
89
|
+
self.kl_warmup_epochs = kl_warmup_epochs
|
|
90
|
+
self.patience = patience
|
|
91
|
+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=5, factor=0.5)
|
|
92
|
+
self.val_frequency = val_frequency
|
|
93
|
+
|
|
94
|
+
def train(self):
|
|
95
|
+
"""Trains the AutoencoderLDM model with mixed precision and evaluation metrics.
|
|
96
|
+
|
|
97
|
+
Performs training with reconstruction and regularization losses, KL warmup, gradient
|
|
98
|
+
clipping, and learning rate scheduling. Saves checkpoints for the best validation
|
|
99
|
+
loss and supports early stopping.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
tuple
|
|
104
|
+
A tuple containing:
|
|
105
|
+
- train_losses: List of mean training losses per epoch.
|
|
106
|
+
- best_val_loss: Best validation loss achieved (or best training loss if no validation).
|
|
107
|
+
"""
|
|
108
|
+
scaler = GradScaler()
|
|
109
|
+
self.model.train()
|
|
110
|
+
train_losses = []
|
|
111
|
+
best_val_loss = float("inf")
|
|
112
|
+
wait = 0
|
|
113
|
+
|
|
114
|
+
for epoch in range(self.max_epoch):
|
|
115
|
+
if self.model.use_vq:
|
|
116
|
+
beta = 1.0 # No warmup for VQ
|
|
117
|
+
else:
|
|
118
|
+
beta = min(1.0, epoch / self.kl_warmup_epochs) * self.model.beta
|
|
119
|
+
self.model.current_beta = beta
|
|
120
|
+
|
|
121
|
+
train_losses_ = []
|
|
122
|
+
for x, _ in tqdm(self.data_loader):
|
|
123
|
+
x = x.to(self.device)
|
|
124
|
+
self.optimizer.zero_grad()
|
|
125
|
+
with autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
|
|
126
|
+
x_hat, loss, reg_loss, z = self.model(x)
|
|
127
|
+
scaler.scale(loss).backward()
|
|
128
|
+
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
129
|
+
scaler.step(self.optimizer)
|
|
130
|
+
scaler.update()
|
|
131
|
+
train_losses_.append(loss.item())
|
|
132
|
+
|
|
133
|
+
mean_train_loss = torch.mean(torch.tensor(train_losses_)).item()
|
|
134
|
+
train_losses.append(mean_train_loss)
|
|
135
|
+
print(f"Epoch: {epoch + 1} | Train Loss: {mean_train_loss:.4f}", end="")
|
|
136
|
+
|
|
137
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
138
|
+
val_loss, fid, mse, psnr, ssim, lpips_score = self.validate()
|
|
139
|
+
print(f" | Val Loss: {val_loss:.4f}", end="")
|
|
140
|
+
if self.metrics_ and self.metrics_.fid:
|
|
141
|
+
print(f" | FID: {fid:.4f}", end="")
|
|
142
|
+
if self.metrics_ and self.metrics_.metrics:
|
|
143
|
+
print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
|
|
144
|
+
if self.metrics_ and self.metrics_.lpips:
|
|
145
|
+
print(f" | LPIPS: {lpips_score:.4f}", end="")
|
|
146
|
+
print() # Newline after metrics
|
|
147
|
+
|
|
148
|
+
current_best = val_loss
|
|
149
|
+
self.scheduler.step(val_loss)
|
|
150
|
+
else:
|
|
151
|
+
current_best = mean_train_loss
|
|
152
|
+
|
|
153
|
+
if current_best < best_val_loss:
|
|
154
|
+
best_val_loss = current_best
|
|
155
|
+
wait = 0
|
|
156
|
+
torch.save({
|
|
157
|
+
'epoch': epoch,
|
|
158
|
+
'model_state_dict': self.model.state_dict(),
|
|
159
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
160
|
+
'loss': best_val_loss,
|
|
161
|
+
}, self.save_path)
|
|
162
|
+
print(f" | Model saved at epoch {epoch + 1}")
|
|
163
|
+
else:
|
|
164
|
+
wait += 1
|
|
165
|
+
if wait >= self.patience:
|
|
166
|
+
print("Early stopping triggered")
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
return train_losses, best_val_loss
|
|
170
|
+
|
|
171
|
+
def validate(self):
|
|
172
|
+
"""Validates the AutoencoderLDM model and computes evaluation metrics.
|
|
173
|
+
|
|
174
|
+
Computes validation loss and optional metrics (MSE, PSNR, SSIM, FID, LPIPS) using
|
|
175
|
+
the provided Metrics object.
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
tuple
|
|
180
|
+
A tuple containing:
|
|
181
|
+
- val_loss: Mean validation loss (float).
|
|
182
|
+
- fid: Mean FID score (float, or `float('inf')` if not computed).
|
|
183
|
+
- mse: Mean MSE (float, or None if not computed).
|
|
184
|
+
- psnr: Mean PSNR (float, or None if not computed).
|
|
185
|
+
- ssim: Mean SSIM (float, or None if not computed).
|
|
186
|
+
- lpips_score: Mean LPIPS score (float, or None if not computed).
|
|
187
|
+
"""
|
|
188
|
+
self.model.eval()
|
|
189
|
+
val_losses = []
|
|
190
|
+
fid_, mse_, psnr_, ssim_, lpips_score_ = [], [], [], [], []
|
|
191
|
+
|
|
192
|
+
with torch.no_grad():
|
|
193
|
+
for x, _ in self.val_loader:
|
|
194
|
+
x = x.to(self.device)
|
|
195
|
+
x_hat, loss, reg_loss, z = self.model(x)
|
|
196
|
+
val_losses.append(loss.item())
|
|
197
|
+
if self.metrics_ is not None:
|
|
198
|
+
fid, mse, psnr, ssim, lpips_score = self.metrics_.forward(x, x_hat)
|
|
199
|
+
if self.metrics_.fid:
|
|
200
|
+
fid_.append(fid)
|
|
201
|
+
if self.metrics_.metrics:
|
|
202
|
+
mse_.append(mse)
|
|
203
|
+
psnr_.append(psnr)
|
|
204
|
+
ssim_.append(ssim)
|
|
205
|
+
if self.metrics_.lpips:
|
|
206
|
+
lpips_score_.append(lpips_score)
|
|
207
|
+
|
|
208
|
+
val_loss = torch.mean(torch.tensor(val_losses)).item()
|
|
209
|
+
fid_ = torch.mean(torch.tensor(fid_)).item() if fid_ else float('inf')
|
|
210
|
+
mse_ = torch.mean(torch.tensor(mse_)).item() if mse_ else None
|
|
211
|
+
psnr_ = torch.mean(torch.tensor(psnr_)).item() if psnr_ else None
|
|
212
|
+
ssim_ = torch.mean(torch.tensor(ssim_)).item() if ssim_ else None
|
|
213
|
+
lpips_score_ = torch.mean(torch.tensor(lpips_score_)).item() if lpips_score_ else None
|
|
214
|
+
|
|
215
|
+
self.model.train()
|
|
216
|
+
return val_loss, fid_, mse_, psnr_, ssim_, lpips_score_
|