TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
unclip/utils.py ADDED
@@ -0,0 +1,1793 @@
1
+ """
2
+ **Utilities for text encoding, noise prediction, and evaluation in diffusion models**
3
+
4
+ This module provides core components for building diffusion model pipelines, including
5
+ text encoding, used as conditional model, U-Net-based noise prediction, and image quality evaluation. These
6
+ utilities support various diffusion model architectures, such as DDPM, DDIM, LDM, and
7
+ SDE, and are designed for standalone use in model training and sampling.
8
+
9
+ **Primary Components**
10
+
11
+ - **TextEncoder**: Encodes text prompts into embeddings using a pre-trained BERT model or a custom transformer.
12
+ - **NoisePredictor**: U-Net-like architecture for predicting noise in diffusion models, supporting time and text conditioning.
13
+ - **Metrics**: Computes image quality metrics (MSE, PSNR, SSIM, FID, LPIPS) for evaluating generated images.
14
+
15
+ **Notes**
16
+
17
+ - The primary components are intended to be imported directly for use in diffusion model workflows.
18
+ - Additional supporting classes and functions in this module provide internal functionality for the primary components.
19
+
20
+
21
+ ---------------------------------------------------------------------------------
22
+ """
23
+
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch.nn.attention import sdpa_kernel, SDPBackend
29
+ from pytorch_fid import fid_score
30
+ from torchvision.utils import save_image
31
+ from transformers import BertModel
32
+ import os
33
+ import lpips
34
+ import math
35
+ import shutil
36
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
37
+ from torchmetrics.image.fid import FrechetInceptionDistance
38
+ from torchvision.utils import save_image
39
+ from typing import Optional, Tuple, List
40
+
41
+
42
+ ###==================================================================================================================###
43
+
44
+ class TextEncoder(torch.nn.Module):
45
+ """Transformer-based encoder for text prompts in conditional diffusion models.
46
+
47
+ Encodes text prompts into embeddings using either a pre-trained BERT model or a
48
+ custom transformer architecture. Used as the `conditional_model` in diffusion models
49
+ (e.g., DDPM, DDIM, SDE, LDM) to provide conditional inputs for noise prediction.
50
+
51
+ Parameters
52
+ ----------
53
+ use_pretrained_model : bool, optional
54
+ If True, uses a pre-trained BERT model; otherwise, builds a custom transformer
55
+ (default: True).
56
+ model_name : str, optional
57
+ Name of the pre-trained model to load (default: "bert-base-uncased").
58
+ vocabulary_size : int, optional
59
+ Size of the vocabulary for the custom transformer's embedding layer
60
+ (default: 30522).
61
+ num_layers : int, optional
62
+ Number of transformer encoder layers for the custom transformer (default: 6).
63
+ input_dimension : int, optional
64
+ Input embedding dimension for the custom transformer (default: 768).
65
+ output_dimension : int, optional
66
+ Output embedding dimension for both pre-trained and custom models
67
+ (default: 768).
68
+ num_heads : int, optional
69
+ Number of attention heads in the custom transformer (default: 8).
70
+ context_length : int, optional
71
+ Maximum sequence length for text prompts (default: 77).
72
+ dropout_rate : float, optional
73
+ Dropout rate for attention and feedforward layers (default: 0.1).
74
+ qkv_bias : bool, optional
75
+ If True, includes bias in query, key, and value projections for the custom
76
+ transformer (default: False).
77
+ scaling_value : int, optional
78
+ Scaling factor for the feedforward layer's hidden dimension in the custom
79
+ transformer (default: 4).
80
+ epsilon : float, optional
81
+ Epsilon for layer normalization in the custom transformer (default: 1e-5).
82
+ use_learned_pos : bool, optional
83
+ If True, in the transformer structure uses learnable positional embeddings instead of sinusoidal encodings
84
+ (default: False).
85
+
86
+ **Notes**
87
+
88
+ - When `use_pretrained_model` is True, the BERT model's parameters are frozen
89
+ (`requires_grad = False`), and a projection layer maps outputs to
90
+ `output_dimension`.
91
+ - The custom transformer uses `EncoderLayer` modules with multi-head attention and
92
+ feedforward networks, supporting variable input/output dimensions.
93
+ - The output shape is (batch_size, context_length, output_dimension).
94
+ """
95
+ def __init__(
96
+ self,
97
+ use_pretrained_model: bool = True,
98
+ model_name: str = "bert-base-uncased",
99
+ vocabulary_size: int = 30522,
100
+ num_layers: int = 6,
101
+ input_dimension: int = 768,
102
+ output_dimension: int = 768,
103
+ num_heads: int = 8,
104
+ context_length: int = 77,
105
+ dropout_rate: float = 0.1,
106
+ qkv_bias: bool = False,
107
+ scaling_value: int = 4,
108
+ epsilon: float = 1e-5,
109
+ use_learned_pos: bool = False
110
+ ) -> None:
111
+ super().__init__()
112
+ self.use_pretrained_model = use_pretrained_model
113
+ if self.use_pretrained_model:
114
+ self.bert = BertModel.from_pretrained(model_name)
115
+ for param in self.bert.parameters():
116
+ param.requires_grad = False
117
+ self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
118
+ else:
119
+ self.embedding = Embedding(
120
+ vocabulary_size=vocabulary_size,
121
+ embedding_dimension=input_dimension,
122
+ max_context_length=context_length,
123
+ use_learned_pos=use_learned_pos
124
+ )
125
+ self.layers = torch.nn.ModuleList([
126
+ EncoderLayer(
127
+ input_dimension=input_dimension,
128
+ output_dimension=output_dimension,
129
+ num_heads=num_heads,
130
+ dropout_rate=dropout_rate,
131
+ qkv_bias=qkv_bias,
132
+ scaling_value=scaling_value,
133
+ epsilon=epsilon
134
+ )
135
+ for _ in range(num_layers)
136
+ ])
137
+
138
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
139
+ """Encodes text prompts into embeddings.
140
+
141
+ Processes input token IDs and an optional attention mask to produce embeddings
142
+ using either a pre-trained BERT model or a custom transformer.
143
+
144
+ Parameters
145
+ ----------
146
+ x : torch.Tensor
147
+ Token IDs, shape (batch_size, seq_len).
148
+ attention_mask : torch.Tensor, optional
149
+ Attention mask, shape (batch_size, seq_len), where 0 indicates padding
150
+ tokens to ignore (default: None).
151
+
152
+ Returns
153
+ -------
154
+ x (torch.Tensor) - Encoded embeddings, shape (batch_size, seq_len, output_dimension).
155
+
156
+ **Notes**
157
+
158
+ - For pre-trained BERT, the `last_hidden_state` is projected to
159
+ `output_dimension` and this layer is the only trainable layer in the model.
160
+ - For the custom transformer, token embeddings are processed through
161
+ `Embedding` and `EncoderLayer` modules.
162
+ - The attention mask should be 0 for padding tokens and 1 for valid tokens when
163
+ using the custom transformer, or follow BERT's convention for pre-trained
164
+ models.
165
+ """
166
+ if self.use_pretrained_model:
167
+ x = self.bert(input_ids=x, attention_mask=attention_mask)
168
+ x = x.last_hidden_state
169
+ x = self.projection(x)
170
+ else:
171
+ x = self.embedding(x)
172
+ for layer in self.layers:
173
+ x = layer(x, attention_mask=attention_mask)
174
+ return x
175
+ ###==================================================================================================================###
176
+
177
+ class EncoderLayer(torch.nn.Module):
178
+ """Single transformer encoder layer with multi-head attention and feedforward network.
179
+
180
+ Used in the custom transformer of `TextEncoder` to process embedded text prompts.
181
+
182
+ Parameters
183
+ ----------
184
+ input_dimension : int
185
+ Input embedding dimension.
186
+ output_dimension : int
187
+ Output embedding dimension.
188
+ num_heads : int
189
+ Number of attention heads.
190
+ dropout_rate : float
191
+ Dropout rate for attention and feedforward layers.
192
+ qkv_bias : bool
193
+ If True, includes bias in query, key, and value projections.
194
+ scaling_value : int
195
+ Scaling factor for the feedforward layer’s hidden dimension.
196
+ epsilon : float, optional
197
+ Epsilon for layer normalization (default: 1e-5).
198
+
199
+ **Notes**
200
+
201
+ - The layer follows the standard transformer encoder architecture: attention,
202
+ residual connection, normalization, feedforward, residual connection,
203
+ normalization.
204
+ - The attention mechanism uses `batch_first=True` for compatibility with
205
+ `TextEncoder`’s input format.
206
+ """
207
+ def __init__(
208
+ self,
209
+ input_dimension: int,
210
+ output_dimension: int,
211
+ num_heads: int,
212
+ dropout_rate: float,
213
+ qkv_bias: bool,
214
+ scaling_value: int,
215
+ epsilon: float = 1e-5
216
+ ) -> None:
217
+ super().__init__()
218
+ self.attention = nn.MultiheadAttention(
219
+ embed_dim=input_dimension,
220
+ num_heads=num_heads,
221
+ dropout=dropout_rate,
222
+ bias=qkv_bias,
223
+ batch_first=True
224
+ )
225
+ self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
226
+ self.norm1 = self.norm1 = nn.LayerNorm(normalized_shape=input_dimension, eps=epsilon)
227
+ self.dropout1 = nn.Dropout(dropout_rate)
228
+ self.feedforward = FeedForward(
229
+ embedding_dimension=input_dimension,
230
+ scaling_value=scaling_value,
231
+ dropout_rate=dropout_rate
232
+ )
233
+ self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
234
+ self.dropout2 = nn.Dropout(dropout_rate)
235
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
236
+ """Processes input embeddings through attention and feedforward layers.
237
+
238
+ Parameters
239
+ ----------
240
+ x : torch.Tensor
241
+ Input embeddings, shape (batch_size, seq_len, input_dimension).
242
+ attention_mask : torch.Tensor, optional
243
+ Attention mask, shape (batch_size, seq_len), where 0 indicates padding
244
+ tokens to ignore (default: None).
245
+
246
+ Returns
247
+ -------
248
+ x (torch.Tensor) - Processed embeddings, shape (batch_size, seq_len, output_dimension).
249
+
250
+ **Notes**
251
+
252
+ - The attention mask is passed as `key_padding_mask` to
253
+ `nn.MultiheadAttention`, where 0 indicates padding tokens.
254
+ - Residual connections and normalization are applied after attention and
255
+ feedforward layers.
256
+ """
257
+ attn_output, _ = self.attention(x, key_padding_mask=attention_mask)
258
+ attn_output = self.output_projection(attn_output)
259
+ x = self.norm1(x + self.dropout1(attn_output))
260
+ ff_output = self.feedforward(x)
261
+ x = self.norm2(x + self.dropout2(ff_output))
262
+ return x
263
+
264
+ ###==================================================================================================================###
265
+
266
+ class FeedForward(torch.nn.Module):
267
+ """Feedforward network for transformer encoder layers.
268
+
269
+ Used in `EncoderLayer` to process attention outputs with a two-layer MLP and GELU
270
+ activation.
271
+
272
+ Parameters
273
+ ----------
274
+ embedding_dimension : int
275
+ Input and output embedding dimension.
276
+ scaling_value : int
277
+ Scaling factor for the hidden layer’s dimension (hidden_dim =
278
+ embedding_dimension * scaling_value).
279
+ dropout_rate : float, optional
280
+ Dropout rate after the hidden layer (default: 0.1).
281
+
282
+
283
+ **Notes**
284
+
285
+ - The hidden layer dimension is `embedding_dimension * scaling_value`, following
286
+ standard transformer feedforward designs.
287
+ - GELU activation is used for non-linearity.
288
+ """
289
+ def __init__(self, embedding_dimension: int, scaling_value: int, dropout_rate: float = 0.1) -> None:
290
+ super().__init__()
291
+ self.layers = torch.nn.Sequential(
292
+ torch.nn.Linear(
293
+ in_features=embedding_dimension,
294
+ out_features=embedding_dimension * scaling_value,
295
+ bias=True
296
+ ),
297
+ torch.nn.GELU(),
298
+ torch.nn.Dropout(dropout_rate),
299
+ torch.nn.Linear(
300
+ in_features=embedding_dimension * scaling_value,
301
+ out_features=embedding_dimension,
302
+ bias=True
303
+ )
304
+ )
305
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
306
+ """Processes input embeddings through the feedforward network.
307
+
308
+ Parameters
309
+ ----------
310
+ x : torch.Tensor
311
+ Input embeddings, shape (batch_size, seq_len, embedding_dimension).
312
+
313
+ Returns
314
+ -------
315
+ x (torch.Tensor) - Processed embeddings, shape (batch_size, seq_len, embedding_dimension).
316
+ """
317
+ return self.layers(x)
318
+
319
+ ###==================================================================================================================###
320
+
321
+
322
+ class Attention(nn.Module):
323
+ """Attention module for NoisePredictor, supporting text conditioning or self-attention.
324
+
325
+ Applies multi-head attention to enhance features, with optional text embeddings for
326
+ conditional generation.
327
+
328
+ Parameters
329
+ ----------
330
+ in_channels : int
331
+ Number of input channels (embedding dimension for attention).
332
+ y_embed_dim : int, optional
333
+ Dimensionality of text embeddings (default: 768).
334
+ num_heads : int, optional
335
+ Number of attention heads (default: 4).
336
+ num_groups : int, optional
337
+ Number of groups for group normalization (default: 8).
338
+ dropout_rate : float, optional
339
+ Dropout rate for attention and output (default: 0.1).
340
+
341
+ Attributes
342
+ ----------
343
+ in_channels : int
344
+ Input channel dimension.
345
+ y_embed_dim : int
346
+ Text embedding dimension.
347
+ num_heads : int
348
+ Number of attention heads.
349
+ dropout_rate : float
350
+ Dropout rate.
351
+ attention : torch.nn.MultiheadAttention
352
+ Multi-head attention with `batch_first=True`.
353
+ norm : torch.nn.GroupNorm
354
+ Group normalization before attention.
355
+ dropout : torch.nn.Dropout
356
+ Dropout layer for output.
357
+ y_projection : torch.nn.Linear
358
+ Projection for text embeddings to match `in_channels`.
359
+
360
+ Raises
361
+ ------
362
+ AssertionError
363
+ If input channels do not match `in_channels`.
364
+ ValueError
365
+ If text embeddings (`y`) have incorrect dimensions after projection.
366
+ """
367
+ def __init__(
368
+ self,
369
+ in_channels: int,
370
+ y_embed_dim: int = 768,
371
+ num_heads: int = 4,
372
+ num_groups: int = 8,
373
+ dropout_rate: float = 0.1
374
+ ) -> None:
375
+ super().__init__()
376
+ self.in_channels = in_channels
377
+ self.y_embed_dim = y_embed_dim
378
+ self.num_heads = num_heads
379
+ self.dropout_rate = dropout_rate
380
+ self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True)
381
+ self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
382
+ self.dropout = nn.Dropout(dropout_rate)
383
+ self.y_projection = nn.Linear(y_embed_dim, in_channels)
384
+
385
+ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
386
+ """Applies attention to input features with optional text conditioning.
387
+
388
+ Parameters
389
+ ----------
390
+ x : torch.Tensor
391
+ Input tensor, shape (batch_size, in_channels, height, width).
392
+ y : torch.Tensor, optional
393
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
394
+ (batch_size, y_embed_dim) (default: None).
395
+
396
+ Returns
397
+ -------
398
+ torch.Tensor
399
+ Output tensor, same shape as input `x`.
400
+ """
401
+ batch_size, channels, h, w = x.shape
402
+ assert channels == self.in_channels, f"Expected {self.in_channels} channels, got {channels}"
403
+ x_reshaped = x.view(batch_size, channels, h * w).permute(0, 2, 1)
404
+ if y is not None:
405
+ y = self.y_projection(y)
406
+ if y.dim() != 3:
407
+ if y.dim() == 2:
408
+ y = y.unsqueeze(1)
409
+ else:
410
+ raise ValueError(
411
+ f"Expected y to be 2D or 3D after projection, got {y.dim()}D with shape {y.shape}"
412
+ )
413
+ if y.shape[-1] != self.in_channels:
414
+ raise ValueError(
415
+ f"Expected y's embedding dim to match in_channels ({self.in_channels}), got {y.shape[-1]}"
416
+ )
417
+ out, _ = self.attention(x_reshaped, y, y)
418
+ else:
419
+ out, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
420
+ out = out.permute(0, 2, 1).view(batch_size, channels, h, w)
421
+ out = self.norm(out)
422
+ out = self.dropout(out)
423
+ return out
424
+
425
+
426
+ ###==================================================================================================================###
427
+
428
+
429
+ class Embedding(nn.Module):
430
+ """Token and positional embedding layer for transformer inputs.
431
+
432
+ Used in `TextEncoder`’s transformer to embed token IDs and add positional encodings.
433
+
434
+ Parameters
435
+ ----------
436
+ vocabulary_size : int
437
+ Size of the vocabulary for token embeddings.
438
+ embedding_dimension : int, optional
439
+ Dimension of token and positional embeddings (default: 768).
440
+ max_context_length : int, optional
441
+ Maximum sequence length for precomputing positional encodings (default: 77).
442
+ use_learned_pos : bool, optional
443
+ If True, uses learnable positional embeddings instead of sinusoidal encodings
444
+ (default: False).
445
+
446
+ **Notes**
447
+
448
+ - Supports both sinusoidal (fixed) and learned positional embeddings, selectable via
449
+ `use_learned_pos`.
450
+ - Sinusoidal encodings follow the transformer architecture, computed on-the-fly for
451
+ memory efficiency and cached for sequences up to `max_context_length`.
452
+ - Learned positional embeddings are initialized as a learnable parameter for flexibility.
453
+ - Optimized for device-agnostic operation, ensuring seamless CPU/GPU transitions.
454
+ - The output shape is (batch_size, seq_len, embedding_dimension).
455
+ """
456
+ def __init__(
457
+ self,
458
+ vocabulary_size: int,
459
+ embedding_dimension: int = 768,
460
+ max_context_length: int = 77,
461
+ use_learned_pos: bool = False
462
+ ) -> None:
463
+ super().__init__()
464
+ self.vocabulary_size = vocabulary_size
465
+ self.embedding_dimension = embedding_dimension
466
+ self.max_context_length = max_context_length
467
+ self.use_learned_pos = use_learned_pos
468
+
469
+ # Token embedding layer
470
+ self.token_embedding = nn.Embedding(
471
+ num_embeddings=vocabulary_size,
472
+ embedding_dim=embedding_dimension
473
+ )
474
+
475
+ if use_learned_pos:
476
+ # Learnable positional embeddings
477
+ self.positional_embedding = nn.Parameter(
478
+ torch.randn(1, max_context_length, embedding_dimension) / math.sqrt(embedding_dimension)
479
+ )
480
+ else:
481
+ # Register buffer for sinusoidal encodings
482
+ self.register_buffer(
483
+ "positional_encoding_cache",
484
+ torch.empty(1, 0, embedding_dimension, dtype=torch.float32)
485
+ )
486
+
487
+ def _generate_positional_encoding(self, seq_len: int, device: torch.device) -> torch.Tensor:
488
+ """Generates sinusoidal positional encodings for transformer inputs.
489
+
490
+ Computes positional encodings using sine and cosine functions.
491
+
492
+ Parameters
493
+ ----------
494
+ seq_len : int
495
+ Length of the sequence for which to generate positional encodings.
496
+ device : torch.device
497
+ Device on which to create the positional encodings.
498
+
499
+ Returns
500
+ -------
501
+ torch.Tensor
502
+ Positional encodings, shape (1, seq_len, embedding_dimension), where
503
+ even-indexed dimensions use sine and odd-indexed dimensions use cosine.
504
+
505
+ **Notes**
506
+
507
+ - Uses the formula: for position `pos` and dimension `i`,
508
+ `PE(pos, 2i) = sin(pos / 10000^(2i/d))` and
509
+ `PE(pos, 2i+1) = cos(pos / 10000^(2i/d))`, where `d` is `embedding_dimension`.
510
+ - Fully vectorized for efficiency and supports any sequence length.
511
+ """
512
+ position = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1)
513
+ div_term = torch.exp(
514
+ torch.arange(0, self.embedding_dimension, 2, dtype=torch.float32, device=device) *
515
+ (-math.log(10000.0) / self.embedding_dimension)
516
+ )
517
+ pos_enc = torch.zeros((1, seq_len, self.embedding_dimension), dtype=torch.float32, device=device)
518
+ pos_enc[:, :, 0::2] = torch.sin(position * div_term)
519
+ pos_enc[:, :, 1::2] = torch.cos(position * div_term[:, :-1] if self.embedding_dimension % 2 else div_term)
520
+ return pos_enc
521
+
522
+ def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
523
+ """Embeds token IDs and adds positional encodings.
524
+
525
+ Parameters
526
+ ----------
527
+ token_ids : torch.Tensor
528
+ Token IDs, shape (batch_size, seq_len).
529
+
530
+ Returns
531
+ -------
532
+ torch.Tensor
533
+ Embedded tokens with positional encodings, shape
534
+ (batch_size, seq_len, embedding_dimension).
535
+
536
+ **Notes**
537
+
538
+ - Automatically handles sequences longer than `max_context_length` by generating
539
+ positional encodings on-the-fly.
540
+ - For learned positional embeddings, sequences longer than `max_context_length`
541
+ will raise an error unless truncated.
542
+ - Ensures device compatibility by generating encodings on the input’s device.
543
+ """
544
+ assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
545
+ batch_size, seq_len = token_ids.size()
546
+ device = token_ids.device
547
+
548
+ # Compute token embeddings
549
+ token_embedded = self.token_embedding(token_ids)
550
+
551
+ # Handle positional embeddings
552
+ if self.use_learned_pos:
553
+ if seq_len > self.max_context_length:
554
+ raise ValueError(
555
+ f"Sequence length ({seq_len}) exceeds max_context_length ({self.max_context_length}) "
556
+ "for learned positional embeddings."
557
+ )
558
+ position_encoded = self.positional_embedding[:, :seq_len, :]
559
+ else:
560
+ # Use cached sinusoidal encodings if available and sufficient
561
+ if (self.positional_encoding_cache.size(1) < seq_len or
562
+ self.positional_encoding_cache.device != device):
563
+ self.positional_encoding_cache = self._generate_positional_encoding(
564
+ max(seq_len, self.max_context_length), device
565
+ )
566
+ position_encoded = self.positional_encoding_cache[:, :seq_len, :]
567
+
568
+ return token_embedded + position_encoded
569
+
570
+
571
+ ###==================================================================================================================###
572
+
573
+
574
+ class NoisePredictor(nn.Module):
575
+ """U-Net-like architecture for noise prediction in Diffusion Models.
576
+
577
+ Predicts noise for diffusion models (DDPM, DDIM, SDE), incorporating
578
+ time embeddings and optional text conditioning. used as the `noise_predictor` in
579
+ `Train` and `Sample` from the `ldm`, `sde`, `ddpm`, `ddim` modules.
580
+
581
+ Parameters
582
+ ----------
583
+ in_channels : int
584
+ Number of input channels.
585
+ down_channels : list of int
586
+ List of output channels for downsampling blocks.
587
+ mid_channels : list of int
588
+ List of channels for middle blocks.
589
+ up_channels : list of int
590
+ List of output channels for upsampling blocks.
591
+ down_sampling : list of bool
592
+ List indicating whether to downsample in each down block.
593
+ time_embed_dim : int
594
+ Dimensionality of time embeddings.
595
+ y_embed_dim : int
596
+ Dimensionality of text embeddings for conditioning.
597
+ num_down_blocks : int
598
+ Number of convolutional layer pairs per down block.
599
+ num_mid_blocks : int
600
+ Number of convolutional layer pairs per middle block.
601
+ num_up_blocks : int
602
+ Number of convolutional layer pairs per up block.
603
+ dropout_rate : float, optional
604
+ Dropout rate for convolutional and attention layers (default: 0.1).
605
+ down_sampling_factor : int, optional
606
+ Factor for spatial downsampling/upsampling (default: 2).
607
+ where_y : bool, optional
608
+ If True, text embeddings are used in attention; if False, concatenated to input
609
+ (default: True).
610
+ y_to_all : bool, optional
611
+ If True, apply text-conditioned attention to all layers; if False, only first layer
612
+ (default: False).
613
+
614
+ **Notes**
615
+
616
+ - The architecture follows a U-Net structure with downsampling, bottleneck, and
617
+ upsampling blocks, incorporating time embeddings and optional text conditioning via
618
+ attention or concatenation.
619
+ - Skip connections link down and up blocks, with channel adjustments for concatenation.
620
+ - Weights are initialized with Kaiming normal (Leaky ReLU nonlinearity) for stability.
621
+ - Input and output tensors have the same shape.
622
+ """
623
+ def __init__(
624
+ self,
625
+ in_channels: int,
626
+ down_channels: List[int],
627
+ mid_channels: List[int],
628
+ up_channels: List[int],
629
+ down_sampling: List[bool],
630
+ time_embed_dim: int,
631
+ y_embed_dim: int,
632
+ num_down_blocks: int,
633
+ num_mid_blocks: int,
634
+ num_up_blocks: int,
635
+ dropout_rate: float = 0.1,
636
+ down_sampling_factor: int = 2,
637
+ where_y: bool = True,
638
+ y_to_all: bool = False
639
+ ) -> None:
640
+ super().__init__()
641
+ self.in_channels = in_channels
642
+ self.down_channels = down_channels
643
+ self.mid_channels = mid_channels
644
+ self.up_channels = up_channels
645
+ self.down_sampling = down_sampling
646
+ self.time_embed_dim = time_embed_dim
647
+ self.y_embed_dim = y_embed_dim
648
+ self.num_down_blocks = num_down_blocks
649
+ self.num_mid_blocks = num_mid_blocks
650
+ self.num_up_blocks = num_up_blocks
651
+ self.dropout_rate = dropout_rate
652
+ self.where_y = where_y
653
+ self.up_sampling = list(reversed(self.down_sampling))
654
+ self.conv1 = nn.Conv2d(
655
+ in_channels=self.in_channels,
656
+ out_channels=self.down_channels[0],
657
+ kernel_size=3,
658
+ padding=1
659
+ )
660
+ # initial time embedding projection
661
+ self.time_projection = nn.Sequential(
662
+ nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim),
663
+ nn.SiLU(),
664
+ nn.Linear(in_features=self.time_embed_dim, out_features=self.time_embed_dim)
665
+ )
666
+ # down blocks
667
+ self.down_blocks = nn.ModuleList([
668
+ DownBlock(
669
+ in_channels=self.down_channels[i],
670
+ out_channels=self.down_channels[i+1],
671
+ time_embed_dim=self.time_embed_dim,
672
+ y_embed_dim=y_embed_dim,
673
+ num_layers=self.num_down_blocks,
674
+ down_sampling_factor=down_sampling_factor,
675
+ down_sample=self.down_sampling[i],
676
+ dropout_rate=self.dropout_rate,
677
+ y_to_all=y_to_all
678
+ ) for i in range(len(self.down_channels)-1)
679
+ ])
680
+ # middle blocks
681
+ self.mid_blocks = nn.ModuleList([
682
+ MiddleBlock(
683
+ in_channels=self.mid_channels[i],
684
+ out_channels=self.mid_channels[i + 1],
685
+ time_embed_dim=self.time_embed_dim,
686
+ y_embed_dim=y_embed_dim,
687
+ num_layers=self.num_mid_blocks,
688
+ dropout_rate=self.dropout_rate,
689
+ y_to_all=y_to_all
690
+ ) for i in range(len(self.mid_channels) - 1)
691
+ ])
692
+ # up blocks
693
+ skip_channels = list(reversed(self.down_channels))
694
+ self.up_blocks = nn.ModuleList([
695
+ UpBlock(
696
+ in_channels=self.up_channels[i],
697
+ out_channels=self.up_channels[i+1],
698
+ skip_channels=skip_channels[i],
699
+ time_embed_dim=self.time_embed_dim,
700
+ y_embed_dim=y_embed_dim,
701
+ num_layers=self.num_up_blocks,
702
+ up_sampling_factor=down_sampling_factor,
703
+ up_sampling=self.up_sampling[i],
704
+ dropout_rate=self.dropout_rate,
705
+ y_to_all=y_to_all
706
+ ) for i in range(len(self.up_channels)-1)
707
+ ])
708
+ # final convolution layer
709
+ self.conv2 = nn.Sequential(
710
+ nn.GroupNorm(num_groups=8, num_channels=self.up_channels[-1]),
711
+ nn.Dropout(p=self.dropout_rate),
712
+ nn.Conv2d(in_channels=self.up_channels[-1], out_channels=self.in_channels, kernel_size=3, padding=1)
713
+ )
714
+
715
+ def initialize_weights(self) -> None:
716
+ """Initializes model weights for training stability.
717
+
718
+ Applies Kaiming normal initialization to convolutional and linear layers with
719
+ Leaky ReLU nonlinearity (a=0.2), and zeros biases.
720
+ """
721
+ for module in self.modules():
722
+ if isinstance(module, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
723
+ nn.init.kaiming_normal_(module.weight, a=0.2, nonlinearity='leaky_relu')
724
+ if module.bias is not None:
725
+ nn.init.zeros_(module.bias)
726
+
727
+ def forward(self, x: torch.Tensor, t: torch.Tensor, y: torch.Tensor = None, clip_embeddings: torch.Tensor = None) -> torch.Tensor:
728
+ """Predicts noise given input, time step, and optional text conditioning.
729
+
730
+ Parameters
731
+ ----------
732
+ x : torch.Tensor
733
+ Input tensor, shape (batch_size, in_channels, height, width).
734
+ t : torch.Tensor
735
+ Time steps, shape (batch_size,).
736
+ y : torch.Tensor, optional
737
+ Text embeddings for conditioning, shape (batch_size, seq_len, y_embed_dim)
738
+ or (batch_size, y_embed_dim) (default: None).
739
+ clip_embeddings: torch.Tensor, optional
740
+ used in the context of un-clip algorithm
741
+
742
+ Returns
743
+ -------
744
+ output (torch.Tensor) - Predicted noise, same shape as input `x`.
745
+ """
746
+ if not self.where_y and y is not None:
747
+ x = torch.cat(tensors=[x, y], dim=1)
748
+ output = self.conv1(x)
749
+ time_embed = GetEmbeddedTime(embed_dim=self.time_embed_dim)(time_steps=t)
750
+ time_embed = self.time_projection(time_embed)
751
+ #print("time embedded size (before concat): ", time_embed.size())
752
+ #print("time embed:", time_embed.size())
753
+ #print("clip embed (before)", clip_embeddings.size())
754
+ if clip_embeddings is not None:
755
+ # if len(clip_embeddings.shape) == 3: # [batch_size, seq_len, time_embed_dim]
756
+ # #clip_embeddings = clip_embeddings.mean(dim=1)
757
+ # time_embed = time_embed.unsqueeze(1)
758
+ # #print("please print it", time_embed.size())
759
+
760
+ #print("clip imbed (after)", clip_embeddings.size())
761
+ time_embed = time_embed + clip_embeddings
762
+ #print("time embedded size (after concat): ", time_embed.size())
763
+
764
+ skip_connections = []
765
+ for i, down in enumerate(self.down_blocks):
766
+ skip_connections.append(output)
767
+ output = down(x=output, embed_time=time_embed, y=y)
768
+ for i, mid in enumerate(self.mid_blocks):
769
+ output = mid(x=output, embed_time=time_embed, y=y)
770
+ for i, up in enumerate(self.up_blocks):
771
+ skip_connection = skip_connections.pop()
772
+ output = up(x=output, skip_connection=skip_connection, embed_time=time_embed, y=y)
773
+
774
+ output = self.conv2(output)
775
+ return output
776
+
777
+ ###==================================================================================================================###
778
+
779
+ class DownBlock(nn.Module):
780
+ """Downsampling block for NoisePredictor’s encoder.
781
+
782
+ Applies convolutional layers with residual connections, time embeddings, and optional
783
+ text-conditioned attention, followed by downsampling if enabled.
784
+
785
+ Parameters
786
+ ----------
787
+ in_channels : int
788
+ Number of input channels.
789
+ out_channels : int
790
+ Number of output channels.
791
+ time_embed_dim : int
792
+ Dimensionality of time embeddings.
793
+ y_embed_dim : int
794
+ Dimensionality of text embeddings.
795
+ num_layers : int
796
+ Number of convolutional layer pairs (Conv3).
797
+ down_sampling_factor : int
798
+ Factor for spatial downsampling.
799
+ down_sample : bool
800
+ If True, apply downsampling; if False, use identity (no downsampling).
801
+ dropout_rate : float
802
+ Dropout rate for Conv3 and attention layers.
803
+ y_to_all : bool
804
+ If True, apply text-conditioned attention to all layers; if False, only first layer.
805
+ """
806
+ def __init__(
807
+ self,
808
+ in_channels: int,
809
+ out_channels: int ,
810
+ time_embed_dim: int,
811
+ y_embed_dim: int,
812
+ num_layers: int,
813
+ down_sampling_factor: int,
814
+ down_sample: bool,
815
+ dropout_rate: float,
816
+ y_to_all: bool
817
+ ) -> None:
818
+ super().__init__()
819
+ self.num_layers = num_layers
820
+ self.y_to_all = y_to_all
821
+ self.conv1 = nn.ModuleList([
822
+ Conv3(
823
+ in_channels=in_channels if i==0 else out_channels,
824
+ out_channels=out_channels,
825
+ num_groups=8,
826
+ kernel_size=3,
827
+ norm=True,
828
+ activation=True,
829
+ dropout_rate=dropout_rate
830
+ ) for i in range(self.num_layers)
831
+ ])
832
+ self.conv2 = nn.ModuleList([
833
+ Conv3(
834
+ in_channels=out_channels,
835
+ out_channels=out_channels,
836
+ num_groups=8,
837
+ kernel_size=3,
838
+ norm=True,
839
+ activation=True,
840
+ dropout_rate=dropout_rate
841
+ ) for _ in range(self.num_layers)
842
+ ])
843
+ self.time_embedding = nn.ModuleList([
844
+ TimeEmbedding(
845
+ output_dim=out_channels,
846
+ embed_dim=time_embed_dim
847
+ ) for _ in range(self.num_layers)
848
+ ])
849
+ self.attention = nn.ModuleList([
850
+ Attention(
851
+ in_channels=out_channels,
852
+ y_embed_dim=y_embed_dim,
853
+ num_groups=8,
854
+ num_heads=4,
855
+ dropout_rate=dropout_rate
856
+ ) for _ in range(self.num_layers)
857
+ ])
858
+ self.down_sampling = DownSampling(
859
+ in_channels=out_channels,
860
+ out_channels=out_channels,
861
+ down_sampling_factor=down_sampling_factor,
862
+ conv_block=True,
863
+ max_pool=True
864
+ ) if down_sample else nn.Identity()
865
+ self.resnet = nn.ModuleList([
866
+ nn.Conv2d(
867
+ in_channels=in_channels if i == 0 else out_channels,
868
+ out_channels=out_channels,
869
+ kernel_size=1
870
+ ) for i in range(num_layers)
871
+
872
+ ])
873
+
874
+ def forward(self, x: torch.Tensor, embed_time: torch.Tensor, y: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
875
+ """Processes input through convolutions, time embeddings, attention, and downsampling.
876
+
877
+ Parameters
878
+ ----------
879
+ x : torch.Tensor
880
+ Input tensor, shape (batch_size, in_channels, height, width).
881
+ embed_time : torch.Tensor
882
+ Time embeddings, shape (batch_size, time_embed_dim).
883
+ y : torch.Tensor, optional
884
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
885
+ (batch_size, y_embed_dim) (default: None).
886
+ key_padding_mask : torch.Tensor, optional
887
+ Boolean mask, shape (batch_size, seq_len) if `y` is None, or (batch_size, seq_len_y) if `y` is provided,
888
+ where `True` indicates positions to mask out (default: None).
889
+
890
+ Returns
891
+ -------
892
+ output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor) if downsampling; otherwise, same height/width as input.
893
+ """
894
+ output = x
895
+ for i in range(self.num_layers):
896
+ resnet_input = output
897
+ output = self.conv1[i](output)
898
+ output = output + self.time_embedding[i](embed_time)[:, :, None, None]
899
+ output = self.conv2[i](output)
900
+ output = output + self.resnet[i](resnet_input)
901
+
902
+ if not self.y_to_all and i == 0:
903
+ out_attn = self.attention[i](output, y)
904
+ output = output + out_attn
905
+ elif self.y_to_all:
906
+ out_attn = self.attention[i](output, y)
907
+ output = output + out_attn
908
+
909
+ output = self.down_sampling(output)
910
+ return output
911
+
912
+ ###==================================================================================================================###
913
+
914
+ class MiddleBlock(nn.Module):
915
+ """Bottleneck block for NoisePredictor’s middle layers.
916
+
917
+ Applies convolutional layers with residual connections, time embeddings, and optional
918
+ text-conditioned attention, preserving spatial dimensions.
919
+
920
+ Parameters
921
+ ----------
922
+ in_channels : int
923
+ Number of input channels.
924
+ out_channels : int
925
+ Number of output channels.
926
+ time_embed_dim : int
927
+ Dimensionality of time embeddings.
928
+ y_embed_dim : int
929
+ Dimensionality of text embeddings.
930
+ num_layers : int
931
+ Number of convolutional layer pairs (Conv3).
932
+ dropout_rate : float
933
+ Dropout rate for Conv3 and attention layers.
934
+ y_to_all : bool
935
+ If True, apply text-conditioned attention to all layers; if False, only first layer
936
+ (default: False).
937
+ """
938
+ def __init__(
939
+ self,
940
+ in_channels: int,
941
+ out_channels: int,
942
+ time_embed_dim: int,
943
+ y_embed_dim: int,
944
+ num_layers: int,
945
+ dropout_rate: float,
946
+ y_to_all: bool
947
+ ) -> None:
948
+ super().__init__()
949
+ self.num_layers = num_layers
950
+ self.y_to_all = y_to_all
951
+ self.conv1 = nn.ModuleList([
952
+ Conv3(
953
+ in_channels=in_channels if i == 0 else out_channels,
954
+ out_channels=out_channels,
955
+ num_groups=8,
956
+ kernel_size=3,
957
+ norm=True,
958
+ activation=True,
959
+ dropout_rate=dropout_rate
960
+ ) for i in range(self.num_layers+1)
961
+ ])
962
+ self.conv2 = nn.ModuleList([
963
+ Conv3(
964
+ in_channels=out_channels,
965
+ out_channels=out_channels,
966
+ num_groups=8,
967
+ kernel_size=3,
968
+ norm=True,
969
+ activation=True,
970
+ dropout_rate=dropout_rate
971
+ ) for _ in range(self.num_layers+1)
972
+ ])
973
+ self.time_embedding = nn.ModuleList([
974
+ TimeEmbedding(
975
+ output_dim=out_channels,
976
+ embed_dim=time_embed_dim
977
+ ) for _ in range(self.num_layers+1)
978
+ ])
979
+ self.attention = nn.ModuleList([
980
+ Attention(
981
+ in_channels=out_channels,
982
+ y_embed_dim=y_embed_dim,
983
+ num_groups=8,
984
+ num_heads=4,
985
+ dropout_rate=dropout_rate
986
+ ) for _ in range(self.num_layers + 1)
987
+ ])
988
+ self.resnet = nn.ModuleList([
989
+ nn.Conv2d(
990
+ in_channels=in_channels if i == 0 else out_channels,
991
+ out_channels=out_channels,
992
+ kernel_size=1
993
+ ) for i in range(num_layers+1)
994
+ ])
995
+
996
+ def forward(self, x: torch.Tensor, embed_time: torch.Tensor, y: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
997
+ """Processes input through convolutions, time embeddings, and attention.
998
+
999
+ Parameters
1000
+ ----------
1001
+ x : torch.Tensor
1002
+ Input tensor, shape (batch_size, in_channels, height, width).
1003
+ embed_time : torch.Tensor
1004
+ Time embeddings, shape (batch_size, time_embed_dim).
1005
+ y : torch.Tensor, optional
1006
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
1007
+ (batch_size, y_embed_dim) (default: None).
1008
+ key_padding_mask : torch.Tensor, optional
1009
+ Boolean mask, shape (batch_size, seq_len) if `y` is None, or (batch_size, seq_len_y) if `y` is provided,
1010
+ where `True` indicates positions to mask out (default: None).
1011
+
1012
+ Returns
1013
+ -------
1014
+ output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height, width).
1015
+ """
1016
+ output = x
1017
+ resnet_input = output
1018
+ output = self.conv1[0](output)
1019
+ output = output + self.time_embedding[0](embed_time)[:, :, None, None]
1020
+ output = self.conv2[0](output)
1021
+ output = output + self.resnet[0](resnet_input)
1022
+
1023
+ for i in range(self.num_layers):
1024
+ if not self.y_to_all and i == 0:
1025
+ out_attn = self.attention[i](output, y)
1026
+ output = output + out_attn
1027
+ elif self.y_to_all:
1028
+ out_attn = self.attention[i](output, y)
1029
+ output = output + out_attn
1030
+ resnet_input = output
1031
+ output = self.conv1[i + 1](output)
1032
+ output = output + self.time_embedding[i + 1](embed_time)[:, :, None, None]
1033
+ output = self.conv2[i + 1](output)
1034
+ output = output + self.resnet[i+1](resnet_input)
1035
+ return output
1036
+
1037
+ ###==================================================================================================================###
1038
+
1039
+ class UpBlock(nn.Module):
1040
+ """Upsampling block for NoisePredictor’s decoder.
1041
+
1042
+ Applies upsampling (if enabled), concatenates skip connections, and processes through
1043
+ convolutional layers with residual connections, time embeddings, and optional
1044
+ text-conditioned attention.
1045
+
1046
+ Parameters
1047
+ ----------
1048
+ in_channels : int
1049
+ Number of input channels (before upsampling).
1050
+ out_channels : int
1051
+ Number of output channels.
1052
+ skip_channels : int
1053
+ Number of channels from skip connection.
1054
+ time_embed_dim : int
1055
+ Dimensionality of time embeddings.
1056
+ y_embed_dim : int
1057
+ Dimensionality of text embeddings.
1058
+ num_layers : int
1059
+ Number of convolutional layer pairs (Conv3).
1060
+ up_sampling_factor : int
1061
+ Factor for spatial upsampling.
1062
+ up_sampling : bool
1063
+ If True, apply upsampling; if False, use identity (no upsampling).
1064
+ dropout_rate : float
1065
+ Dropout rate for Conv3 and attention layers.
1066
+ y_to_all : bool
1067
+ If True, apply text-conditioned attention to all layers; if False, only first layer
1068
+ (default: False).
1069
+ """
1070
+ def __init__(
1071
+ self,
1072
+ in_channels: int,
1073
+ out_channels: int,
1074
+ skip_channels: int,
1075
+ time_embed_dim: int,
1076
+ y_embed_dim: int,
1077
+ num_layers: int,
1078
+ up_sampling_factor: int,
1079
+ up_sampling: bool,
1080
+ dropout_rate: float,
1081
+ y_to_all: bool
1082
+ ) -> None:
1083
+ super().__init__()
1084
+ self.num_layers = num_layers
1085
+ self.y_to_all = y_to_all
1086
+ effective_in_channels = in_channels // 2 + skip_channels
1087
+ self.conv1 = nn.ModuleList([
1088
+ Conv3(
1089
+ in_channels=effective_in_channels if i == 0 else out_channels,
1090
+ out_channels=out_channels,
1091
+ num_groups=8,
1092
+ kernel_size=3,
1093
+ norm=True,
1094
+ activation=True,
1095
+ dropout_rate=dropout_rate
1096
+ ) for i in range(self.num_layers)
1097
+ ])
1098
+ self.conv2 = nn.ModuleList([
1099
+ Conv3(
1100
+ in_channels=out_channels,
1101
+ out_channels=out_channels,
1102
+ num_groups=8,
1103
+ kernel_size=3,
1104
+ norm=True,
1105
+ activation=True,
1106
+ dropout_rate=dropout_rate
1107
+ ) for _ in range(self.num_layers)
1108
+ ])
1109
+ self.time_embedding = nn.ModuleList([
1110
+ TimeEmbedding(
1111
+ output_dim=out_channels,
1112
+ embed_dim=time_embed_dim
1113
+ ) for _ in range(self.num_layers)
1114
+ ])
1115
+ self.attention = nn.ModuleList([
1116
+ Attention(
1117
+ in_channels=out_channels,
1118
+ y_embed_dim=y_embed_dim,
1119
+ num_groups=8,
1120
+ num_heads=4,
1121
+ dropout_rate=dropout_rate
1122
+ ) for _ in range(self.num_layers)
1123
+ ])
1124
+ self.up_sampling_ = UpSampling(
1125
+ in_channels=in_channels,
1126
+ out_channels=in_channels,
1127
+ up_sampling_factor=up_sampling_factor,
1128
+ conv_block=True,
1129
+ up_sampling=True
1130
+ ) if up_sampling else nn.Identity()
1131
+ self.resnet = nn.ModuleList([
1132
+ nn.Conv2d(
1133
+ in_channels=effective_in_channels if i == 0 else out_channels,
1134
+ out_channels=out_channels,
1135
+ kernel_size=1
1136
+ ) for i in range(num_layers)
1137
+
1138
+ ])
1139
+
1140
+ def forward(self, x: torch.Tensor, skip_connection: torch.Tensor, embed_time: torch.Tensor, y: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1141
+ """Processes input through upsampling, skip connection, convolutions, time embeddings, and attention.
1142
+
1143
+ Parameters
1144
+ ----------
1145
+ x : torch.Tensor
1146
+ Input tensor, shape (batch_size, in_channels, height, width).
1147
+ skip_connection : torch.Tensor
1148
+ Skip connection tensor, shape (batch_size, skip_channels,
1149
+ height*up_sampling_factor, width*up_sampling_factor).
1150
+ embed_time : torch.Tensor
1151
+ Time embeddings, shape (batch_size, time_embed_dim).
1152
+ y : torch.Tensor, optional
1153
+ Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
1154
+ (batch_size, y_embed_dim) (default: None).
1155
+ key_padding_mask : torch.Tensor, optional
1156
+ Boolean mask, shape (batch_size, seq_len) if `y` is None, or (batch_size, seq_len_y) if `y` is provided,
1157
+ where `True` indicates positions to mask out (default: None).
1158
+
1159
+ Returns
1160
+ -------
1161
+ output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height*up_sampling_factor, width*up_sampling_factor) if upsampling; otherwise, same height/width as input (after skip connection).
1162
+ """
1163
+ x = self.up_sampling_(x)
1164
+ x = torch.cat(tensors=[x, skip_connection], dim=1)
1165
+ output = x
1166
+ for i in range(self.num_layers):
1167
+ resnet_input = output
1168
+ output = self.conv1[i](output)
1169
+ output = output + self.time_embedding[i](embed_time)[:, :, None, None]
1170
+ output = self.conv2[i](output)
1171
+ output = output + self.resnet[i](resnet_input)
1172
+
1173
+ if not self.y_to_all and i == 0:
1174
+ out_attn = self.attention[i](output, y)
1175
+ output = output + out_attn
1176
+ elif self.y_to_all:
1177
+ out_attn = self.attention[i](output, y)
1178
+ output = output + out_attn
1179
+
1180
+ return output
1181
+
1182
+ ###==================================================================================================================###
1183
+
1184
+ class Conv3(nn.Module):
1185
+ """Convolutional layer with optional group normalization, SiLU activation, and dropout.
1186
+
1187
+ Used in DownBlock, MiddleBlock, and UpBlock for feature extraction in NoisePredictor.
1188
+
1189
+ Parameters
1190
+ ----------
1191
+ in_channels : int
1192
+ Number of input channels.
1193
+ out_channels : int
1194
+ Number of output channels.
1195
+ num_groups : int, optional
1196
+ Number of groups for group normalization (default: 8).
1197
+ kernel_size : int, optional
1198
+ Convolutional kernel size (default: 3).
1199
+ norm : bool, optional
1200
+ If True, apply group normalization (default: True).
1201
+ activation : bool, optional
1202
+ If True, apply SiLU activation (default: True).
1203
+ dropout_rate : float, optional
1204
+ Dropout rate (default: 0.2).
1205
+ """
1206
+ def __init__(
1207
+ self,
1208
+ in_channels: int,
1209
+ out_channels: int,
1210
+ num_groups: int = 8,
1211
+ kernel_size: int = 3,
1212
+ norm: bool = True,
1213
+ activation: bool = True,
1214
+ dropout_rate: float = 0.2
1215
+ ) -> None:
1216
+ super().__init__()
1217
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2)
1218
+ self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=out_channels) if norm else nn.Identity()
1219
+ self.activation = nn.SiLU() if activation else nn.Identity()
1220
+ self.dropout = nn.Dropout(p=dropout_rate)
1221
+
1222
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
1223
+ """Processes input through convolution, normalization, activation, and dropout.
1224
+
1225
+ Parameters
1226
+ ----------
1227
+ batch : torch.Tensor
1228
+ Input tensor, shape (batch_size, in_channels, height, width).
1229
+
1230
+ Returns
1231
+ -------
1232
+ batch (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height, width).
1233
+ """
1234
+ batch = self.conv(batch)
1235
+ batch = self.group_norm(batch)
1236
+ batch = self.activation(batch)
1237
+ batch = self.dropout(batch)
1238
+ return batch
1239
+
1240
+ ###==================================================================================================================###
1241
+
1242
+ class TimeEmbedding(nn.Module):
1243
+ """Time embedding projection for conditioning NoisePredictor layers.
1244
+
1245
+ Projects time embeddings to match the channel dimension of convolutional outputs.
1246
+
1247
+ Parameters
1248
+ ----------
1249
+ output_dim : int
1250
+ Output channel dimension (matches convolutional channels).
1251
+ embed_dim : int
1252
+ Input time embedding dimension.
1253
+ """
1254
+ def __init__(self, output_dim: int, embed_dim: int) -> None:
1255
+ super().__init__()
1256
+ self.embedding = nn.Sequential(
1257
+ nn.SiLU(),
1258
+ nn.Linear(in_features=embed_dim, out_features=output_dim)
1259
+ )
1260
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
1261
+ """Projects time embeddings to output dimension.
1262
+
1263
+ Parameters
1264
+ ----------
1265
+ batch : torch.Tensor
1266
+ Time embeddings, shape (batch_size, embed_dim).
1267
+
1268
+ Returns
1269
+ -------
1270
+ torch.Tensor
1271
+ Projected embeddings, shape (batch_size, output_dim).
1272
+ """
1273
+ return self.embedding(batch)
1274
+
1275
+ ###==================================================================================================================###
1276
+
1277
+ class GetEmbeddedTime(nn.Module):
1278
+ """Generates sinusoidal time embeddings for NoisePredictor.
1279
+
1280
+ Creates positional encodings for time steps using sine and cosine functions, following
1281
+ the transformer embedding approach.
1282
+
1283
+ Parameters
1284
+ ----------
1285
+ embed_dim : int
1286
+ Dimensionality of the time embeddings (must be even).
1287
+ """
1288
+ def __init__(self, embed_dim: int) -> None:
1289
+ super().__init__()
1290
+ assert embed_dim % 2 == 0, "The embedding dimension must be divisible by two"
1291
+ self.embed_dim = embed_dim
1292
+
1293
+ def forward(self, time_steps: torch.Tensor) -> torch.Tensor:
1294
+ """Generates sinusoidal embeddings for time steps.
1295
+
1296
+ Parameters
1297
+ ----------
1298
+ time_steps : torch.Tensor
1299
+ Time steps, shape (batch_size,).
1300
+
1301
+ Returns
1302
+ -------
1303
+ embed_time (torch.Tensor) - Sinusoidal embeddings, shape (batch_size, embed_dim).
1304
+ """
1305
+ i = torch.arange(start=0, end=self.embed_dim // 2, dtype=torch.float32, device=time_steps.device)
1306
+ factor = 10000 ** (2 * i / self.embed_dim)
1307
+ embed_time = time_steps[:, None] / factor
1308
+ embed_time = torch.cat(tensors=[torch.sin(embed_time), torch.cos(embed_time)], dim=-1)
1309
+ return embed_time
1310
+
1311
+ ###==================================================================================================================###
1312
+
1313
+
1314
+ class DownSampling(nn.Module):
1315
+ """Downsampling module for NoisePredictor’s DownBlock.
1316
+
1317
+ Combines convolutional downsampling and max pooling (if enabled), concatenating
1318
+ outputs to preserve feature information.
1319
+
1320
+ Parameters
1321
+ ----------
1322
+ in_channels : int
1323
+ Number of input channels.
1324
+ out_channels : int
1325
+ Number of output channels.
1326
+ down_sampling_factor : int
1327
+ Factor for spatial downsampling.
1328
+ conv_block : bool, optional
1329
+ If True, include convolutional path (default: True).
1330
+ max_pool : bool, optional
1331
+ If True, include max pooling path (default: True).
1332
+ """
1333
+ def __init__(
1334
+ self,
1335
+ in_channels: int,
1336
+ out_channels: int,
1337
+ down_sampling_factor: int,
1338
+ conv_block: bool = True,
1339
+ max_pool: bool = True
1340
+ ) -> None:
1341
+ super().__init__()
1342
+ self.conv_block = conv_block
1343
+ self.max_pool = max_pool
1344
+ self.down_sampling_factor = down_sampling_factor
1345
+ self.conv = nn.Sequential(
1346
+ nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1),
1347
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels // 2 if max_pool else out_channels,
1348
+ kernel_size=3, stride=down_sampling_factor, padding=1)
1349
+ ) if conv_block else nn.Identity()
1350
+ self.pool = nn.Sequential(
1351
+ nn.MaxPool2d(kernel_size=down_sampling_factor, stride=down_sampling_factor),
1352
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels//2 if conv_block else out_channels,
1353
+ kernel_size=1, stride=1, padding=0)
1354
+ ) if max_pool else nn.Identity()
1355
+
1356
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
1357
+ """Downsamples input using convolutional and/or pooling paths.
1358
+
1359
+ Parameters
1360
+ ----------
1361
+ batch : torch.Tensor
1362
+ Input tensor, shape (batch_size, in_channels, height, width).
1363
+
1364
+ Returns
1365
+ -------
1366
+ batch (torch.Tensor) - Downsampled tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor).
1367
+ """
1368
+ if not self.conv_block:
1369
+ return self.pool(batch)
1370
+ if not self.max_pool:
1371
+ return self.conv(batch)
1372
+ return torch.cat(tensors=[self.conv(batch), self.pool(batch)], dim=1)
1373
+
1374
+ ###==================================================================================================================###
1375
+
1376
+ class UpSampling(nn.Module):
1377
+ """Upsampling module for NoisePredictor’s UpBlock.
1378
+
1379
+ Combines transposed convolution and nearest-neighbor upsampling (if enabled),
1380
+ concatenating outputs to preserve feature information, with interpolation to align
1381
+ spatial dimensions if needed.
1382
+
1383
+ Parameters
1384
+ ----------
1385
+ in_channels : int
1386
+ Number of input channels.
1387
+ out_channels : int
1388
+ Number of output channels.
1389
+ up_sampling_factor : int
1390
+ Factor for spatial upsampling.
1391
+ conv_block : bool, optional
1392
+ If True, include transposed convolutional path (default: True).
1393
+ up_sampling : bool, optional
1394
+ If True, include nearest-neighbor upsampling path (default: True).
1395
+ """
1396
+ def __init__(
1397
+ self,
1398
+ in_channels: int,
1399
+ out_channels: int,
1400
+ up_sampling_factor: int,
1401
+ conv_block: bool = True,
1402
+ up_sampling: bool = True
1403
+ ) -> None:
1404
+ super().__init__()
1405
+ self.conv_block = conv_block
1406
+ self.up_sampling = up_sampling
1407
+ self.up_sampling_factor = up_sampling_factor
1408
+ half_out_channels = out_channels // 2
1409
+ self.conv = nn.Sequential(
1410
+ nn.ConvTranspose2d(
1411
+ in_channels=in_channels,
1412
+ out_channels=half_out_channels if up_sampling else out_channels,
1413
+ kernel_size=3,
1414
+ stride=up_sampling_factor,
1415
+ padding=1,
1416
+ output_padding=up_sampling_factor - 1
1417
+ ),
1418
+ nn.Conv2d(
1419
+ in_channels=half_out_channels if up_sampling else out_channels,
1420
+ out_channels=half_out_channels if up_sampling else out_channels,
1421
+ kernel_size=1,
1422
+ stride=1,
1423
+ padding=0
1424
+ )
1425
+ ) if conv_block else nn.Identity()
1426
+
1427
+ self.up_sample = nn.Sequential(
1428
+ nn.Upsample(scale_factor=up_sampling_factor, mode="nearest"),
1429
+ nn.Conv2d(in_channels=in_channels, out_channels=half_out_channels if conv_block else out_channels,
1430
+ kernel_size=1, stride=1, padding=0)
1431
+ ) if up_sampling else nn.Identity()
1432
+
1433
+ def forward(self, batch: torch.Tensor) -> torch.Tensor:
1434
+ """Upsamples input using convolutional and/or upsampling paths.
1435
+
1436
+ Parameters
1437
+ ----------
1438
+ batch : torch.Tensor
1439
+ Input tensor, shape (batch_size, in_channels, height, width).
1440
+
1441
+ Returns
1442
+ -------
1443
+ batch (torch.Tensor) - Upsampled tensor, shape (batch_size, out_channels, height*up_sampling_factor, width*up_sampling_factor).
1444
+
1445
+ **Notes**
1446
+
1447
+ - Interpolation is applied if the spatial dimensions of the convolutional and
1448
+ upsampling paths differ, using nearest-neighbor mode.
1449
+ """
1450
+ if not self.conv_block:
1451
+ return self.up_sample(batch)
1452
+ if not self.up_sampling:
1453
+ return self.conv(batch)
1454
+ conv_output = self.conv(batch)
1455
+ up_sample_output = self.up_sample(batch)
1456
+ if conv_output.shape[2:] != up_sample_output.shape[2:]:
1457
+ _, _, h, w = conv_output.shape
1458
+ up_sample_output = torch.nn.functional.interpolate(
1459
+ up_sample_output,
1460
+ size=(h, w),
1461
+ mode='nearest'
1462
+ )
1463
+ return torch.cat(tensors=[conv_output, up_sample_output], dim=1)
1464
+
1465
+ ###==================================================================================================================###
1466
+
1467
+ class Metrics:
1468
+ """Computes image quality metrics for evaluating diffusion models.
1469
+
1470
+ Supports Mean Squared Error (MSE), Peak Signal-to-Noise Ratio (PSNR), Structural
1471
+ Similarity Index (SSIM), Fréchet Inception Distance (FID), and Learned Perceptual
1472
+ Image Patch Similarity (LPIPS) for comparing generated and ground truth images.
1473
+
1474
+ Parameters
1475
+ ----------
1476
+ device : str, optional
1477
+ Device for computation (e.g., 'cuda', 'cpu') (default: 'cuda').
1478
+ fid : bool, optional
1479
+ If True, compute FID score (default: True).
1480
+ metrics : bool, optional
1481
+ If True, compute MSE, PSNR, and SSIM (default: False).
1482
+ lpips : bool, optional
1483
+ If True, compute LPIPS using VGG backbone (default: False).
1484
+ """
1485
+
1486
+ def __init__(
1487
+ self,
1488
+ device: str = "cuda",
1489
+ fid: bool = True,
1490
+ metrics: bool = False,
1491
+ lpips_: bool = False
1492
+ ) -> None:
1493
+ self.device = device
1494
+ self.fid = fid
1495
+ self.metrics = metrics
1496
+ self.lpips = lpips_
1497
+ self.lpips_model = LearnedPerceptualImagePatchSimilarity(
1498
+ net_type='vgg',
1499
+ normalize=True # This handles [0,1] -> [-1,1] conversion
1500
+ ).to(device) if self.lpips else None
1501
+ self.temp_dir_real = "temp_real"
1502
+ self.temp_dir_fake = "temp_fake"
1503
+
1504
+ def compute_fid(self, real_images: torch.Tensor, fake_images: torch.Tensor) -> float:
1505
+ """Computes the Fréchet Inception Distance (FID) between real and generated images.
1506
+
1507
+ Saves images to temporary directories and uses Inception V3 to compute FID,
1508
+ cleaning up directories afterward.
1509
+
1510
+ Parameters
1511
+ ----------
1512
+ real_images : torch.Tensor
1513
+ Real images, shape (batch_size, channels, height, width), in [-1, 1].
1514
+ fake_images : torch.Tensor
1515
+ Generated images, same shape, in [-1, 1].
1516
+
1517
+ Returns
1518
+ -------
1519
+ fid (float) - FID score, or `float('inf')` if computation fails.
1520
+
1521
+ **Notes**
1522
+
1523
+ - Images are normalized to [0, 1] and saved as PNG files for FID computation.
1524
+ - Uses Inception V3 with 2048-dimensional features (`dims=2048`).
1525
+ """
1526
+ if real_images.shape != fake_images.shape:
1527
+ raise ValueError(f"Shape mismatch: real_images {real_images.shape}, fake_images {fake_images.shape}")
1528
+
1529
+ real_images = (real_images + 1) / 2
1530
+ fake_images = (fake_images + 1) / 2
1531
+ real_images = real_images.clamp(0, 1).cpu()
1532
+ fake_images = fake_images.clamp(0, 1).cpu()
1533
+
1534
+ os.makedirs(self.temp_dir_real, exist_ok=True)
1535
+ os.makedirs(self.temp_dir_fake, exist_ok=True)
1536
+
1537
+ try:
1538
+ for i, (real, fake) in enumerate(zip(real_images, fake_images)):
1539
+ save_image(real, f"{self.temp_dir_real}/{i}.png")
1540
+ save_image(fake, f"{self.temp_dir_fake}/{i}.png")
1541
+
1542
+ fid = fid_score.calculate_fid_given_paths(
1543
+ paths=[self.temp_dir_real, self.temp_dir_fake],
1544
+ batch_size=50,
1545
+ device=self.device,
1546
+ dims=2048
1547
+ )
1548
+ except Exception as e:
1549
+ print(f"Error computing FID: {e}")
1550
+ fid = float('inf')
1551
+ finally:
1552
+ shutil.rmtree(self.temp_dir_real, ignore_errors=True)
1553
+ shutil.rmtree(self.temp_dir_fake, ignore_errors=True)
1554
+
1555
+ return fid
1556
+
1557
+ def compute_metrics(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float]:
1558
+ """Computes MSE, PSNR, and SSIM for evaluating image quality.
1559
+
1560
+ Parameters
1561
+ ----------
1562
+ x : torch.Tensor
1563
+ Ground truth images, shape (batch_size, channels, height, width).
1564
+ x_hat : torch.Tensor
1565
+ Generated images, same shape as `x`.
1566
+
1567
+ Returns
1568
+ -------
1569
+ mse : float
1570
+ Mean squared error.
1571
+ psnr : float
1572
+ Peak signal-to-noise ratio.
1573
+ ssim : float
1574
+ Structural similarity index (mean over batch).
1575
+ """
1576
+ if x.shape != x_hat.shape:
1577
+ raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
1578
+
1579
+ mse = F.mse_loss(x_hat, x)
1580
+ psnr = -10 * torch.log10(mse)
1581
+ c1, c2 = (0.01 * 2) ** 2, (0.03 * 2) ** 2 # Adjusted for [-1, 1] range
1582
+ eps = 1e-8
1583
+ mu_x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
1584
+ mu_y = F.avg_pool2d(x_hat, kernel_size=3, stride=1, padding=1)
1585
+ mu_xy = mu_x * mu_y
1586
+ sigma_x_sq = F.avg_pool2d(x.pow(2), kernel_size=3, stride=1, padding=1) - mu_x.pow(2)
1587
+ sigma_y_sq = F.avg_pool2d(x_hat.pow(2), kernel_size=3, stride=1, padding=1) - mu_y.pow(2)
1588
+ sigma_xy = F.avg_pool2d(x * x_hat, kernel_size=3, stride=1, padding=1) - mu_xy
1589
+ ssim = ((2 * mu_xy + c1) * (2 * sigma_xy + c2)) / (
1590
+ (mu_x.pow(2) + mu_y.pow(2) + c1) * (sigma_x_sq + sigma_y_sq + c2) + eps
1591
+ )
1592
+
1593
+ return mse.item(), psnr.item(), ssim.mean().item()
1594
+
1595
+ def compute_lpips(self, x: torch.Tensor, x_hat: torch.Tensor) -> float:
1596
+ """Computes LPIPS using a pre-trained VGG network.
1597
+
1598
+ Parameters
1599
+ ----------
1600
+ x : torch.Tensor
1601
+ Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
1602
+ x_hat : torch.Tensor
1603
+ Generated images, same shape as `x`.
1604
+
1605
+ Returns
1606
+ -------
1607
+ lpips (float) - Mean LPIPS score over the batch.
1608
+ """
1609
+ if self.lpips_model is None:
1610
+ raise RuntimeError("LPIPS model not initialized; set lpips=True in __init__")
1611
+ if x.shape != x_hat.shape:
1612
+ raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
1613
+
1614
+ # Normalize inputs to [0, 1] range
1615
+ x = (x + 1) / 2 # Convert from [-1, 1] to [0, 1]
1616
+ x_hat = (x_hat + 1) / 2
1617
+ x = x.clamp(0, 1) # Ensure values are in [0, 1]
1618
+ x_hat = x_hat.clamp(0, 1)
1619
+
1620
+ x = x.to(self.device)
1621
+ x_hat = x_hat.to(self.device)
1622
+
1623
+ # Convert grayscale to RGB if needed
1624
+ if x.shape[1] == 1:
1625
+ x = x.repeat(1, 3, 1, 1) # Repeat grayscale channel 3 times
1626
+ if x_hat.shape[1] == 1:
1627
+ x_hat = x_hat.repeat(1, 3, 1, 1)
1628
+
1629
+ return self.lpips_model(x, x_hat).mean().item()
1630
+
1631
+ def forward(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float, float, float]:
1632
+ """Computes specified metrics for ground truth and generated images.
1633
+
1634
+ Parameters
1635
+ ----------
1636
+ x : torch.Tensor
1637
+ Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
1638
+ x_hat : torch.Tensor
1639
+ Generated images, same shape as `x`.
1640
+
1641
+ Returns
1642
+ -------
1643
+ fid : float, or `float('inf')` if not computed
1644
+ Mean FID score.
1645
+ mse : float, or None if not computed
1646
+ Mean MSE
1647
+ psnr : float, or None if not computed
1648
+ Mean PSNR
1649
+ ssim : float, or None if not computed
1650
+ Mean SSIM
1651
+ lpips_score : float, or None if not computed
1652
+ Mean LPIPS score
1653
+ """
1654
+ fid = float('inf')
1655
+ mse, psnr, ssim = None, None, None
1656
+ lpips_score = None
1657
+
1658
+ if self.metrics:
1659
+ mse, psnr, ssim = self.compute_metrics(x, x_hat)
1660
+ if self.fid:
1661
+ fid = self.compute_fid(x, x_hat)
1662
+ if self.lpips:
1663
+ lpips_score = self.compute_lpips(x, x_hat)
1664
+
1665
+ return fid, mse, psnr, ssim, lpips_score
1666
+
1667
+
1668
+
1669
+ """
1670
+ # Ensure GPU usage
1671
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1672
+ x = torch.randn(5, 3, 50, 50).to(device)
1673
+ t = torch.rand(5).to(device)
1674
+
1675
+ np_ = NoisePredictor(
1676
+ in_channels=3,
1677
+ down_channels=[16, 32],
1678
+ mid_channels=[32, 32],
1679
+ up_channels=[32, 16],
1680
+ down_sampling=[True, True],
1681
+ time_embed_dim=20,
1682
+ y_embed_dim=20,
1683
+ num_down_blocks=2,
1684
+ num_mid_blocks=2,
1685
+ num_up_blocks=2,
1686
+ down_sampling_factor=2
1687
+ ).to(device)
1688
+
1689
+ # Uncompiled
1690
+ times = []
1691
+ for i in range(50):
1692
+ st = time.time()
1693
+ np_(x, t)
1694
+ torch.cuda.synchronize() # Ensure GPU operations complete
1695
+ ft = time.time()
1696
+ times.append(ft - st)
1697
+ print(f"Uncompiled iter: {i+1} finished in {ft - st:.4f} s")
1698
+ print(f"Uncompiled average: {sum(times)/len(times):.4f} s")
1699
+
1700
+ # Compiled with warm-up
1701
+ cm = torch.compile(np_, mode="reduce-overhead")
1702
+ for _ in range(5): # Warm-up
1703
+ cm(x, t)
1704
+ torch.cuda.synchronize()
1705
+
1706
+ times = []
1707
+ for i in range(50):
1708
+ st = time.time()
1709
+ cm(x, t)
1710
+ torch.cuda.synchronize()
1711
+ ft = time.time()
1712
+ times.append(ft - st)
1713
+ print(f"Compiled iter: {i+1} finished in {ft - st:.4f} s")
1714
+ print(f"Compiled average: {sum(times)/len(times):.4f} s")
1715
+
1716
+
1717
+ import transformers
1718
+ import time
1719
+
1720
+ # Set up device and inputs
1721
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1722
+ x = torch.randn(5, 3, 50, 50).to(device) # Image input (not used in TextEncoder)
1723
+ t = torch.rand(5).to(device) # Time embedding (not used)
1724
+ y = ["do", "not to do", "correctly", "wrong", "it is true"] # Text inputs
1725
+
1726
+ # Initialize tokenizer and TextEncoder
1727
+ tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
1728
+ te = TextEncoder(
1729
+ use_pretrained_model=True,
1730
+ model_name="bert-base-uncased",
1731
+ num_layers=2,
1732
+ vocabulary_size=30522,
1733
+ input_dimension=768,
1734
+ output_dimension=768,
1735
+ dropout_rate=0.2,
1736
+ num_heads=6,
1737
+ context_length=77
1738
+ ).to(device) # Use FP16 for better performance
1739
+
1740
+ # Tokenize inputs
1741
+ y_list = y # Already a list of strings
1742
+ y_encoded = tokenizer(
1743
+ y_list,
1744
+ padding="max_length",
1745
+ truncation=True,
1746
+ max_length=77,
1747
+ return_tensors="pt"
1748
+ ).to(device) # Convert to FP16
1749
+ input_ids = y_encoded["input_ids"] # Shape: [5, 77]
1750
+ attention_mask = y_encoded["attention_mask"] # Shape: [5, 77]
1751
+
1752
+ # Uncompiled TextEncoder
1753
+ times = []
1754
+ for i in range(50):
1755
+ st = time.time()
1756
+ with torch.no_grad(): # Disable gradients
1757
+ y_encoded = te(input_ids, attention_mask)
1758
+ torch.cuda.synchronize()
1759
+ ft = time.time()
1760
+ times.append(ft - st)
1761
+ print(f"Uncompiled iter: {i+1} finished in {ft - st:.4f} s")
1762
+ print(f"Uncompiled average: {sum(times)/len(times):.4f} s")
1763
+
1764
+ # Compiled TextEncoder with warm-up
1765
+ cm = torch.compile(te, mode="reduce-overhead")
1766
+ for _ in range(5): # Warm-up runs
1767
+ with torch.no_grad():
1768
+ cm(input_ids, attention_mask)
1769
+ torch.cuda.synchronize()
1770
+
1771
+ times = []
1772
+ for i in range(50):
1773
+ st = time.time()
1774
+ with torch.no_grad(): # Disable gradients
1775
+ y_encoded = cm(input_ids, attention_mask)
1776
+ torch.cuda.synchronize()
1777
+ ft = time.time()
1778
+ times.append(ft - st)
1779
+ print(f"Compiled iter: {i+1} finished in {ft - st:.4f} s")
1780
+ print(f"Compiled average: {sum(times)/len(times):.4f} s")
1781
+
1782
+ # Verify output shape
1783
+ print(f"Output shape: {y_encoded.shape}")
1784
+ """
1785
+
1786
+
1787
+
1788
+
1789
+
1790
+
1791
+
1792
+
1793
+