TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
torchdiff/utils.py ADDED
@@ -0,0 +1,1660 @@
1
+ """
2
+ **Utilities for text encoding, noise prediction, and evaluation in diffusion models**
3
+
4
+ This module provides core components for building diffusion model pipelines, including
5
+ text encoding, used as conditional model, U-Net-based noise prediction, and image quality evaluation. These
6
+ utilities support various diffusion model architectures, such as DDPM, DDIM, LDM, and
7
+ SDE, and are designed for standalone use in model training and sampling.
8
+
9
+ **Primary Components**
10
+
11
+ - **TextEncoder**: Encodes text prompts into embeddings using a pre-trained BERT model or a custom transformer.
12
+ - **NoisePredictor**: U-Net-like architecture for predicting noise in diffusion models, supporting time and text conditioning.
13
+ - **Metrics**: Computes image quality metrics (MSE, PSNR, SSIM, FID, LPIPS) for evaluating generated images.
14
+
15
+ **Notes**
16
+
17
+ - The primary components are intended to be imported directly for use in diffusion model workflows.
18
+ - Additional supporting classes and functions in this module provide internal functionality for the primary components.
19
+
20
+
21
+ ---------------------------------------------------------------------------------
22
+ """
23
+
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from pytorch_fid import fid_score
29
+ from transformers import BertModel
30
+ import os
31
+ import math
32
+ import shutil
33
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
34
+ from torchvision.utils import save_image
35
+ from typing import Optional, Tuple, List
36
+
37
+
38
+ ###==================================================================================================================###
39
+
40
+ class TextEncoder(torch.nn.Module):
41
+ """Transformer-based encoder for text prompts in conditional diffusion models.
42
+
43
+ Encodes text prompts into embeddings using either a pre-trained BERT model or a
44
+ custom transformer architecture. Used as the `conditional_model` in diffusion models
45
+ (e.g., DDPM, DDIM, SDE, LDM) to provide conditional inputs for noise prediction.
46
+
47
+ Parameters
48
+ ----------
49
+ use_pretrained_model : bool, optional
50
+ If True, uses a pre-trained BERT model; otherwise, builds a custom transformer
51
+ (default: True).
52
+ model_name : str, optional
53
+ Name of the pre-trained model to load (default: "bert-base-uncased").
54
+ vocabulary_size : int, optional
55
+ Size of the vocabulary for the custom transformer’s embedding layer
56
+ (default: 30522).
57
+ num_layers : int, optional
58
+ Number of transformer encoder layers for the custom transformer (default: 6).
59
+ input_dimension : int, optional
60
+ Input embedding dimension for the custom transformer (default: 768).
61
+ output_dimension : int, optional
62
+ Output embedding dimension for both pre-trained and custom models
63
+ (default: 768).
64
+ num_heads : int, optional
65
+ Number of attention heads in the custom transformer (default: 8).
66
+ context_length : int, optional
67
+ Maximum sequence length for text prompts (default: 77).
68
+ dropout_rate : float, optional
69
+ Dropout rate for attention and feedforward layers (default: 0.1).
70
+ qkv_bias : bool, optional
71
+ If True, includes bias in query, key, and value projections for the custom
72
+ transformer (default: False).
73
+ scaling_value : int, optional
74
+ Scaling factor for the feedforward layer’s hidden dimension in the custom
75
+ transformer (default: 4).
76
+ epsilon : float, optional
77
+ Epsilon for layer normalization in the custom transformer (default: 1e-5).
78
+ use_learned_pos : bool, optional
79
+ If True, in the transformer structure uses learnable positional embeddings instead of sinusoidal encodings
80
+ (default: False).
81
+
82
+ **Notes**
83
+
84
+ - When `use_pretrained_model` is True, the BERT model’s parameters are frozen
85
+ (`requires_grad = False`), and a projection layer maps outputs to
86
+ `output_dimension`.
87
+ - The custom transformer uses `EncoderLayer` modules with multi-head attention and
88
+ feedforward networks, supporting variable input/output dimensions.
89
+ - The output shape is (batch_size, context_length, output_dimension).
90
+ """
91
+ def __init__(
92
+ self,
93
+ use_pretrained_model: bool = True,
94
+ model_name: str = "bert-base-uncased",
95
+ vocabulary_size: int = 30522,
96
+ num_layers: int = 6,
97
+ input_dimension: int = 768,
98
+ output_dimension: int = 768,
99
+ num_heads: int = 8,
100
+ context_length: int = 77,
101
+ dropout_rate: float = 0.1,
102
+ qkv_bias: bool = False,
103
+ scaling_value: int = 4,
104
+ epsilon: float = 1e-5,
105
+ use_learned_pos: bool = False
106
+ ) -> None:
107
+ super().__init__()
108
+ self.use_pretrained_model = use_pretrained_model
109
+ if self.use_pretrained_model:
110
+ self.bert = BertModel.from_pretrained(model_name)
111
+ for param in self.bert.parameters():
112
+ param.requires_grad = False
113
+ self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
114
+ else:
115
+ self.embedding = Embedding(
116
+ vocabulary_size=vocabulary_size,
117
+ embedding_dimension=input_dimension,
118
+ max_context_length=context_length,
119
+ use_learned_pos=use_learned_pos
120
+ )
121
+ self.layers = torch.nn.ModuleList([
122
+ EncoderLayer(
123
+ input_dimension=input_dimension,
124
+ output_dimension=output_dimension,
125
+ num_heads=num_heads,
126
+ dropout_rate=dropout_rate,
127
+ qkv_bias=qkv_bias,
128
+ scaling_value=scaling_value,
129
+ epsilon=epsilon
130
+ )
131
+ for _ in range(num_layers)
132
+ ])
133
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
134
+ """Encodes text prompts into embeddings.
135
+
136
+ Processes input token IDs and an optional attention mask to produce embeddings
137
+ using either a pre-trained BERT model or a custom transformer.
138
+
139
+ Parameters
140
+ ----------
141
+ x : torch.Tensor
142
+ Token IDs, shape (batch_size, seq_len).
143
+ attention_mask : torch.Tensor, optional
144
+ Attention mask, shape (batch_size, seq_len), where 0 indicates padding
145
+ tokens to ignore (default: None).
146
+
147
+ Returns
148
+ -------
149
+ x (torch.Tensor) - Encoded embeddings, shape (batch_size, seq_len, output_dimension).
150
+
151
+ **Notes**
152
+
153
+ - For pre-trained BERT, the `last_hidden_state` is projected to
154
+ `output_dimension` and this layer is the only trainable layer in the model.
155
+ - For the custom transformer, token embeddings are processed through
156
+ `Embedding` and `EncoderLayer` modules.
157
+ - The attention mask should be 0 for padding tokens and 1 for valid tokens when
158
+ using the custom transformer, or follow BERT’s convention for pre-trained
159
+ models.
160
+ """
161
+ if self.use_pretrained_model:
162
+ x = self.bert(input_ids=x, attention_mask=attention_mask)
163
+ x = x.last_hidden_state
164
+ x = self.projection(x)
165
+ else:
166
+ x = self.embedding(x)
167
+ for layer in self.layers:
168
+ x = layer(x, attention_mask=attention_mask)
169
+ return x
170
+
171
+ ###==================================================================================================================###
172
+
173
+ class EncoderLayer(torch.nn.Module):
174
+ """Single transformer encoder layer with multi-head attention and feedforward network.
175
+
176
+ Used in the custom transformer of `TextEncoder` to process embedded text prompts.
177
+
178
+ Parameters
179
+ ----------
180
+ input_dimension : int
181
+ Input embedding dimension.
182
+ output_dimension : int
183
+ Output embedding dimension.
184
+ num_heads : int
185
+ Number of attention heads.
186
+ dropout_rate : float
187
+ Dropout rate for attention and feedforward layers.
188
+ qkv_bias : bool
189
+ If True, includes bias in query, key, and value projections.
190
+ scaling_value : int
191
+ Scaling factor for the feedforward layer’s hidden dimension.
192
+ epsilon : float, optional
193
+ Epsilon for layer normalization (default: 1e-5).
194
+
195
+ **Notes**
196
+
197
+ - The layer follows the standard transformer encoder architecture: attention,
198
+ residual connection, normalization, feedforward, residual connection,
199
+ normalization.
200
+ - The attention mechanism uses `batch_first=True` for compatibility with
201
+ `TextEncoder`’s input format.
202
+ """
203
+ def __init__(
204
+ self,
205
+ input_dimension: int,
206
+ output_dimension: int,
207
+ num_heads: int,
208
+ dropout_rate: float,
209
+ qkv_bias: bool,
210
+ scaling_value: int,
211
+ epsilon: float = 1e-5
212
+ ) -> None:
213
+ super().__init__()
214
+ self.attention = nn.MultiheadAttention(
215
+ embed_dim=input_dimension,
216
+ num_heads=num_heads,
217
+ dropout=dropout_rate,
218
+ bias=qkv_bias,
219
+ batch_first=True
220
+ )
221
+ self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
222
+ self.norm1 = self.norm1 = nn.LayerNorm(normalized_shape=input_dimension, eps=epsilon)
223
+ self.dropout1 = nn.Dropout(dropout_rate)
224
+ self.feedforward = FeedForward(
225
+ embedding_dimension=input_dimension,
226
+ scaling_value=scaling_value,
227
+ dropout_rate=dropout_rate
228
+ )
229
+ self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
230
+ self.dropout2 = nn.Dropout(dropout_rate)
231
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
232
+ """Processes input embeddings through attention and feedforward layers.
233
+
234
+ Parameters
235
+ ----------
236
+ x : torch.Tensor
237
+ Input embeddings, shape (batch_size, seq_len, input_dimension).
238
+ attention_mask : torch.Tensor, optional
239
+ Attention mask, shape (batch_size, seq_len), where 0 indicates padding
240
+ tokens to ignore (default: None).
241
+
242
+ Returns
243
+ -------
244
+ x (torch.Tensor) - Processed embeddings, shape (batch_size, seq_len, output_dimension).
245
+
246
+ **Notes**
247
+
248
+ - The attention mask is passed as `key_padding_mask` to
249
+ `nn.MultiheadAttention`, where 0 indicates padding tokens.
250
+ - Residual connections and normalization are applied after attention and
251
+ feedforward layers.
252
+ """
253
+ attn_output, _ = self.attention(x, key_padding_mask=attention_mask)
254
+ attn_output = self.output_projection(attn_output)
255
+ x = self.norm1(x + self.dropout1(attn_output))
256
+ ff_output = self.feedforward(x)
257
+ x = self.norm2(x + self.dropout2(ff_output))
258
+ return x
259
+
260
+ ###==================================================================================================================###
261
+
262
+ class FeedForward(torch.nn.Module):
263
+ """Feedforward network for transformer encoder layers.
264
+
265
+ Used in `EncoderLayer` to process attention outputs with a two-layer MLP and GELU
266
+ activation.
267
+
268
+ Parameters
269
+ ----------
270
+ embedding_dimension : int
271
+ Input and output embedding dimension.
272
+ scaling_value : int
273
+ Scaling factor for the hidden layer’s dimension (hidden_dim =
274
+ embedding_dimension * scaling_value).
275
+ dropout_rate : float, optional
276
+ Dropout rate after the hidden layer (default: 0.1).
277
+
278
+
279
+ **Notes**
280
+
281
+ - The hidden layer dimension is `embedding_dimension * scaling_value`, following
282
+ standard transformer feedforward designs.
283
+ - GELU activation is used for non-linearity.
284
+ """
285
+ def __init__(self, embedding_dimension: int, scaling_value: int, dropout_rate: float = 0.1) -> None:
286
+ super().__init__()
287
+ self.layers = torch.nn.Sequential(
288
+ torch.nn.Linear(
289
+ in_features=embedding_dimension,
290
+ out_features=embedding_dimension * scaling_value,
291
+ bias=True
292
+ ),
293
+ torch.nn.GELU(),
294
+ torch.nn.Dropout(dropout_rate),
295
+ torch.nn.Linear(
296
+ in_features=embedding_dimension * scaling_value,
297
+ out_features=embedding_dimension,
298
+ bias=True
299
+ )
300
+ )
301
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
302
+ """Processes input embeddings through the feedforward network.
303
+
304
+ Parameters
305
+ ----------
306
+ x : torch.Tensor
307
+ Input embeddings, shape (batch_size, seq_len, embedding_dimension).
308
+
309
+ Returns
310
+ -------
311
+ x (torch.Tensor) - Processed embeddings, shape (batch_size, seq_len, embedding_dimension).
312
+ """
313
+ return self.layers(x)
314
+
315
+ ###==================================================================================================================###
316
+
317
+
318
+ class Attention(nn.Module):
319
+ """Attention module for NoisePredictor, supporting text conditioning or self-attention.
320
+
321
+ Applies multi-head attention to enhance features, with optional text embeddings for
322
+ conditional generation.
323
+
324
+ Parameters
325
+ ----------
326
+ in_channels : int
327
+ Number of input channels (embedding dimension for attention).
328
+ y_embed_dim : int, optional
329
+ Dimensionality of text embeddings (default: 768).
330
+ num_heads : int, optional
331
+ Number of attention heads (default: 4).
332
+ num_groups : int, optional
333
+ Number of groups for group normalization (default: 8).
334
+ dropout_rate : float, optional
335
+ Dropout rate for attention and output (default: 0.1).
336
+
337
+ Attributes
338
+ ----------
339
+ in_channels : int
340
+ Input channel dimension.
341
+ y_embed_dim : int
342
+ Text embedding dimension.
343
+ num_heads : int
344
+ Number of attention heads.
345
+ dropout_rate : float
346
+ Dropout rate.
347
+ attention : torch.nn.MultiheadAttention
348
+ Multi-head attention with `batch_first=True`.
349
+ norm : torch.nn.GroupNorm
350
+ Group normalization before attention.
351
+ dropout : torch.nn.Dropout
352
+ Dropout layer for output.
353
+ y_projection : torch.nn.Linear
354
+ Projection for text embeddings to match `in_channels`.
355
+
356
+ Raises
357
+ ------
358
+ AssertionError
359
+ If input channels do not match `in_channels`.
360
+ ValueError
361
+ If text embeddings (`y`) have incorrect dimensions after projection.
362
+ """
363
+ def __init__(
364
+ self,
365
+ in_channels: int,
366
+ y_embed_dim: int = 768,
367
+ num_heads: int = 4,
368
+ num_groups: int = 8,
369
+ dropout_rate: float = 0.1
370
+ ) -> None:
371
+ super().__init__()
372
+ self.in_channels = in_channels
373
+ self.y_embed_dim = y_embed_dim
374
+ self.num_heads = num_heads
375
+ self.dropout_rate = dropout_rate
376
+ self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True)
377
+ self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
378
+ self.dropout = nn.Dropout(dropout_rate)
379
+ self.y_projection = nn.Linear(y_embed_dim, in_channels)
380
+
381
+ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
382
+ """Applies attention to input features with optional text conditioning.
383
+
384
+ Parameters
385
+ ----------
386
+ x : torch.Tensor
387
+ Input tensor, shape (batch_size, in_channels, height, width).
388
+ y : torch.Tensor, optional
389
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
390
+ (batch_size, y_embed_dim) (default: None).
391
+
392
+ Returns
393
+ -------
394
+ torch.Tensor
395
+ Output tensor, same shape as input `x`.
396
+ """
397
+ batch_size, channels, h, w = x.shape
398
+ assert channels == self.in_channels, f"Expected {self.in_channels} channels, got {channels}"
399
+ x_reshaped = x.view(batch_size, channels, h * w).permute(0, 2, 1)
400
+ if y is not None:
401
+ y = self.y_projection(y)
402
+ if y.dim() != 3:
403
+ if y.dim() == 2:
404
+ y = y.unsqueeze(1)
405
+ else:
406
+ raise ValueError(
407
+ f"Expected y to be 2D or 3D after projection, got {y.dim()}D with shape {y.shape}"
408
+ )
409
+ if y.shape[-1] != self.in_channels:
410
+ raise ValueError(
411
+ f"Expected y's embedding dim to match in_channels ({self.in_channels}), got {y.shape[-1]}"
412
+ )
413
+ out, _ = self.attention(x_reshaped, y, y)
414
+ else:
415
+ out, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
416
+ out = out.permute(0, 2, 1).view(batch_size, channels, h, w)
417
+ out = self.norm(out)
418
+ out = self.dropout(out)
419
+ return out
420
+
421
+
422
+ ###==================================================================================================================###
423
+
424
+
425
+ class Embedding(nn.Module):
426
+ """Token and positional embedding layer for transformer inputs.
427
+
428
+ Used in `TextEncoder`’s transformer to embed token IDs and add positional encodings.
429
+
430
+ Parameters
431
+ ----------
432
+ vocabulary_size : int
433
+ Size of the vocabulary for token embeddings.
434
+ embedding_dimension : int, optional
435
+ Dimension of token and positional embeddings (default: 768).
436
+ max_context_length : int, optional
437
+ Maximum sequence length for precomputing positional encodings (default: 77).
438
+ use_learned_pos : bool, optional
439
+ If True, uses learnable positional embeddings instead of sinusoidal encodings
440
+ (default: False).
441
+
442
+ **Notes**
443
+
444
+ - Supports both sinusoidal (fixed) and learned positional embeddings, selectable via
445
+ `use_learned_pos`.
446
+ - Sinusoidal encodings follow the transformer architecture, computed on-the-fly for
447
+ memory efficiency and cached for sequences up to `max_context_length`.
448
+ - Learned positional embeddings are initialized as a learnable parameter for flexibility.
449
+ - Optimized for device-agnostic operation, ensuring seamless CPU/GPU transitions.
450
+ - The output shape is (batch_size, seq_len, embedding_dimension).
451
+ """
452
+ def __init__(
453
+ self,
454
+ vocabulary_size: int,
455
+ embedding_dimension: int = 768,
456
+ max_context_length: int = 77,
457
+ use_learned_pos: bool = False
458
+ ) -> None:
459
+ super().__init__()
460
+ self.vocabulary_size = vocabulary_size
461
+ self.embedding_dimension = embedding_dimension
462
+ self.max_context_length = max_context_length
463
+ self.use_learned_pos = use_learned_pos
464
+
465
+ # Token embedding layer
466
+ self.token_embedding = nn.Embedding(
467
+ num_embeddings=vocabulary_size,
468
+ embedding_dim=embedding_dimension
469
+ )
470
+
471
+ if use_learned_pos:
472
+ # Learnable positional embeddings
473
+ self.positional_embedding = nn.Parameter(
474
+ torch.randn(1, max_context_length, embedding_dimension) / math.sqrt(embedding_dimension)
475
+ )
476
+ else:
477
+ # Register buffer for sinusoidal encodings
478
+ self.register_buffer(
479
+ "positional_encoding_cache",
480
+ torch.empty(1, 0, embedding_dimension, dtype=torch.float32)
481
+ )
482
+
483
+ def _generate_positional_encoding(self, seq_len: int, device: torch.device) -> torch.Tensor:
484
+ """Generates sinusoidal positional encodings for transformer inputs.
485
+
486
+ Computes positional encodings using sine and cosine functions.
487
+
488
+ Parameters
489
+ ----------
490
+ seq_len : int
491
+ Length of the sequence for which to generate positional encodings.
492
+ device : torch.device
493
+ Device on which to create the positional encodings.
494
+
495
+ Returns
496
+ -------
497
+ torch.Tensor
498
+ Positional encodings, shape (1, seq_len, embedding_dimension), where
499
+ even-indexed dimensions use sine and odd-indexed dimensions use cosine.
500
+
501
+ **Notes**
502
+
503
+ - Uses the formula: for position `pos` and dimension `i`,
504
+ `PE(pos, 2i) = sin(pos / 10000^(2i/d))` and
505
+ `PE(pos, 2i+1) = cos(pos / 10000^(2i/d))`, where `d` is `embedding_dimension`.
506
+ - Fully vectorized for efficiency and supports any sequence length.
507
+ """
508
+ position = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1)
509
+ div_term = torch.exp(
510
+ torch.arange(0, self.embedding_dimension, 2, dtype=torch.float32, device=device) *
511
+ (-math.log(10000.0) / self.embedding_dimension)
512
+ )
513
+ pos_enc = torch.zeros((1, seq_len, self.embedding_dimension), dtype=torch.float32, device=device)
514
+ pos_enc[:, :, 0::2] = torch.sin(position * div_term)
515
+ pos_enc[:, :, 1::2] = torch.cos(position * div_term[:, :-1] if self.embedding_dimension % 2 else div_term)
516
+ return pos_enc
517
+
518
+ def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
519
+ """Embeds token IDs and adds positional encodings.
520
+
521
+ Parameters
522
+ ----------
523
+ token_ids : torch.Tensor
524
+ Token IDs, shape (batch_size, seq_len).
525
+
526
+ Returns
527
+ -------
528
+ torch.Tensor
529
+ Embedded tokens with positional encodings, shape
530
+ (batch_size, seq_len, embedding_dimension).
531
+
532
+ **Notes**
533
+
534
+ - Automatically handles sequences longer than `max_context_length` by generating
535
+ positional encodings on-the-fly.
536
+ - For learned positional embeddings, sequences longer than `max_context_length`
537
+ will raise an error unless truncated.
538
+ - Ensures device compatibility by generating encodings on the input’s device.
539
+ """
540
+ assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
541
+ batch_size, seq_len = token_ids.size()
542
+ device = token_ids.device
543
+
544
+ # Compute token embeddings
545
+ token_embedded = self.token_embedding(token_ids)
546
+
547
+ # Handle positional embeddings
548
+ if self.use_learned_pos:
549
+ if seq_len > self.max_context_length:
550
+ raise ValueError(
551
+ f"Sequence length ({seq_len}) exceeds max_context_length ({self.max_context_length}) "
552
+ "for learned positional embeddings."
553
+ )
554
+ position_encoded = self.positional_embedding[:, :seq_len, :]
555
+ else:
556
+ # Use cached sinusoidal encodings if available and sufficient
557
+ if (self.positional_encoding_cache.size(1) < seq_len or
558
+ self.positional_encoding_cache.device != device):
559
+ self.positional_encoding_cache = self._generate_positional_encoding(
560
+ max(seq_len, self.max_context_length), device
561
+ )
562
+ position_encoded = self.positional_encoding_cache[:, :seq_len, :]
563
+
564
+ return token_embedded + position_encoded
565
+
566
+
567
+ ###==================================================================================================================###
568
+
569
+
570
+ class NoisePredictor(nn.Module):
571
+ """U-Net-like architecture for noise prediction in Diffusion Models.
572
+
573
+ Predicts noise for diffusion models (DDPM, DDIM, SDE), incorporating
574
+ time embeddings and optional text conditioning. used as the `noise_predictor` in
575
+ `Train` and `Sample` from the `ldm`, `sde`, `ddpm`, `ddim` modules.
576
+
577
+ Parameters
578
+ ----------
579
+ in_channels : int
580
+ Number of input channels.
581
+ down_channels : list of int
582
+ List of output channels for downsampling blocks.
583
+ mid_channels : list of int
584
+ List of channels for middle blocks.
585
+ up_channels : list of int
586
+ List of output channels for upsampling blocks.
587
+ down_sampling : list of bool
588
+ List indicating whether to downsample in each down block.
589
+ time_embed_dim : int
590
+ Dimensionality of time embeddings.
591
+ y_embed_dim : int
592
+ Dimensionality of text embeddings for conditioning.
593
+ num_down_blocks : int
594
+ Number of convolutional layer pairs per down block.
595
+ num_mid_blocks : int
596
+ Number of convolutional layer pairs per middle block.
597
+ num_up_blocks : int
598
+ Number of convolutional layer pairs per up block.
599
+ dropout_rate : float, optional
600
+ Dropout rate for convolutional and attention layers (default: 0.1).
601
+ down_sampling_factor : int, optional
602
+ Factor for spatial downsampling/upsampling (default: 2).
603
+ where_y : bool, optional
604
+ If True, text embeddings are used in attention; if False, concatenated to input
605
+ (default: True).
606
+ y_to_all : bool, optional
607
+ If True, apply text-conditioned attention to all layers; if False, only first layer
608
+ (default: False).
609
+
610
+ **Notes**
611
+
612
+ - The architecture follows a U-Net structure with downsampling, bottleneck, and
613
+ upsampling blocks, incorporating time embeddings and optional text conditioning via
614
+ attention or concatenation.
615
+ - Skip connections link down and up blocks, with channel adjustments for concatenation.
616
+ - Weights are initialized with Kaiming normal (Leaky ReLU nonlinearity) for stability.
617
+ - Input and output tensors have the same shape.
618
+ """
619
+ def __init__(
620
+ self,
621
+ in_channels: int,
622
+ down_channels: List[int],
623
+ mid_channels: List[int],
624
+ up_channels: List[int],
625
+ down_sampling: List[bool],
626
+ time_embed_dim: int,
627
+ y_embed_dim: int,
628
+ num_down_blocks: int,
629
+ num_mid_blocks: int,
630
+ num_up_blocks: int,
631
+ dropout_rate: float = 0.1,
632
+ down_sampling_factor: int = 2,
633
+ where_y: bool = True,
634
+ y_to_all: bool = False
635
+ ) -> None:
636
+ super().__init__()
637
+ self.in_channels = in_channels
638
+ self.down_channels = down_channels
639
+ self.mid_channels = mid_channels
640
+ self.up_channels = up_channels
641
+ self.down_sampling = down_sampling
642
+ self.time_embed_dim = time_embed_dim
643
+ self.y_embed_dim = y_embed_dim
644
+ self.num_down_blocks = num_down_blocks
645
+ self.num_mid_blocks = num_mid_blocks
646
+ self.num_up_blocks = num_up_blocks
647
+ self.dropout_rate = dropout_rate
648
+ self.where_y = where_y
649
+ self.up_sampling = list(reversed(self.down_sampling))
650
+ self.conv1 = nn.Conv2d(
651
+ in_channels=self.in_channels,
652
+ out_channels=self.down_channels[0],
653
+ kernel_size=3,
654
+ padding=1
655
+ )
656
+ # initial time embedding projection
657
+ self.time_projection = nn.Sequential(
658
+ nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim),
659
+ nn.SiLU(),
660
+ nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim)
661
+ )
662
+ # down blocks
663
+ self.down_blocks = nn.ModuleList([
664
+ DownBlock(
665
+ in_channels=self.down_channels[i],
666
+ out_channels=self.down_channels[i+1],
667
+ time_embed_dim=self.time_embed_dim,
668
+ y_embed_dim=y_embed_dim,
669
+ num_layers=self.num_down_blocks,
670
+ down_sampling_factor=down_sampling_factor,
671
+ down_sample=self.down_sampling[i],
672
+ dropout_rate=self.dropout_rate,
673
+ y_to_all=y_to_all
674
+ ) for i in range(len(self.down_channels)-1)
675
+ ])
676
+ # middle blocks
677
+ self.mid_blocks = nn.ModuleList([
678
+ MiddleBlock(
679
+ in_channels=self.mid_channels[i],
680
+ out_channels=self.mid_channels[i + 1],
681
+ time_embed_dim=self.time_embed_dim,
682
+ y_embed_dim=y_embed_dim,
683
+ num_layers=self.num_mid_blocks,
684
+ dropout_rate=self.dropout_rate,
685
+ y_to_all=y_to_all
686
+ ) for i in range(len(self.mid_channels) - 1)
687
+ ])
688
+ # up blocks
689
+ skip_channels = list(reversed(self.down_channels))
690
+ self.up_blocks = nn.ModuleList([
691
+ UpBlock(
692
+ in_channels=self.up_channels[i],
693
+ out_channels=self.up_channels[i+1],
694
+ skip_channels=skip_channels[i],
695
+ time_embed_dim=self.time_embed_dim,
696
+ y_embed_dim=y_embed_dim,
697
+ num_layers=self.num_up_blocks,
698
+ up_sampling_factor=down_sampling_factor,
699
+ up_sampling=self.up_sampling[i],
700
+ dropout_rate=self.dropout_rate,
701
+ y_to_all=y_to_all
702
+ ) for i in range(len(self.up_channels)-1)
703
+ ])
704
+ # final convolution layer
705
+ self.conv2 = nn.Sequential(
706
+ nn.GroupNorm(num_groups=8, num_channels=self.up_channels[-1]),
707
+ nn.Dropout(p=self.dropout_rate),
708
+ nn.Conv2d(in_channels=self.up_channels[-1], out_channels=self.in_channels, kernel_size=3, padding=1)
709
+ )
710
+
711
+ def initialize_weights(self) -> None:
712
+ """Initializes model weights for training stability.
713
+
714
+ Applies Kaiming normal initialization to convolutional and linear layers with
715
+ Leaky ReLU nonlinearity (a=0.2), and zeros biases.
716
+ """
717
+ for module in self.modules():
718
+ if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
719
+ nn.init.kaiming_normal_(module.weight, a=0.2, nonlinearity='leaky_relu')
720
+ if module.bias is not None:
721
+ nn.init.zeros_(module.bias)
722
+
723
+ def forward(
724
+ self,
725
+ x: torch.Tensor,
726
+ t: torch.Tensor,
727
+ y: Optional[torch.Tensor] = None,
728
+ clip_embeddings: Optional[torch.Tensor] = None
729
+ ) -> torch.Tensor:
730
+ """Predicts noise given input, time step, and optional text conditioning.
731
+
732
+ Parameters
733
+ ----------
734
+ x : torch.Tensor
735
+ Input tensor, shape (batch_size, in_channels, height, width).
736
+ t : torch.Tensor
737
+ Time steps, shape (batch_size,).
738
+ y : torch.Tensor, optional
739
+ Text embeddings for conditioning, shape (batch_size, seq_len, y_embed_dim)
740
+ or (batch_size, y_embed_dim) (default: None).
741
+ clip_embeddings: torch.Tensor, optional
742
+ used in the context of un-clip algorithm
743
+
744
+ Returns
745
+ -------
746
+ output (torch.Tensor) - Predicted noise, same shape as input `x`.
747
+ """
748
+ if not self.where_y and y is not None:
749
+ x = torch.cat(tensors=[x, y], dim=1)
750
+ output = self.conv1(x)
751
+ time_embed = GetEmbeddedTime(embed_dim=self.time_embed_dim)(time_steps=t)
752
+ time_embed = self.time_projection(time_embed)
753
+
754
+ if clip_embeddings is not None:
755
+ #if len(clip_embeddings.shape) == 3: # [batch_size, seq_len, time_embed_dim]
756
+ # time_embed = time_embed.unsqueeze(1)
757
+ time_embed = time_embed + clip_embeddings
758
+
759
+ skip_connections = []
760
+ for i, down in enumerate(self.down_blocks):
761
+ skip_connections.append(output)
762
+ output = down(x=output, embed_time=time_embed, y=y)
763
+ for i, mid in enumerate(self.mid_blocks):
764
+ output = mid(x=output, embed_time=time_embed, y=y)
765
+ for i, up in enumerate(self.up_blocks):
766
+ skip_connection = skip_connections.pop()
767
+ output = up(x=output, skip_connection=skip_connection, embed_time=time_embed, y=y)
768
+
769
+ output = self.conv2(output)
770
+ return output
771
+
772
+ ###==================================================================================================================###
773
+
774
+ class DownBlock(nn.Module):
775
+ """Downsampling block for NoisePredictor’s encoder.
776
+
777
+ Applies convolutional layers with residual connections, time embeddings, and optional
778
+ text-conditioned attention, followed by downsampling if enabled.
779
+
780
+ Parameters
781
+ ----------
782
+ in_channels : int
783
+ Number of input channels.
784
+ out_channels : int
785
+ Number of output channels.
786
+ time_embed_dim : int
787
+ Dimensionality of time embeddings.
788
+ y_embed_dim : int
789
+ Dimensionality of text embeddings.
790
+ num_layers : int
791
+ Number of convolutional layer pairs (Conv3).
792
+ down_sampling_factor : int
793
+ Factor for spatial downsampling.
794
+ down_sample : bool
795
+ If True, apply downsampling; if False, use identity (no downsampling).
796
+ dropout_rate : float
797
+ Dropout rate for Conv3 and attention layers.
798
+ y_to_all : bool
799
+ If True, apply text-conditioned attention to all layers; if False, only first layer.
800
+ """
801
+ def __init__(
802
+ self,
803
+ in_channels: int,
804
+ out_channels: int ,
805
+ time_embed_dim: int,
806
+ y_embed_dim: int,
807
+ num_layers: int,
808
+ down_sampling_factor: int,
809
+ down_sample: bool,
810
+ dropout_rate: float,
811
+ y_to_all: bool
812
+ ) -> None:
813
+ super().__init__()
814
+ self.num_layers = num_layers
815
+ self.y_to_all = y_to_all
816
+ self.conv1 = nn.ModuleList([
817
+ Conv3(
818
+ in_channels=in_channels if i==0 else out_channels,
819
+ out_channels=out_channels,
820
+ num_groups=8,
821
+ kernel_size=3,
822
+ norm=True,
823
+ activation=True,
824
+ dropout_rate=dropout_rate
825
+ ) for i in range(self.num_layers)
826
+ ])
827
+ self.conv2 = nn.ModuleList([
828
+ Conv3(
829
+ in_channels=out_channels,
830
+ out_channels=out_channels,
831
+ num_groups=8,
832
+ kernel_size=3,
833
+ norm=True,
834
+ activation=True,
835
+ dropout_rate=dropout_rate
836
+ ) for _ in range(self.num_layers)
837
+ ])
838
+ self.time_embedding = nn.ModuleList([
839
+ TimeEmbedding(
840
+ output_dim=out_channels,
841
+ embed_dim=time_embed_dim
842
+ ) for _ in range(self.num_layers)
843
+ ])
844
+ self.attention = nn.ModuleList([
845
+ Attention(
846
+ in_channels=out_channels,
847
+ y_embed_dim=y_embed_dim,
848
+ num_groups=8,
849
+ num_heads=4,
850
+ dropout_rate=dropout_rate
851
+ ) for _ in range(self.num_layers)
852
+ ])
853
+ self.down_sampling = DownSampling(
854
+ in_channels=out_channels,
855
+ out_channels=out_channels,
856
+ down_sampling_factor=down_sampling_factor,
857
+ conv_block=True,
858
+ max_pool=True
859
+ ) if down_sample else nn.Identity()
860
+ self.resnet = nn.ModuleList([
861
+ nn.Conv2d(
862
+ in_channels=in_channels if i == 0 else out_channels,
863
+ out_channels=out_channels,
864
+ kernel_size=1
865
+ ) for i in range(num_layers)
866
+
867
+ ])
868
+
869
+ def forward(self, x: torch.Tensor, embed_time: torch.Tensor, y: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
870
+ """Processes input through convolutions, time embeddings, attention, and downsampling.
871
+
872
+ Parameters
873
+ ----------
874
+ x : torch.Tensor
875
+ Input tensor, shape (batch_size, in_channels, height, width).
876
+ embed_time : torch.Tensor
877
+ Time embeddings, shape (batch_size, time_embed_dim).
878
+ y : torch.Tensor, optional
879
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
880
+ (batch_size, y_embed_dim) (default: None).
881
+ key_padding_mask : torch.Tensor, optional
882
+ Boolean mask, shape (batch_size, seq_len) if `y` is None, or (batch_size, seq_len_y) if `y` is provided,
883
+ where `True` indicates positions to mask out (default: None).
884
+
885
+ Returns
886
+ -------
887
+ output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor) if downsampling; otherwise, same height/width as input.
888
+ """
889
+ output = x
890
+ for i in range(self.num_layers):
891
+ resnet_input = output
892
+ output = self.conv1[i](output)
893
+ output = output + self.time_embedding[i](embed_time)[:, :, None, None]
894
+ output = self.conv2[i](output)
895
+ output = output + self.resnet[i](resnet_input)
896
+
897
+ if not self.y_to_all and i == 0:
898
+ out_attn = self.attention[i](output, y)
899
+ output = output + out_attn
900
+ elif self.y_to_all:
901
+ out_attn = self.attention[i](output, y)
902
+ output = output + out_attn
903
+
904
+ output = self.down_sampling(output)
905
+ return output
906
+
907
+ ###==================================================================================================================###
908
+
909
+ class MiddleBlock(nn.Module):
910
+ """Bottleneck block for NoisePredictor’s middle layers.
911
+
912
+ Applies convolutional layers with residual connections, time embeddings, and optional
913
+ text-conditioned attention, preserving spatial dimensions.
914
+
915
+ Parameters
916
+ ----------
917
+ in_channels : int
918
+ Number of input channels.
919
+ out_channels : int
920
+ Number of output channels.
921
+ time_embed_dim : int
922
+ Dimensionality of time embeddings.
923
+ y_embed_dim : int
924
+ Dimensionality of text embeddings.
925
+ num_layers : int
926
+ Number of convolutional layer pairs (Conv3).
927
+ dropout_rate : float
928
+ Dropout rate for Conv3 and attention layers.
929
+ y_to_all : bool
930
+ If True, apply text-conditioned attention to all layers; if False, only first layer
931
+ (default: False).
932
+ """
933
+ def __init__(
934
+ self,
935
+ in_channels: int,
936
+ out_channels: int,
937
+ time_embed_dim: int,
938
+ y_embed_dim: int,
939
+ num_layers: int,
940
+ dropout_rate: float,
941
+ y_to_all: bool
942
+ ) -> None:
943
+ super().__init__()
944
+ self.num_layers = num_layers
945
+ self.y_to_all = y_to_all
946
+ self.conv1 = nn.ModuleList([
947
+ Conv3(
948
+ in_channels=in_channels if i == 0 else out_channels,
949
+ out_channels=out_channels,
950
+ num_groups=8,
951
+ kernel_size=3,
952
+ norm=True,
953
+ activation=True,
954
+ dropout_rate=dropout_rate
955
+ ) for i in range(self.num_layers+1)
956
+ ])
957
+ self.conv2 = nn.ModuleList([
958
+ Conv3(
959
+ in_channels=out_channels,
960
+ out_channels=out_channels,
961
+ num_groups=8,
962
+ kernel_size=3,
963
+ norm=True,
964
+ activation=True,
965
+ dropout_rate=dropout_rate
966
+ ) for _ in range(self.num_layers+1)
967
+ ])
968
+ self.time_embedding = nn.ModuleList([
969
+ TimeEmbedding(
970
+ output_dim=out_channels,
971
+ embed_dim=time_embed_dim
972
+ ) for _ in range(self.num_layers+1)
973
+ ])
974
+ self.attention = nn.ModuleList([
975
+ Attention(
976
+ in_channels=out_channels,
977
+ y_embed_dim=y_embed_dim,
978
+ num_groups=8,
979
+ num_heads=4,
980
+ dropout_rate=dropout_rate
981
+ ) for _ in range(self.num_layers + 1)
982
+ ])
983
+ self.resnet = nn.ModuleList([
984
+ nn.Conv2d(
985
+ in_channels=in_channels if i == 0 else out_channels,
986
+ out_channels=out_channels,
987
+ kernel_size=1
988
+ ) for i in range(num_layers+1)
989
+ ])
990
+
991
+ def forward(self, x: torch.Tensor, embed_time: torch.Tensor, y: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
992
+ """Processes input through convolutions, time embeddings, and attention.
993
+
994
+ Parameters
995
+ ----------
996
+ x : torch.Tensor
997
+ Input tensor, shape (batch_size, in_channels, height, width).
998
+ embed_time : torch.Tensor
999
+ Time embeddings, shape (batch_size, time_embed_dim).
1000
+ y : torch.Tensor, optional
1001
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
1002
+ (batch_size, y_embed_dim) (default: None).
1003
+ key_padding_mask : torch.Tensor, optional
1004
+ Boolean mask, shape (batch_size, seq_len) if `y` is None, or (batch_size, seq_len_y) if `y` is provided,
1005
+ where `True` indicates positions to mask out (default: None).
1006
+
1007
+ Returns
1008
+ -------
1009
+ output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height, width).
1010
+ """
1011
+ output = x
1012
+ resnet_input = output
1013
+ output = self.conv1[0](output)
1014
+ output = output + self.time_embedding[0](embed_time)[:, :, None, None]
1015
+ output = self.conv2[0](output)
1016
+ output = output + self.resnet[0](resnet_input)
1017
+
1018
+ for i in range(self.num_layers):
1019
+ if not self.y_to_all and i == 0:
1020
+ out_attn = self.attention[i](output, y)
1021
+ output = output + out_attn
1022
+ elif self.y_to_all:
1023
+ out_attn = self.attention[i](output, y)
1024
+ output = output + out_attn
1025
+ resnet_input = output
1026
+ output = self.conv1[i + 1](output)
1027
+ output = output + self.time_embedding[i + 1](embed_time)[:, :, None, None]
1028
+ output = self.conv2[i + 1](output)
1029
+ output = output + self.resnet[i+1](resnet_input)
1030
+ return output
1031
+
1032
+ ###==================================================================================================================###
1033
+
1034
+ class UpBlock(nn.Module):
1035
+ """Upsampling block for NoisePredictor’s decoder.
1036
+
1037
+ Applies upsampling (if enabled), concatenates skip connections, and processes through
1038
+ convolutional layers with residual connections, time embeddings, and optional
1039
+ text-conditioned attention.
1040
+
1041
+ Parameters
1042
+ ----------
1043
+ in_channels : int
1044
+ Number of input channels (before upsampling).
1045
+ out_channels : int
1046
+ Number of output channels.
1047
+ skip_channels : int
1048
+ Number of channels from skip connection.
1049
+ time_embed_dim : int
1050
+ Dimensionality of time embeddings.
1051
+ y_embed_dim : int
1052
+ Dimensionality of text embeddings.
1053
+ num_layers : int
1054
+ Number of convolutional layer pairs (Conv3).
1055
+ up_sampling_factor : int
1056
+ Factor for spatial upsampling.
1057
+ up_sampling : bool
1058
+ If True, apply upsampling; if False, use identity (no upsampling).
1059
+ dropout_rate : float
1060
+ Dropout rate for Conv3 and attention layers.
1061
+ y_to_all : bool
1062
+ If True, apply text-conditioned attention to all layers; if False, only first layer
1063
+ (default: False).
1064
+ """
1065
+ def __init__(
1066
+ self,
1067
+ in_channels: int,
1068
+ out_channels: int,
1069
+ skip_channels: int,
1070
+ time_embed_dim: int,
1071
+ y_embed_dim: int,
1072
+ num_layers: int,
1073
+ up_sampling_factor: int,
1074
+ up_sampling: bool,
1075
+ dropout_rate: float,
1076
+ y_to_all: bool
1077
+ ) -> None:
1078
+ super().__init__()
1079
+ self.num_layers = num_layers
1080
+ self.y_to_all = y_to_all
1081
+ effective_in_channels = in_channels // 2 + skip_channels
1082
+ self.conv1 = nn.ModuleList([
1083
+ Conv3(
1084
+ in_channels=effective_in_channels if i == 0 else out_channels,
1085
+ out_channels=out_channels,
1086
+ num_groups=8,
1087
+ kernel_size=3,
1088
+ norm=True,
1089
+ activation=True,
1090
+ dropout_rate=dropout_rate
1091
+ ) for i in range(self.num_layers)
1092
+ ])
1093
+ self.conv2 = nn.ModuleList([
1094
+ Conv3(
1095
+ in_channels=out_channels,
1096
+ out_channels=out_channels,
1097
+ num_groups=8,
1098
+ kernel_size=3,
1099
+ norm=True,
1100
+ activation=True,
1101
+ dropout_rate=dropout_rate
1102
+ ) for _ in range(self.num_layers)
1103
+ ])
1104
+ self.time_embedding = nn.ModuleList([
1105
+ TimeEmbedding(
1106
+ output_dim=out_channels,
1107
+ embed_dim=time_embed_dim
1108
+ ) for _ in range(self.num_layers)
1109
+ ])
1110
+ self.attention = nn.ModuleList([
1111
+ Attention(
1112
+ in_channels=out_channels,
1113
+ y_embed_dim=y_embed_dim,
1114
+ num_groups=8,
1115
+ num_heads=4,
1116
+ dropout_rate=dropout_rate
1117
+ ) for _ in range(self.num_layers)
1118
+ ])
1119
+ self.up_sampling_ = UpSampling(
1120
+ in_channels=in_channels,
1121
+ out_channels=in_channels,
1122
+ up_sampling_factor=up_sampling_factor,
1123
+ conv_block=True,
1124
+ up_sampling=True
1125
+ ) if up_sampling else nn.Identity()
1126
+ self.resnet = nn.ModuleList([
1127
+ nn.Conv2d(
1128
+ in_channels=effective_in_channels if i == 0 else out_channels,
1129
+ out_channels=out_channels,
1130
+ kernel_size=1
1131
+ ) for i in range(num_layers)
1132
+
1133
+ ])
1134
+
1135
+ def forward(self, x: torch.Tensor, skip_connection: torch.Tensor, embed_time: torch.Tensor, y: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1136
+ """Processes input through upsampling, skip connection, convolutions, time embeddings, and attention.
1137
+
1138
+ Parameters
1139
+ ----------
1140
+ x : torch.Tensor
1141
+ Input tensor, shape (batch_size, in_channels, height, width).
1142
+ skip_connection : torch.Tensor
1143
+ Skip connection tensor, shape (batch_size, skip_channels,
1144
+ height*up_sampling_factor, width*up_sampling_factor).
1145
+ embed_time : torch.Tensor
1146
+ Time embeddings, shape (batch_size, time_embed_dim).
1147
+ y : torch.Tensor, optional
1148
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
1149
+ (batch_size, y_embed_dim) (default: None).
1150
+ key_padding_mask : torch.Tensor, optional
1151
+ Boolean mask, shape (batch_size, seq_len) if `y` is None, or (batch_size, seq_len_y) if `y` is provided,
1152
+ where `True` indicates positions to mask out (default: None).
1153
+
1154
+ Returns
1155
+ -------
1156
+ output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height*up_sampling_factor, width*up_sampling_factor) if upsampling; otherwise, same height/width as input (after skip connection).
1157
+ """
1158
+ x = self.up_sampling_(x)
1159
+ x = torch.cat(tensors=[x, skip_connection], dim=1)
1160
+ output = x
1161
+ for i in range(self.num_layers):
1162
+ resnet_input = output
1163
+ output = self.conv1[i](output)
1164
+ output = output + self.time_embedding[i](embed_time)[:, :, None, None]
1165
+ output = self.conv2[i](output)
1166
+ output = output + self.resnet[i](resnet_input)
1167
+
1168
+ if not self.y_to_all and i == 0:
1169
+ out_attn = self.attention[i](output, y)
1170
+ output = output + out_attn
1171
+ elif self.y_to_all:
1172
+ out_attn = self.attention[i](output, y)
1173
+ output = output + out_attn
1174
+
1175
+ return output
1176
+
1177
+ ###==================================================================================================================###
1178
+
1179
+ class Conv3(nn.Module):
1180
+ """Convolutional layer with optional group normalization, SiLU activation, and dropout.
1181
+
1182
+ Used in DownBlock, MiddleBlock, and UpBlock for feature extraction in NoisePredictor.
1183
+
1184
+ Parameters
1185
+ ----------
1186
+ in_channels : int
1187
+ Number of input channels.
1188
+ out_channels : int
1189
+ Number of output channels.
1190
+ num_groups : int, optional
1191
+ Number of groups for group normalization (default: 8).
1192
+ kernel_size : int, optional
1193
+ Convolutional kernel size (default: 3).
1194
+ norm : bool, optional
1195
+ If True, apply group normalization (default: True).
1196
+ activation : bool, optional
1197
+ If True, apply SiLU activation (default: True).
1198
+ dropout_rate : float, optional
1199
+ Dropout rate (default: 0.2).
1200
+ """
1201
+ def __init__(
1202
+ self,
1203
+ in_channels: int,
1204
+ out_channels: int,
1205
+ num_groups: int = 8,
1206
+ kernel_size: int = 3,
1207
+ norm: bool = True,
1208
+ activation: bool = True,
1209
+ dropout_rate: float = 0.2
1210
+ ) -> None:
1211
+ super().__init__()
1212
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
1213
+ self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels) if norm else nn.Identity()
1214
+ self.activation = nn.SiLU() if activation else nn.Identity()
1215
+ self.dropout = nn.Dropout(p=dropout_rate)
1216
+
1217
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
1218
+ """Processes input through convolution, normalization, activation, and dropout.
1219
+
1220
+ Parameters
1221
+ ----------
1222
+ batch : torch.Tensor
1223
+ Input tensor, shape (batch_size, in_channels, height, width).
1224
+
1225
+ Returns
1226
+ -------
1227
+ batch (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height, width).
1228
+ """
1229
+ batch = self.conv(batch)
1230
+ batch = self.group_norm(batch)
1231
+ batch = self.activation(batch)
1232
+ batch = self.dropout(batch)
1233
+ return batch
1234
+
1235
+ ###==================================================================================================================###
1236
+
1237
+ class TimeEmbedding(nn.Module):
1238
+ """Time embedding projection for conditioning NoisePredictor layers.
1239
+
1240
+ Projects time embeddings to match the channel dimension of convolutional outputs.
1241
+
1242
+ Parameters
1243
+ ----------
1244
+ output_dim : int
1245
+ Output channel dimension (matches convolutional channels).
1246
+ embed_dim : int
1247
+ Input time embedding dimension.
1248
+ """
1249
+ def __init__(self, output_dim: int, embed_dim: int) -> None:
1250
+ super().__init__()
1251
+ self.embedding = nn.Sequential(
1252
+ nn.SiLU(),
1253
+ nn.Linear(in_features=embed_dim, out_features=output_dim)
1254
+ )
1255
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
1256
+ """Projects time embeddings to output dimension.
1257
+
1258
+ Parameters
1259
+ ----------
1260
+ batch : torch.Tensor
1261
+ Time embeddings, shape (batch_size, embed_dim).
1262
+
1263
+ Returns
1264
+ -------
1265
+ torch.Tensor
1266
+ Projected embeddings, shape (batch_size, output_dim).
1267
+ """
1268
+ return self.embedding(batch)
1269
+
1270
+ ###==================================================================================================================###
1271
+
1272
+ class GetEmbeddedTime(nn.Module):
1273
+ """Generates sinusoidal time embeddings for NoisePredictor.
1274
+
1275
+ Creates positional encodings for time steps using sine and cosine functions, following
1276
+ the transformer embedding approach.
1277
+
1278
+ Parameters
1279
+ ----------
1280
+ embed_dim : int
1281
+ Dimensionality of the time embeddings (must be even).
1282
+ """
1283
+ def __init__(self, embed_dim: int) -> None:
1284
+ super().__init__()
1285
+ assert embed_dim % 2 == 0, "The embedding dimension must be divisible by two"
1286
+ self.embed_dim = embed_dim
1287
+
1288
+ def forward(self, time_steps: torch.Tensor) -> torch.Tensor:
1289
+ """Generates sinusoidal embeddings for time steps.
1290
+
1291
+ Parameters
1292
+ ----------
1293
+ time_steps : torch.Tensor
1294
+ Time steps, shape (batch_size,).
1295
+
1296
+ Returns
1297
+ -------
1298
+ embed_time (torch.Tensor) - Sinusoidal embeddings, shape (batch_size, embed_dim).
1299
+ """
1300
+ i = torch.arange(start=0, end=self.embed_dim // 2, dtype=torch.float32, device=time_steps.device)
1301
+ factor = 10000 ** (2 * i / self.embed_dim)
1302
+ embed_time = time_steps[:, None] / factor
1303
+ embed_time = torch.cat(tensors=[torch.sin(embed_time), torch.cos(embed_time)], dim=-1)
1304
+ return embed_time
1305
+
1306
+ ###==================================================================================================================###
1307
+
1308
+
1309
+ class DownSampling(nn.Module):
1310
+ """Downsampling module for NoisePredictor’s DownBlock.
1311
+
1312
+ Combines convolutional downsampling and max pooling (if enabled), concatenating
1313
+ outputs to preserve feature information.
1314
+
1315
+ Parameters
1316
+ ----------
1317
+ in_channels : int
1318
+ Number of input channels.
1319
+ out_channels : int
1320
+ Number of output channels.
1321
+ down_sampling_factor : int
1322
+ Factor for spatial downsampling.
1323
+ conv_block : bool, optional
1324
+ If True, include convolutional path (default: True).
1325
+ max_pool : bool, optional
1326
+ If True, include max pooling path (default: True).
1327
+ """
1328
+ def __init__(
1329
+ self,
1330
+ in_channels: int,
1331
+ out_channels: int,
1332
+ down_sampling_factor: int,
1333
+ conv_block: bool = True,
1334
+ max_pool: bool = True
1335
+ ) -> None:
1336
+ super().__init__()
1337
+ self.conv_block = conv_block
1338
+ self.max_pool = max_pool
1339
+ self.down_sampling_factor = down_sampling_factor
1340
+ self.conv = nn.Sequential(
1341
+ nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
1342
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2 if max_pool else out_channels,
1343
+ kernel_size=3, stride=down_sampling_factor, padding=1)
1344
+ ) if conv_block else nn.Identity()
1345
+ self.pool = nn.Sequential(
1346
+ nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
1347
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels//2 if conv_block else out_channels,
1348
+ kernel_size=1, stride=1, padding=0)
1349
+ ) if max_pool else nn.Identity()
1350
+
1351
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
1352
+ """Downsamples input using convolutional and/or pooling paths.
1353
+
1354
+ Parameters
1355
+ ----------
1356
+ batch : torch.Tensor
1357
+ Input tensor, shape (batch_size, in_channels, height, width).
1358
+
1359
+ Returns
1360
+ -------
1361
+ batch (torch.Tensor) - Downsampled tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor).
1362
+ """
1363
+ if not self.conv_block:
1364
+ return self.pool(batch)
1365
+ if not self.max_pool:
1366
+ return self.conv(batch)
1367
+ return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
1368
+
1369
+ ###==================================================================================================================###
1370
+
1371
+ class UpSampling(nn.Module):
1372
+ """Upsampling module for NoisePredictor’s UpBlock.
1373
+
1374
+ Combines transposed convolution and nearest-neighbor upsampling (if enabled),
1375
+ concatenating outputs to preserve feature information, with interpolation to align
1376
+ spatial dimensions if needed.
1377
+
1378
+ Parameters
1379
+ ----------
1380
+ in_channels : int
1381
+ Number of input channels.
1382
+ out_channels : int
1383
+ Number of output channels.
1384
+ up_sampling_factor : int
1385
+ Factor for spatial upsampling.
1386
+ conv_block : bool, optional
1387
+ If True, include transposed convolutional path (default: True).
1388
+ up_sampling : bool, optional
1389
+ If True, include nearest-neighbor upsampling path (default: True).
1390
+ """
1391
+ def __init__(
1392
+ self,
1393
+ in_channels: int,
1394
+ out_channels: int,
1395
+ up_sampling_factor: int,
1396
+ conv_block: bool = True,
1397
+ up_sampling: bool = True
1398
+ ) -> None:
1399
+ super().__init__()
1400
+ self.conv_block = conv_block
1401
+ self.up_sampling = up_sampling
1402
+ self.up_sampling_factor = up_sampling_factor
1403
+ half_out_channels = out_channels // 2
1404
+ self.conv = nn.Sequential(
1405
+ nn.ConvTranspose2d(
1406
+ in_channels=in_channels,
1407
+ out_channels=half_out_channels if up_sampling else out_channels,
1408
+ kernel_size=3,
1409
+ stride=up_sampling_factor,
1410
+ padding=1,
1411
+ output_padding=up_sampling_factor - 1
1412
+ ),
1413
+ nn.Conv2d(
1414
+ in_channels=half_out_channels if up_sampling else out_channels,
1415
+ out_channels=half_out_channels if up_sampling else out_channels,
1416
+ kernel_size=1,
1417
+ stride=1,
1418
+ padding=0
1419
+ )
1420
+ ) if conv_block else nn.Identity()
1421
+
1422
+ self.up_sample = nn.Sequential(
1423
+ nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
1424
+ nn.Conv2d(in_channels=in_channels, out_channels=half_out_channels if conv_block else out_channels,
1425
+ kernel_size=1, stride=1, padding=0)
1426
+ ) if up_sampling else nn.Identity()
1427
+
1428
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
1429
+ """Upsamples input using convolutional and/or upsampling paths.
1430
+
1431
+ Parameters
1432
+ ----------
1433
+ batch : torch.Tensor
1434
+ Input tensor, shape (batch_size, in_channels, height, width).
1435
+
1436
+ Returns
1437
+ -------
1438
+ batch (torch.Tensor) - Upsampled tensor, shape (batch_size, out_channels, height*up_sampling_factor, width*up_sampling_factor).
1439
+
1440
+ **Notes**
1441
+
1442
+ - Interpolation is applied if the spatial dimensions of the convolutional and
1443
+ upsampling paths differ, using nearest-neighbor mode.
1444
+ """
1445
+ if not self.conv_block:
1446
+ return self.up_sample(batch)
1447
+ if not self.up_sampling:
1448
+ return self.conv(batch)
1449
+ conv_output = self.conv(batch)
1450
+ up_sample_output = self.up_sample(batch)
1451
+ if conv_output.shape[2:] != up_sample_output.shape[2:]:
1452
+ _, _, h, w = conv_output.shape
1453
+ up_sample_output = torch.nn.functional.interpolate(
1454
+ up_sample_output,
1455
+ size=(h, w),
1456
+ mode='nearest'
1457
+ )
1458
+ return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
1459
+
1460
+ ###==================================================================================================================###
1461
+
1462
+ class Metrics:
1463
+ """Computes image quality metrics for evaluating diffusion models.
1464
+
1465
+ Supports Mean Squared Error (MSE), Peak Signal-to-Noise Ratio (PSNR), Structural
1466
+ Similarity Index (SSIM), Fréchet Inception Distance (FID), and Learned Perceptual
1467
+ Image Patch Similarity (LPIPS) for comparing generated and ground truth images.
1468
+
1469
+ Parameters
1470
+ ----------
1471
+ device : str, optional
1472
+ Device for computation (e.g., 'cuda', 'cpu') (default: 'cuda').
1473
+ fid : bool, optional
1474
+ If True, compute FID score (default: True).
1475
+ metrics : bool, optional
1476
+ If True, compute MSE, PSNR, and SSIM (default: False).
1477
+ lpips : bool, optional
1478
+ If True, compute LPIPS using VGG backbone (default: False).
1479
+ """
1480
+
1481
+ def __init__(
1482
+ self,
1483
+ device: str = "cuda",
1484
+ fid: bool = True,
1485
+ metrics: bool = False,
1486
+ lpips_: bool = False
1487
+ ) -> None:
1488
+ self.device = device
1489
+ self.fid = fid
1490
+ self.metrics = metrics
1491
+ self.lpips = lpips_
1492
+ self.lpips_model = LearnedPerceptualImagePatchSimilarity(
1493
+ net_type='vgg',
1494
+ normalize=True # This handles [0,1] -> [-1,1] conversion
1495
+ ).to(device) if self.lpips else None
1496
+ self.temp_dir_real = "temp_real"
1497
+ self.temp_dir_fake = "temp_fake"
1498
+
1499
+ def compute_fid(self, real_images: torch.Tensor, fake_images: torch.Tensor) -> float:
1500
+ """Computes the Fréchet Inception Distance (FID) between real and generated images.
1501
+
1502
+ Saves images to temporary directories and uses Inception V3 to compute FID,
1503
+ cleaning up directories afterward.
1504
+
1505
+ Parameters
1506
+ ----------
1507
+ real_images : torch.Tensor
1508
+ Real images, shape (batch_size, channels, height, width), in [-1, 1].
1509
+ fake_images : torch.Tensor
1510
+ Generated images, same shape, in [-1, 1].
1511
+
1512
+ Returns
1513
+ -------
1514
+ fid (float) - FID score, or `float('inf')` if computation fails.
1515
+
1516
+ **Notes**
1517
+
1518
+ - Images are normalized to [0, 1] and saved as PNG files for FID computation.
1519
+ - Uses Inception V3 with 2048-dimensional features (`dims=2048`).
1520
+ """
1521
+ if real_images.shape != fake_images.shape:
1522
+ raise ValueError(f"Shape mismatch: real_images {real_images.shape}, fake_images {fake_images.shape}")
1523
+
1524
+ real_images = (real_images + 1) / 2
1525
+ fake_images = (fake_images + 1) / 2
1526
+ real_images = real_images.clamp(0, 1).cpu()
1527
+ fake_images = fake_images.clamp(0, 1).cpu()
1528
+
1529
+ os.makedirs(self.temp_dir_real, exist_ok=True)
1530
+ os.makedirs(self.temp_dir_fake, exist_ok=True)
1531
+
1532
+ try:
1533
+ for i, (real, fake) in enumerate(zip(real_images, fake_images)):
1534
+ save_image(real, f"{self.temp_dir_real}/{i}.png")
1535
+ save_image(fake, f"{self.temp_dir_fake}/{i}.png")
1536
+
1537
+ fid = fid_score.calculate_fid_given_paths(
1538
+ paths=[self.temp_dir_real, self.temp_dir_fake],
1539
+ batch_size=50,
1540
+ device=self.device,
1541
+ dims=2048
1542
+ )
1543
+ except Exception as e:
1544
+ print(f"Error computing FID: {e}")
1545
+ fid = float('inf')
1546
+ finally:
1547
+ shutil.rmtree(self.temp_dir_real, ignore_errors=True)
1548
+ shutil.rmtree(self.temp_dir_fake, ignore_errors=True)
1549
+
1550
+ return fid
1551
+
1552
+ def compute_metrics(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float]:
1553
+ """Computes MSE, PSNR, and SSIM for evaluating image quality.
1554
+
1555
+ Parameters
1556
+ ----------
1557
+ x : torch.Tensor
1558
+ Ground truth images, shape (batch_size, channels, height, width).
1559
+ x_hat : torch.Tensor
1560
+ Generated images, same shape as `x`.
1561
+
1562
+ Returns
1563
+ -------
1564
+ mse : float
1565
+ Mean squared error.
1566
+ psnr : float
1567
+ Peak signal-to-noise ratio.
1568
+ ssim : float
1569
+ Structural similarity index (mean over batch).
1570
+ """
1571
+ if x.shape != x_hat.shape:
1572
+ raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
1573
+
1574
+ mse = F.mse_loss(x_hat, x)
1575
+ psnr = -10 * torch.log10(mse)
1576
+ c1, c2 = (0.01 * 2) ** 2, (0.03 * 2) ** 2 # Adjusted for [-1, 1] range
1577
+ eps = 1e-8
1578
+ mu_x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
1579
+ mu_y = F.avg_pool2d(x_hat, kernel_size=3, stride=1, padding=1)
1580
+ mu_xy = mu_x * mu_y
1581
+ sigma_x_sq = F.avg_pool2d(x.pow(2), kernel_size=3, stride=1, padding=1) - mu_x.pow(2)
1582
+ sigma_y_sq = F.avg_pool2d(x_hat.pow(2), kernel_size=3, stride=1, padding=1) - mu_y.pow(2)
1583
+ sigma_xy = F.avg_pool2d(x * x_hat, kernel_size=3, stride=1, padding=1) - mu_xy
1584
+ ssim = ((2 * mu_xy + c1) * (2 * sigma_xy + c2)) / (
1585
+ (mu_x.pow(2) + mu_y.pow(2) + c1) * (sigma_x_sq + sigma_y_sq + c2) + eps
1586
+ )
1587
+
1588
+ return mse.item(), psnr.item(), ssim.mean().item()
1589
+
1590
+ def compute_lpips(self, x: torch.Tensor, x_hat: torch.Tensor) -> float:
1591
+ """Computes LPIPS using a pre-trained VGG network.
1592
+
1593
+ Parameters
1594
+ ----------
1595
+ x : torch.Tensor
1596
+ Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
1597
+ x_hat : torch.Tensor
1598
+ Generated images, same shape as `x`.
1599
+
1600
+ Returns
1601
+ -------
1602
+ lpips (float) - Mean LPIPS score over the batch.
1603
+ """
1604
+ if self.lpips_model is None:
1605
+ raise RuntimeError("LPIPS model not initialized; set lpips=True in __init__")
1606
+ if x.shape != x_hat.shape:
1607
+ raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
1608
+
1609
+ # Normalize inputs to [0, 1] range
1610
+ x = (x + 1) / 2 # Convert from [-1, 1] to [0, 1]
1611
+ x_hat = (x_hat + 1) / 2
1612
+ x = x.clamp(0, 1) # Ensure values are in [0, 1]
1613
+ x_hat = x_hat.clamp(0, 1)
1614
+
1615
+ x = x.to(self.device)
1616
+ x_hat = x_hat.to(self.device)
1617
+
1618
+ # Convert grayscale to RGB if needed
1619
+ if x.shape[1] == 1:
1620
+ x = x.repeat(1, 3, 1, 1) # Repeat grayscale channel 3 times
1621
+ if x_hat.shape[1] == 1:
1622
+ x_hat = x_hat.repeat(1, 3, 1, 1)
1623
+
1624
+ return self.lpips_model(x, x_hat).mean().item()
1625
+
1626
+ def forward(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float, float, float]:
1627
+ """Computes specified metrics for ground truth and generated images.
1628
+
1629
+ Parameters
1630
+ ----------
1631
+ x : torch.Tensor
1632
+ Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
1633
+ x_hat : torch.Tensor
1634
+ Generated images, same shape as `x`.
1635
+
1636
+ Returns
1637
+ -------
1638
+ fid : float, or `float('inf')` if not computed
1639
+ Mean FID score.
1640
+ mse : float, or None if not computed
1641
+ Mean MSE
1642
+ psnr : float, or None if not computed
1643
+ Mean PSNR
1644
+ ssim : float, or None if not computed
1645
+ Mean SSIM
1646
+ lpips_score : float, or None if not computed
1647
+ Mean LPIPS score
1648
+ """
1649
+ fid = float('inf')
1650
+ mse, psnr, ssim = None, None, None
1651
+ lpips_score = None
1652
+
1653
+ if self.metrics:
1654
+ mse, psnr, ssim = self.compute_metrics(x, x_hat)
1655
+ if self.fid:
1656
+ fid = self.compute_fid(x, x_hat)
1657
+ if self.lpips:
1658
+ lpips_score = self.compute_lpips(x, x_hat)
1659
+
1660
+ return fid, mse, psnr, ssim, lpips_score