TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
ldm/text_encoder.py ADDED
@@ -0,0 +1,429 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from transformers import BertModel, DistilBertModel
5
+
6
+
7
+
8
+
9
+ class TextEncoder(torch.nn.Module):
10
+ """Transformer-based encoder for text prompts in conditional diffusion models.
11
+
12
+ Encodes text prompts into embeddings using either a pre-trained BERT model or a
13
+ custom transformer architecture. Used as the `conditional_model` in diffusion models
14
+ (e.g., DDPM, DDIM, SDE, LDM) to provide conditional inputs for noise prediction.
15
+
16
+ Parameters
17
+ ----------
18
+ use_pretrained_model : bool, optional
19
+ If True, uses a pre-trained BERT model; otherwise, builds a custom transformer
20
+ (default: True).
21
+ model_name : str, optional
22
+ Name of the pre-trained model to load (default: "bert-base-uncased").
23
+ vocabulary_size : int, optional
24
+ Size of the vocabulary for the custom transformer’s embedding layer
25
+ (default: 30522).
26
+ num_layers : int, optional
27
+ Number of transformer encoder layers for the custom transformer (default: 6).
28
+ input_dimension : int, optional
29
+ Input embedding dimension for the custom transformer (default: 768).
30
+ output_dimension : int, optional
31
+ Output embedding dimension for both pre-trained and custom models
32
+ (default: 768).
33
+ num_heads : int, optional
34
+ Number of attention heads in the custom transformer (default: 8).
35
+ context_length : int, optional
36
+ Maximum sequence length for text prompts (default: 77).
37
+ dropout_rate : float, optional
38
+ Dropout rate for attention and feedforward layers (default: 0.1).
39
+ qkv_bias : bool, optional
40
+ If True, includes bias in query, key, and value projections for the custom
41
+ transformer (default: False).
42
+ scaling_value : int, optional
43
+ Scaling factor for the feedforward layer’s hidden dimension in the custom
44
+ transformer (default: 4).
45
+ epsilon : float, optional
46
+ Epsilon for layer normalization in the custom transformer (default: 1e-5).
47
+
48
+ Attributes
49
+ ----------
50
+ use_pretrained_model : bool
51
+ Whether a pre-trained model is used.
52
+ bert : transformers.BertModel or None
53
+ Pre-trained BERT model, if `use_pretrained_model` is True.
54
+ projection : torch.nn.Linear or None
55
+ Linear layer to project BERT outputs to `output_dimension`, if
56
+ `use_pretrained_model` is True.
57
+ embedding : Embedding or None
58
+ Token and positional embedding layer for the custom transformer, if
59
+ `use_pretrained_model` is False.
60
+ layers : torch.nn.ModuleList or None
61
+ List of EncoderLayer modules for the custom transformer, if
62
+ `use_pretrained_model` is False.
63
+
64
+ Notes
65
+ -----
66
+ - When `use_pretrained_model` is True, the BERT model’s parameters are frozen
67
+ (`requires_grad = False`), and a projection layer maps outputs to
68
+ `output_dimension`.
69
+ - The custom transformer uses `EncoderLayer` modules with multi-head attention and
70
+ feedforward networks, supporting variable input/output dimensions.
71
+ - The output shape is (batch_size, context_length, output_dimension).
72
+ """
73
+ def __init__(
74
+ self,
75
+ use_pretrained_model=True,
76
+ model_name="bert-base-uncased",
77
+ vocabulary_size=30522,
78
+ num_layers=6,
79
+ input_dimension=768,
80
+ output_dimension=768,
81
+ num_heads=8,
82
+ context_length=77,
83
+ dropout_rate=0.1,
84
+ qkv_bias=False,
85
+ scaling_value=4,
86
+ epsilon=1e-5
87
+ ):
88
+ super().__init__()
89
+ self.use_pretrained_model = use_pretrained_model
90
+ if self.use_pretrained_model:
91
+ # self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
92
+ self.bert = BertModel.from_pretrained(model_name)
93
+ for param in self.bert.parameters():
94
+ param.requires_grad = False
95
+ self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
96
+ else:
97
+ self.embedding = Embedding(
98
+ vocabulary_size=vocabulary_size,
99
+ embedding_dimension=input_dimension,
100
+ context_length=context_length
101
+ )
102
+ self.layers = torch.nn.ModuleList([
103
+ EncoderLayer(
104
+ input_dimension=input_dimension,
105
+ output_dimension=output_dimension,
106
+ num_heads=num_heads,
107
+ dropout_rate=dropout_rate,
108
+ qkv_bias=qkv_bias,
109
+ scaling_value=scaling_value,
110
+ epsilon=epsilon
111
+ )
112
+ for _ in range(num_layers)
113
+ ])
114
+ def forward(self, x, attention_mask=None):
115
+ """Encodes text prompts into embeddings.
116
+
117
+ Processes input token IDs and an optional attention mask to produce embeddings
118
+ using either a pre-trained BERT model or a custom transformer.
119
+
120
+ Parameters
121
+ ----------
122
+ x : torch.Tensor
123
+ Token IDs, shape (batch_size, seq_len).
124
+ attention_mask : torch.Tensor, optional
125
+ Attention mask, shape (batch_size, seq_len), where 0 indicates padding
126
+ tokens to ignore (default: None).
127
+
128
+ Returns
129
+ -------
130
+ torch.Tensor
131
+ Encoded embeddings, shape (batch_size, seq_len, output_dimension).
132
+
133
+ Notes
134
+ -----
135
+ - For pre-trained BERT, the `last_hidden_state` is projected to
136
+ `output_dimension`.
137
+ - For the custom transformer, token embeddings are processed through
138
+ `Embedding` and `EncoderLayer` modules.
139
+ - The attention mask should be 0 for padding tokens and 1 for valid tokens when
140
+ using the custom transformer, or follow BERT’s convention for pre-trained
141
+ models.
142
+ """
143
+ if self.use_pretrained_model:
144
+ x = self.bert(input_ids=x, attention_mask=attention_mask)
145
+ x = x.last_hidden_state
146
+ x = self.projection(x)
147
+ else:
148
+ x = self.embedding(x)
149
+ for layer in self.layers:
150
+ x = layer(x, attention_mask=attention_mask)
151
+ return x
152
+ #-----------------------------------------------------------------------------------------------------------------------
153
+ class EncoderLayer(torch.nn.Module):
154
+ """Single transformer encoder layer with multi-head attention and feedforward network.
155
+
156
+ Used in the custom transformer of `TextEncoder` to process embedded text prompts.
157
+
158
+ Parameters
159
+ ----------
160
+ input_dimension : int
161
+ Input embedding dimension.
162
+ output_dimension : int
163
+ Output embedding dimension.
164
+ num_heads : int
165
+ Number of attention heads.
166
+ dropout_rate : float
167
+ Dropout rate for attention and feedforward layers.
168
+ qkv_bias : bool
169
+ If True, includes bias in query, key, and value projections.
170
+ scaling_value : int
171
+ Scaling factor for the feedforward layer’s hidden dimension.
172
+ epsilon : float, optional
173
+ Epsilon for layer normalization (default: 1e-5).
174
+
175
+ Attributes
176
+ ----------
177
+ attention : torch.nn.MultiheadAttention
178
+ Multi-head self-attention mechanism.
179
+ output_projection : torch.nn.Linear or torch.nn.Identity
180
+ Linear layer to project attention outputs to `output_dimension`, or identity
181
+ if `input_dimension` equals `output_dimension`.
182
+ norm1 : torch.nn.LayerNorm
183
+ Layer normalization after attention.
184
+ dropout1 : torch.nn.Dropout
185
+ Dropout after attention.
186
+ feedforward : FeedForward
187
+ Feedforward network.
188
+ norm2 : torch.nn.LayerNorm
189
+ Layer normalization after feedforward.
190
+ dropout2 : torch.nn.Dropout
191
+ Dropout after feedforward.
192
+
193
+ Notes
194
+ -----
195
+ - The layer follows the standard transformer encoder architecture: attention,
196
+ residual connection, normalization, feedforward, residual connection,
197
+ normalization.
198
+ - The attention mechanism uses `batch_first=True` for compatibility with
199
+ `TextEncoder`’s input format.
200
+ """
201
+ def __init__(
202
+ self,
203
+ input_dimension,
204
+ output_dimension,
205
+ num_heads,
206
+ dropout_rate,
207
+ qkv_bias,
208
+ scaling_value,
209
+ epsilon=1e-5
210
+ ):
211
+ super().__init__()
212
+ self.attention = nn.MultiheadAttention(
213
+ embed_dim=input_dimension,
214
+ num_heads=num_heads,
215
+ dropout=dropout_rate,
216
+ bias=qkv_bias,
217
+ batch_first=True
218
+ )
219
+ self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
220
+ self.norm1 = nn.LayerNorm(normalized_shape=input_dimension, eps=epsilon)
221
+ self.dropout1 = nn.Dropout(dropout_rate)
222
+ self.feedforward = FeedForward(
223
+ embedding_dimension=input_dimension,
224
+ scaling_value=scaling_value,
225
+ dropout_rate=dropout_rate
226
+ )
227
+ self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
228
+ self.dropout2 = nn.Dropout(dropout_rate)
229
+ def forward(self, x, attention_mask=None):
230
+ """Processes input embeddings through attention and feedforward layers.
231
+
232
+ Parameters
233
+ ----------
234
+ x : torch.Tensor
235
+ Input embeddings, shape (batch_size, seq_len, input_dimension).
236
+ attention_mask : torch.Tensor, optional
237
+ Attention mask, shape (batch_size, seq_len), where 0 indicates padding
238
+ tokens to ignore (default: None).
239
+
240
+ Returns
241
+ -------
242
+ torch.Tensor
243
+ Processed embeddings, shape (batch_size, seq_len, output_dimension).
244
+
245
+ Notes
246
+ -----
247
+ - The attention mask is passed as `key_padding_mask` to
248
+ `nn.MultiheadAttention`, where 0 indicates padding tokens.
249
+ - Residual connections and normalization are applied after attention and
250
+ feedforward layers.
251
+ """
252
+ attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
253
+ attn_output = self.output_projection(attn_output)
254
+ x = self.norm1(x + self.dropout1(attn_output))
255
+ ff_output = self.feedforward(x)
256
+ x = self.norm2(x + self.dropout2(ff_output))
257
+ return x
258
+ #-----------------------------------------------------------------------------------------------------------------------
259
+ class FeedForward(torch.nn.Module):
260
+ """Feedforward network for transformer encoder layers.
261
+
262
+ Used in `EncoderLayer` to process attention outputs with a two-layer MLP and GELU
263
+ activation.
264
+
265
+ Parameters
266
+ ----------
267
+ embedding_dimension : int
268
+ Input and output embedding dimension.
269
+ scaling_value : int
270
+ Scaling factor for the hidden layer’s dimension (hidden_dim =
271
+ embedding_dimension * scaling_value).
272
+ dropout_rate : float, optional
273
+ Dropout rate after the hidden layer (default: 0.1).
274
+
275
+ Attributes
276
+ ----------
277
+ layers : torch.nn.Sequential
278
+ Sequential container with linear, GELU, dropout, and linear layers.
279
+
280
+ Notes
281
+ -----
282
+ - The hidden layer dimension is `embedding_dimension * scaling_value`, following
283
+ standard transformer feedforward designs.
284
+ - GELU activation is used for non-linearity.
285
+ """
286
+ def __init__(self, embedding_dimension, scaling_value, dropout_rate=0.1):
287
+ super().__init__()
288
+ self.layers = torch.nn.Sequential(
289
+ torch.nn.Linear(
290
+ in_features=embedding_dimension,
291
+ out_features=embedding_dimension * scaling_value,
292
+ bias=True
293
+ ),
294
+ torch.nn.GELU(),
295
+ torch.nn.Dropout(dropout_rate),
296
+ torch.nn.Linear(
297
+ in_features=embedding_dimension * scaling_value,
298
+ out_features=embedding_dimension,
299
+ bias=True
300
+ )
301
+ )
302
+ def forward(self, x):
303
+ """Processes input embeddings through the feedforward network.
304
+
305
+ Parameters
306
+ ----------
307
+ x : torch.Tensor
308
+ Input embeddings, shape (batch_size, seq_len, embedding_dimension).
309
+
310
+ Returns
311
+ -------
312
+ torch.Tensor
313
+ Processed embeddings, shape (batch_size, seq_len, embedding_dimension).
314
+ """
315
+ return self.layers(x)
316
+ #-----------------------------------------------------------------------------------------------------------------------
317
+ class Embedding(torch.nn.Module):
318
+ """Token and positional embedding layer for transformer inputs.
319
+
320
+ Used in `TextEncoder`’s custom transformer to embed token IDs and add positional
321
+ encodings.
322
+
323
+ Parameters
324
+ ----------
325
+ vocabulary_size : int
326
+ Size of the vocabulary for token embeddings.
327
+ embedding_dimension : int, optional
328
+ Dimension of token and positional embeddings (default: 768).
329
+ context_length : int, optional
330
+ Maximum sequence length for positional encodings (default: 77).
331
+
332
+ Attributes
333
+ ----------
334
+ token_embedding : torch.nn.Embedding
335
+ Token embedding layer.
336
+ embedding_dimension : int
337
+ Dimension of embeddings.
338
+ context_length : int
339
+ Maximum sequence length.
340
+ positional_encoding : torch.Tensor
341
+ Pre-computed positional encodings, shape (1, context_length,
342
+ embedding_dimension).
343
+
344
+ Notes
345
+ -----
346
+ - Positional encodings are computed using sinusoidal functions, following the
347
+ transformer architecture.
348
+ - For sequences longer than `context_length`, positional encodings are dynamically
349
+ generated.
350
+ - The output shape is (batch_size, seq_len, embedding_dimension).
351
+ """
352
+ def __init__(
353
+ self,
354
+ vocabulary_size,
355
+ embedding_dimension=768,
356
+ context_length=77
357
+ ):
358
+ super().__init__()
359
+ self.token_embedding = nn.Embedding(
360
+ num_embeddings=vocabulary_size,
361
+ embedding_dim=embedding_dimension
362
+ )
363
+ self.embedding_dimension = embedding_dimension
364
+ self.context_length = context_length
365
+ self.register_buffer("positional_encoding", self._generate_positional_encoding(context_length))
366
+
367
+ def _generate_positional_encoding(self, seq_len):
368
+ """Generates sinusoidal positional encodings for transformer inputs.
369
+
370
+ Computes positional encodings using sine and cosine functions, following the
371
+ transformer architecture, to represent token positions in a sequence.
372
+
373
+ Parameters
374
+ ----------
375
+ seq_len : int
376
+ Length of the sequence for which to generate positional encodings.
377
+
378
+ Returns
379
+ -------
380
+ torch.Tensor
381
+ Positional encodings, shape (1, seq_len, embedding_dimension), where
382
+ even-indexed dimensions use sine and odd-indexed dimensions use cosine.
383
+
384
+ Notes
385
+ -----
386
+ - The encoding follows the formula: for position `pos` and dimension `i`,
387
+ `PE(pos, 2i) = sin(pos / 10000^(2i/d))` and
388
+ `PE(pos, 2i+1) = cos(pos / 10000^(2i/d))`, where `d` is
389
+ `embedding_dimension`.
390
+ - The output is unsqueezed to include a batch dimension for compatibility with
391
+ token embeddings.
392
+ - The tensor is created on the same device as the input positions for
393
+ compatibility with the model’s device.
394
+ """
395
+ position = torch.arange(seq_len, dtype=torch.float).unsqueeze(1)
396
+ div_term = torch.exp(torch.arange(0, self.embedding_dimension, 2, dtype=torch.float) *
397
+ -(math.log(10000.0) / self.embedding_dimension))
398
+ pos_enc = torch.zeros((seq_len, self.embedding_dimension), device=position.device)
399
+ pos_enc[:, 0::2] = torch.sin(position * div_term)
400
+ pos_enc[:, 1::2] = torch.cos(position * div_term)
401
+ return pos_enc.unsqueeze(0)
402
+
403
+ def forward(self, token_ids):
404
+ """Embeds token IDs and adds positional encodings.
405
+
406
+ Parameters
407
+ ----------
408
+ token_ids : torch.Tensor
409
+ Token IDs, shape (batch_size, seq_len).
410
+
411
+ Returns
412
+ -------
413
+ torch.Tensor
414
+ Embedded tokens with positional encodings, shape (batch_size, seq_len,
415
+ embedding_dimension).
416
+
417
+ Raises
418
+ ------
419
+ AssertionError
420
+ If `token_ids` is not a 2D tensor (batch_size, seq_len).
421
+ """
422
+ assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
423
+ token_embedded = self.token_embedding(token_ids)
424
+ seq_len = token_ids.size(1)
425
+ if seq_len > self.context_length:
426
+ position_encoded = self._generate_positional_encoding(seq_len).to(token_embedded.device)
427
+ else:
428
+ position_encoded = self.positional_encoding[:, :seq_len, :].to(token_embedded.device)
429
+ return token_embedded + position_encoded
@@ -0,0 +1,216 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.cuda.amp import autocast, GradScaler
4
+ from tqdm import tqdm
5
+
6
+
7
+
8
+
9
+
10
+ class TrainAE(nn.Module):
11
+ """Trainer for the AutoencoderLDM variational autoencoder in Latent Diffusion Models.
12
+
13
+ Optimizes the AutoencoderLDM model to compress images into latent space and reconstruct
14
+ them, using reconstruction loss (MSE), regularization (KL or VQ), and optional
15
+ perceptual loss (LPIPS). Supports mixed precision, KL warmup, early stopping, and
16
+ learning rate scheduling, with evaluation metrics (MSE, PSNR, SSIM, FID, LPIPS).
17
+
18
+ Parameters
19
+ ----------
20
+ model : AutoencoderLDM
21
+ The variational autoencoder model (AutoencoderLDM) to train.
22
+ optimizer : torch.optim.Optimizer
23
+ Optimizer for training (e.g., Adam).
24
+ data_loader : torch.utils.data.DataLoader
25
+ DataLoader for training data.
26
+ val_loader : torch.utils.data.DataLoader, optional
27
+ DataLoader for validation data (default: None).
28
+ max_epoch : int, optional
29
+ Maximum number of training epochs (default: 100).
30
+ metrics_ : Metrics, optional
31
+ Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
32
+ device : str, optional
33
+ Device for computation (e.g., 'cuda', 'cpu') (default: 'cuda').
34
+ save_path : str, optional
35
+ Path to save model checkpoints (default: 'vlc_model.pth').
36
+ checkpoint : int, optional
37
+ Frequency (in epochs) to save model checkpoints (default: 10).
38
+ kl_warmup_epochs : int, optional
39
+ Number of epochs for KL loss warmup (default: 10).
40
+ patience : int, optional
41
+ Number of epochs to wait for early stopping if validation loss doesn’t improve
42
+ (default: 10).
43
+ val_frequency : int, optional
44
+ Frequency (in epochs) for validation and metric computation (default: 5).
45
+
46
+ Attributes
47
+ ----------
48
+ device : torch.device
49
+ Computation device.
50
+ model : AutoencoderLDM
51
+ Autoencoder model being trained.
52
+ optimizer : torch.optim.Optimizer
53
+ Training optimizer.
54
+ data_loader : torch.utils.data.DataLoader
55
+ Training DataLoader.
56
+ val_loader : torch.utils.data.DataLoader or None
57
+ Validation DataLoader.
58
+ max_epoch : int
59
+ Maximum training epochs.
60
+ metrics_ : Metrics or None
61
+ Metrics object for evaluation.
62
+ save_path : str
63
+ Checkpoint save path.
64
+ checkpoint : int
65
+ Checkpoint frequency.
66
+ kl_warmup_epochs : int
67
+ KL warmup epochs.
68
+ patience : int
69
+ Early stopping patience.
70
+ scheduler : torch.optim.lr_scheduler.ReduceLROnPlateau
71
+ Learning rate scheduler.
72
+ val_frequency : int
73
+ Validation frequency.
74
+ """
75
+
76
+ def __init__(self, model, optimizer, data_loader, val_loader=None, max_epoch=100, metrics_=None,
77
+ device="cuda", save_path="vlc_model.pth", checkpoint=10, kl_warmup_epochs=10,
78
+ patience=10, val_frequency=5):
79
+ super().__init__()
80
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ self.model = model.to(self.device)
82
+ self.optimizer = optimizer
83
+ self.data_loader = data_loader
84
+ self.val_loader = val_loader
85
+ self.max_epoch = max_epoch
86
+ self.metrics_ = metrics_ # Metrics object, not moved to device
87
+ self.save_path = save_path
88
+ self.checkpoint = checkpoint
89
+ self.kl_warmup_epochs = kl_warmup_epochs
90
+ self.patience = patience
91
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=5, factor=0.5)
92
+ self.val_frequency = val_frequency
93
+
94
+ def train(self):
95
+ """Trains the AutoencoderLDM model with mixed precision and evaluation metrics.
96
+
97
+ Performs training with reconstruction and regularization losses, KL warmup, gradient
98
+ clipping, and learning rate scheduling. Saves checkpoints for the best validation
99
+ loss and supports early stopping.
100
+
101
+ Returns
102
+ -------
103
+ tuple
104
+ A tuple containing:
105
+ - train_losses: List of mean training losses per epoch.
106
+ - best_val_loss: Best validation loss achieved (or best training loss if no validation).
107
+ """
108
+ scaler = GradScaler()
109
+ self.model.train()
110
+ train_losses = []
111
+ best_val_loss = float("inf")
112
+ wait = 0
113
+
114
+ for epoch in range(self.max_epoch):
115
+ if self.model.use_vq:
116
+ beta = 1.0 # No warmup for VQ
117
+ else:
118
+ beta = min(1.0, epoch / self.kl_warmup_epochs) * self.model.beta
119
+ self.model.current_beta = beta
120
+
121
+ train_losses_ = []
122
+ for x, _ in tqdm(self.data_loader):
123
+ x = x.to(self.device)
124
+ self.optimizer.zero_grad()
125
+ with autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
126
+ x_hat, loss, reg_loss, z = self.model(x)
127
+ scaler.scale(loss).backward()
128
+ nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
129
+ scaler.step(self.optimizer)
130
+ scaler.update()
131
+ train_losses_.append(loss.item())
132
+
133
+ mean_train_loss = torch.mean(torch.tensor(train_losses_)).item()
134
+ train_losses.append(mean_train_loss)
135
+ print(f"Epoch: {epoch + 1} | Train Loss: {mean_train_loss:.4f}", end="")
136
+
137
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
138
+ val_loss, fid, mse, psnr, ssim, lpips_score = self.validate()
139
+ print(f" | Val Loss: {val_loss:.4f}", end="")
140
+ if self.metrics_ and self.metrics_.fid:
141
+ print(f" | FID: {fid:.4f}", end="")
142
+ if self.metrics_ and self.metrics_.metrics:
143
+ print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
144
+ if self.metrics_ and self.metrics_.lpips:
145
+ print(f" | LPIPS: {lpips_score:.4f}", end="")
146
+ print() # Newline after metrics
147
+
148
+ current_best = val_loss
149
+ self.scheduler.step(val_loss)
150
+ else:
151
+ current_best = mean_train_loss
152
+
153
+ if current_best < best_val_loss:
154
+ best_val_loss = current_best
155
+ wait = 0
156
+ torch.save({
157
+ 'epoch': epoch,
158
+ 'model_state_dict': self.model.state_dict(),
159
+ 'optimizer_state_dict': self.optimizer.state_dict(),
160
+ 'loss': best_val_loss,
161
+ }, self.save_path)
162
+ print(f" | Model saved at epoch {epoch + 1}")
163
+ else:
164
+ wait += 1
165
+ if wait >= self.patience:
166
+ print("Early stopping triggered")
167
+ break
168
+
169
+ return train_losses, best_val_loss
170
+
171
+ def validate(self):
172
+ """Validates the AutoencoderLDM model and computes evaluation metrics.
173
+
174
+ Computes validation loss and optional metrics (MSE, PSNR, SSIM, FID, LPIPS) using
175
+ the provided Metrics object.
176
+
177
+ Returns
178
+ -------
179
+ tuple
180
+ A tuple containing:
181
+ - val_loss: Mean validation loss (float).
182
+ - fid: Mean FID score (float, or `float('inf')` if not computed).
183
+ - mse: Mean MSE (float, or None if not computed).
184
+ - psnr: Mean PSNR (float, or None if not computed).
185
+ - ssim: Mean SSIM (float, or None if not computed).
186
+ - lpips_score: Mean LPIPS score (float, or None if not computed).
187
+ """
188
+ self.model.eval()
189
+ val_losses = []
190
+ fid_, mse_, psnr_, ssim_, lpips_score_ = [], [], [], [], []
191
+
192
+ with torch.no_grad():
193
+ for x, _ in self.val_loader:
194
+ x = x.to(self.device)
195
+ x_hat, loss, reg_loss, z = self.model(x)
196
+ val_losses.append(loss.item())
197
+ if self.metrics_ is not None:
198
+ fid, mse, psnr, ssim, lpips_score = self.metrics_.forward(x, x_hat)
199
+ if self.metrics_.fid:
200
+ fid_.append(fid)
201
+ if self.metrics_.metrics:
202
+ mse_.append(mse)
203
+ psnr_.append(psnr)
204
+ ssim_.append(ssim)
205
+ if self.metrics_.lpips:
206
+ lpips_score_.append(lpips_score)
207
+
208
+ val_loss = torch.mean(torch.tensor(val_losses)).item()
209
+ fid_ = torch.mean(torch.tensor(fid_)).item() if fid_ else float('inf')
210
+ mse_ = torch.mean(torch.tensor(mse_)).item() if mse_ else None
211
+ psnr_ = torch.mean(torch.tensor(psnr_)).item() if psnr_ else None
212
+ ssim_ = torch.mean(torch.tensor(ssim_)).item() if ssim_ else None
213
+ lpips_score_ = torch.mean(torch.tensor(lpips_score_)).item() if lpips_score_ else None
214
+
215
+ self.model.train()
216
+ return val_loss, fid_, mse_, psnr_, ssim_, lpips_score_