TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
unclip/val_metrics.py ADDED
@@ -0,0 +1,221 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.attention import sdpa_kernel, SDPBackend
5
+ from pytorch_fid import fid_score
6
+ from torchvision.utils import save_image
7
+ from transformers import BertModel
8
+ import os
9
+ import lpips
10
+ import math
11
+ import shutil
12
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
13
+ from torchmetrics.image.fid import FrechetInceptionDistance
14
+ from torchvision.utils import save_image
15
+ from typing import Optional, Tuple, List
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+ class Metrics:
24
+ """Computes image quality metrics for evaluating diffusion models.
25
+
26
+ Supports Mean Squared Error (MSE), Peak Signal-to-Noise Ratio (PSNR), Structural
27
+ Similarity Index (SSIM), Fréchet Inception Distance (FID), and Learned Perceptual
28
+ Image Patch Similarity (LPIPS) for comparing generated and ground truth images.
29
+
30
+ Parameters
31
+ ----------
32
+ device : str, optional
33
+ Device for computation (e.g., 'cuda', 'cpu') (default: 'cuda').
34
+ fid : bool, optional
35
+ If True, compute FID score (default: True).
36
+ metrics : bool, optional
37
+ If True, compute MSE, PSNR, and SSIM (default: False).
38
+ lpips : bool, optional
39
+ If True, compute LPIPS using VGG backbone (default: False).
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ device: str = "cuda",
45
+ fid: bool = True,
46
+ metrics: bool = False,
47
+ lpips_: bool = False
48
+ ) -> None:
49
+ self.device = device
50
+ self.fid = fid
51
+ self.metrics = metrics
52
+ self.lpips = lpips_
53
+ self.lpips_model = LearnedPerceptualImagePatchSimilarity(
54
+ net_type='vgg',
55
+ normalize=True # This handles [0,1] -> [-1,1] conversion
56
+ ).to(device) if self.lpips else None
57
+ self.temp_dir_real = "temp_real"
58
+ self.temp_dir_fake = "temp_fake"
59
+
60
+ def compute_fid(self, real_images: torch.Tensor, fake_images: torch.Tensor) -> float:
61
+ """Computes the Fréchet Inception Distance (FID) between real and generated images.
62
+
63
+ Saves images to temporary directories and uses Inception V3 to compute FID,
64
+ cleaning up directories afterward.
65
+
66
+ Parameters
67
+ ----------
68
+ real_images : torch.Tensor
69
+ Real images, shape (batch_size, channels, height, width), in [-1, 1].
70
+ fake_images : torch.Tensor
71
+ Generated images, same shape, in [-1, 1].
72
+
73
+ Returns
74
+ -------
75
+ fid (float) - FID score, or `float('inf')` if computation fails.
76
+
77
+ **Notes**
78
+
79
+ - Images are normalized to [0, 1] and saved as PNG files for FID computation.
80
+ - Uses Inception V3 with 2048-dimensional features (`dims=2048`).
81
+ """
82
+ if real_images.shape != fake_images.shape:
83
+ raise ValueError(f"Shape mismatch: real_images {real_images.shape}, fake_images {fake_images.shape}")
84
+
85
+ real_images = (real_images + 1) / 2
86
+ fake_images = (fake_images + 1) / 2
87
+ real_images = real_images.clamp(0, 1).cpu()
88
+ fake_images = fake_images.clamp(0, 1).cpu()
89
+
90
+ os.makedirs(self.temp_dir_real, exist_ok=True)
91
+ os.makedirs(self.temp_dir_fake, exist_ok=True)
92
+
93
+ try:
94
+ for i, (real, fake) in enumerate(zip(real_images, fake_images)):
95
+ save_image(real, f"{self.temp_dir_real}/{i}.png")
96
+ save_image(fake, f"{self.temp_dir_fake}/{i}.png")
97
+
98
+ fid = fid_score.calculate_fid_given_paths(
99
+ paths=[self.temp_dir_real, self.temp_dir_fake],
100
+ batch_size=50,
101
+ device=self.device,
102
+ dims=2048
103
+ )
104
+ except Exception as e:
105
+ print(f"Error computing FID: {e}")
106
+ fid = float('inf')
107
+ finally:
108
+ shutil.rmtree(self.temp_dir_real, ignore_errors=True)
109
+ shutil.rmtree(self.temp_dir_fake, ignore_errors=True)
110
+
111
+ return fid
112
+
113
+ def compute_metrics(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float]:
114
+ """Computes MSE, PSNR, and SSIM for evaluating image quality.
115
+
116
+ Parameters
117
+ ----------
118
+ x : torch.Tensor
119
+ Ground truth images, shape (batch_size, channels, height, width).
120
+ x_hat : torch.Tensor
121
+ Generated images, same shape as `x`.
122
+
123
+ Returns
124
+ -------
125
+ mse : float
126
+ Mean squared error.
127
+ psnr : float
128
+ Peak signal-to-noise ratio.
129
+ ssim : float
130
+ Structural similarity index (mean over batch).
131
+ """
132
+ if x.shape != x_hat.shape:
133
+ raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
134
+
135
+ mse = F.mse_loss(x_hat, x)
136
+ psnr = -10 * torch.log10(mse)
137
+ c1, c2 = (0.01 * 2) ** 2, (0.03 * 2) ** 2 # Adjusted for [-1, 1] range
138
+ eps = 1e-8
139
+ mu_x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
140
+ mu_y = F.avg_pool2d(x_hat, kernel_size=3, stride=1, padding=1)
141
+ mu_xy = mu_x * mu_y
142
+ sigma_x_sq = F.avg_pool2d(x.pow(2), kernel_size=3, stride=1, padding=1) - mu_x.pow(2)
143
+ sigma_y_sq = F.avg_pool2d(x_hat.pow(2), kernel_size=3, stride=1, padding=1) - mu_y.pow(2)
144
+ sigma_xy = F.avg_pool2d(x * x_hat, kernel_size=3, stride=1, padding=1) - mu_xy
145
+ ssim = ((2 * mu_xy + c1) * (2 * sigma_xy + c2)) / (
146
+ (mu_x.pow(2) + mu_y.pow(2) + c1) * (sigma_x_sq + sigma_y_sq + c2) + eps
147
+ )
148
+
149
+ return mse.item(), psnr.item(), ssim.mean().item()
150
+
151
+ def compute_lpips(self, x: torch.Tensor, x_hat: torch.Tensor) -> float:
152
+ """Computes LPIPS using a pre-trained VGG network.
153
+
154
+ Parameters
155
+ ----------
156
+ x : torch.Tensor
157
+ Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
158
+ x_hat : torch.Tensor
159
+ Generated images, same shape as `x`.
160
+
161
+ Returns
162
+ -------
163
+ lpips (float) - Mean LPIPS score over the batch.
164
+ """
165
+ if self.lpips_model is None:
166
+ raise RuntimeError("LPIPS model not initialized; set lpips=True in __init__")
167
+ if x.shape != x_hat.shape:
168
+ raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
169
+
170
+ # Normalize inputs to [0, 1] range
171
+ x = (x + 1) / 2 # Convert from [-1, 1] to [0, 1]
172
+ x_hat = (x_hat + 1) / 2
173
+ x = x.clamp(0, 1) # Ensure values are in [0, 1]
174
+ x_hat = x_hat.clamp(0, 1)
175
+
176
+ x = x.to(self.device)
177
+ x_hat = x_hat.to(self.device)
178
+
179
+ # Convert grayscale to RGB if needed
180
+ if x.shape[1] == 1:
181
+ x = x.repeat(1, 3, 1, 1) # Repeat grayscale channel 3 times
182
+ if x_hat.shape[1] == 1:
183
+ x_hat = x_hat.repeat(1, 3, 1, 1)
184
+
185
+ return self.lpips_model(x, x_hat).mean().item()
186
+
187
+ def forward(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float, float, float]:
188
+ """Computes specified metrics for ground truth and generated images.
189
+
190
+ Parameters
191
+ ----------
192
+ x : torch.Tensor
193
+ Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
194
+ x_hat : torch.Tensor
195
+ Generated images, same shape as `x`.
196
+
197
+ Returns
198
+ -------
199
+ fid : float, or `float('inf')` if not computed
200
+ Mean FID score.
201
+ mse : float, or None if not computed
202
+ Mean MSE
203
+ psnr : float, or None if not computed
204
+ Mean PSNR
205
+ ssim : float, or None if not computed
206
+ Mean SSIM
207
+ lpips_score : float, or None if not computed
208
+ Mean LPIPS score
209
+ """
210
+ fid = float('inf')
211
+ mse, psnr, ssim = None, None, None
212
+ lpips_score = None
213
+
214
+ if self.metrics:
215
+ mse, psnr, ssim = self.compute_metrics(x, x_hat)
216
+ if self.fid:
217
+ fid = self.compute_fid(x, x_hat)
218
+ if self.lpips:
219
+ lpips_score = self.compute_lpips(x, x_hat)
220
+
221
+ return fid, mse, psnr, ssim, lpips_score