TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
unclip/val_metrics.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from torch.nn.attention import sdpa_kernel, SDPBackend
|
|
5
|
+
from pytorch_fid import fid_score
|
|
6
|
+
from torchvision.utils import save_image
|
|
7
|
+
from transformers import BertModel
|
|
8
|
+
import os
|
|
9
|
+
import lpips
|
|
10
|
+
import math
|
|
11
|
+
import shutil
|
|
12
|
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
|
13
|
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
|
14
|
+
from torchvision.utils import save_image
|
|
15
|
+
from typing import Optional, Tuple, List
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Metrics:
|
|
24
|
+
"""Computes image quality metrics for evaluating diffusion models.
|
|
25
|
+
|
|
26
|
+
Supports Mean Squared Error (MSE), Peak Signal-to-Noise Ratio (PSNR), Structural
|
|
27
|
+
Similarity Index (SSIM), Fréchet Inception Distance (FID), and Learned Perceptual
|
|
28
|
+
Image Patch Similarity (LPIPS) for comparing generated and ground truth images.
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
device : str, optional
|
|
33
|
+
Device for computation (e.g., 'cuda', 'cpu') (default: 'cuda').
|
|
34
|
+
fid : bool, optional
|
|
35
|
+
If True, compute FID score (default: True).
|
|
36
|
+
metrics : bool, optional
|
|
37
|
+
If True, compute MSE, PSNR, and SSIM (default: False).
|
|
38
|
+
lpips : bool, optional
|
|
39
|
+
If True, compute LPIPS using VGG backbone (default: False).
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
device: str = "cuda",
|
|
45
|
+
fid: bool = True,
|
|
46
|
+
metrics: bool = False,
|
|
47
|
+
lpips_: bool = False
|
|
48
|
+
) -> None:
|
|
49
|
+
self.device = device
|
|
50
|
+
self.fid = fid
|
|
51
|
+
self.metrics = metrics
|
|
52
|
+
self.lpips = lpips_
|
|
53
|
+
self.lpips_model = LearnedPerceptualImagePatchSimilarity(
|
|
54
|
+
net_type='vgg',
|
|
55
|
+
normalize=True # This handles [0,1] -> [-1,1] conversion
|
|
56
|
+
).to(device) if self.lpips else None
|
|
57
|
+
self.temp_dir_real = "temp_real"
|
|
58
|
+
self.temp_dir_fake = "temp_fake"
|
|
59
|
+
|
|
60
|
+
def compute_fid(self, real_images: torch.Tensor, fake_images: torch.Tensor) -> float:
|
|
61
|
+
"""Computes the Fréchet Inception Distance (FID) between real and generated images.
|
|
62
|
+
|
|
63
|
+
Saves images to temporary directories and uses Inception V3 to compute FID,
|
|
64
|
+
cleaning up directories afterward.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
real_images : torch.Tensor
|
|
69
|
+
Real images, shape (batch_size, channels, height, width), in [-1, 1].
|
|
70
|
+
fake_images : torch.Tensor
|
|
71
|
+
Generated images, same shape, in [-1, 1].
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
fid (float) - FID score, or `float('inf')` if computation fails.
|
|
76
|
+
|
|
77
|
+
**Notes**
|
|
78
|
+
|
|
79
|
+
- Images are normalized to [0, 1] and saved as PNG files for FID computation.
|
|
80
|
+
- Uses Inception V3 with 2048-dimensional features (`dims=2048`).
|
|
81
|
+
"""
|
|
82
|
+
if real_images.shape != fake_images.shape:
|
|
83
|
+
raise ValueError(f"Shape mismatch: real_images {real_images.shape}, fake_images {fake_images.shape}")
|
|
84
|
+
|
|
85
|
+
real_images = (real_images + 1) / 2
|
|
86
|
+
fake_images = (fake_images + 1) / 2
|
|
87
|
+
real_images = real_images.clamp(0, 1).cpu()
|
|
88
|
+
fake_images = fake_images.clamp(0, 1).cpu()
|
|
89
|
+
|
|
90
|
+
os.makedirs(self.temp_dir_real, exist_ok=True)
|
|
91
|
+
os.makedirs(self.temp_dir_fake, exist_ok=True)
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
for i, (real, fake) in enumerate(zip(real_images, fake_images)):
|
|
95
|
+
save_image(real, f"{self.temp_dir_real}/{i}.png")
|
|
96
|
+
save_image(fake, f"{self.temp_dir_fake}/{i}.png")
|
|
97
|
+
|
|
98
|
+
fid = fid_score.calculate_fid_given_paths(
|
|
99
|
+
paths=[self.temp_dir_real, self.temp_dir_fake],
|
|
100
|
+
batch_size=50,
|
|
101
|
+
device=self.device,
|
|
102
|
+
dims=2048
|
|
103
|
+
)
|
|
104
|
+
except Exception as e:
|
|
105
|
+
print(f"Error computing FID: {e}")
|
|
106
|
+
fid = float('inf')
|
|
107
|
+
finally:
|
|
108
|
+
shutil.rmtree(self.temp_dir_real, ignore_errors=True)
|
|
109
|
+
shutil.rmtree(self.temp_dir_fake, ignore_errors=True)
|
|
110
|
+
|
|
111
|
+
return fid
|
|
112
|
+
|
|
113
|
+
def compute_metrics(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float]:
|
|
114
|
+
"""Computes MSE, PSNR, and SSIM for evaluating image quality.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
x : torch.Tensor
|
|
119
|
+
Ground truth images, shape (batch_size, channels, height, width).
|
|
120
|
+
x_hat : torch.Tensor
|
|
121
|
+
Generated images, same shape as `x`.
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
mse : float
|
|
126
|
+
Mean squared error.
|
|
127
|
+
psnr : float
|
|
128
|
+
Peak signal-to-noise ratio.
|
|
129
|
+
ssim : float
|
|
130
|
+
Structural similarity index (mean over batch).
|
|
131
|
+
"""
|
|
132
|
+
if x.shape != x_hat.shape:
|
|
133
|
+
raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
|
|
134
|
+
|
|
135
|
+
mse = F.mse_loss(x_hat, x)
|
|
136
|
+
psnr = -10 * torch.log10(mse)
|
|
137
|
+
c1, c2 = (0.01 * 2) ** 2, (0.03 * 2) ** 2 # Adjusted for [-1, 1] range
|
|
138
|
+
eps = 1e-8
|
|
139
|
+
mu_x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
|
140
|
+
mu_y = F.avg_pool2d(x_hat, kernel_size=3, stride=1, padding=1)
|
|
141
|
+
mu_xy = mu_x * mu_y
|
|
142
|
+
sigma_x_sq = F.avg_pool2d(x.pow(2), kernel_size=3, stride=1, padding=1) - mu_x.pow(2)
|
|
143
|
+
sigma_y_sq = F.avg_pool2d(x_hat.pow(2), kernel_size=3, stride=1, padding=1) - mu_y.pow(2)
|
|
144
|
+
sigma_xy = F.avg_pool2d(x * x_hat, kernel_size=3, stride=1, padding=1) - mu_xy
|
|
145
|
+
ssim = ((2 * mu_xy + c1) * (2 * sigma_xy + c2)) / (
|
|
146
|
+
(mu_x.pow(2) + mu_y.pow(2) + c1) * (sigma_x_sq + sigma_y_sq + c2) + eps
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return mse.item(), psnr.item(), ssim.mean().item()
|
|
150
|
+
|
|
151
|
+
def compute_lpips(self, x: torch.Tensor, x_hat: torch.Tensor) -> float:
|
|
152
|
+
"""Computes LPIPS using a pre-trained VGG network.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
x : torch.Tensor
|
|
157
|
+
Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
|
|
158
|
+
x_hat : torch.Tensor
|
|
159
|
+
Generated images, same shape as `x`.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
lpips (float) - Mean LPIPS score over the batch.
|
|
164
|
+
"""
|
|
165
|
+
if self.lpips_model is None:
|
|
166
|
+
raise RuntimeError("LPIPS model not initialized; set lpips=True in __init__")
|
|
167
|
+
if x.shape != x_hat.shape:
|
|
168
|
+
raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
|
|
169
|
+
|
|
170
|
+
# Normalize inputs to [0, 1] range
|
|
171
|
+
x = (x + 1) / 2 # Convert from [-1, 1] to [0, 1]
|
|
172
|
+
x_hat = (x_hat + 1) / 2
|
|
173
|
+
x = x.clamp(0, 1) # Ensure values are in [0, 1]
|
|
174
|
+
x_hat = x_hat.clamp(0, 1)
|
|
175
|
+
|
|
176
|
+
x = x.to(self.device)
|
|
177
|
+
x_hat = x_hat.to(self.device)
|
|
178
|
+
|
|
179
|
+
# Convert grayscale to RGB if needed
|
|
180
|
+
if x.shape[1] == 1:
|
|
181
|
+
x = x.repeat(1, 3, 1, 1) # Repeat grayscale channel 3 times
|
|
182
|
+
if x_hat.shape[1] == 1:
|
|
183
|
+
x_hat = x_hat.repeat(1, 3, 1, 1)
|
|
184
|
+
|
|
185
|
+
return self.lpips_model(x, x_hat).mean().item()
|
|
186
|
+
|
|
187
|
+
def forward(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float, float, float]:
|
|
188
|
+
"""Computes specified metrics for ground truth and generated images.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
x : torch.Tensor
|
|
193
|
+
Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
|
|
194
|
+
x_hat : torch.Tensor
|
|
195
|
+
Generated images, same shape as `x`.
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
fid : float, or `float('inf')` if not computed
|
|
200
|
+
Mean FID score.
|
|
201
|
+
mse : float, or None if not computed
|
|
202
|
+
Mean MSE
|
|
203
|
+
psnr : float, or None if not computed
|
|
204
|
+
Mean PSNR
|
|
205
|
+
ssim : float, or None if not computed
|
|
206
|
+
Mean SSIM
|
|
207
|
+
lpips_score : float, or None if not computed
|
|
208
|
+
Mean LPIPS score
|
|
209
|
+
"""
|
|
210
|
+
fid = float('inf')
|
|
211
|
+
mse, psnr, ssim = None, None, None
|
|
212
|
+
lpips_score = None
|
|
213
|
+
|
|
214
|
+
if self.metrics:
|
|
215
|
+
mse, psnr, ssim = self.compute_metrics(x, x_hat)
|
|
216
|
+
if self.fid:
|
|
217
|
+
fid = self.compute_fid(x, x_hat)
|
|
218
|
+
if self.lpips:
|
|
219
|
+
lpips_score = self.compute_lpips(x, x_hat)
|
|
220
|
+
|
|
221
|
+
return fid, mse, psnr, ssim, lpips_score
|