TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
ldm/forward_idm.py ADDED
@@ -0,0 +1,100 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+
7
+ class ForwardSDE(nn.Module):
8
+ def __init__(self, hyper_params, method):
9
+ super().__init__()
10
+ self.hyper_params = hyper_params
11
+ self.method = method
12
+
13
+ def forward(self, x0, noise, time_steps):
14
+
15
+ dt = self.hyper_params.dt
16
+ if self.method == "ve":
17
+ sigma_t = self.hyper_params.sigmas[time_steps]
18
+ sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
19
+ sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
20
+ x0 = x0 + noise * sigma_diff.view(-1, 1, 1, 1)
21
+
22
+ elif self.method == "vp":
23
+ betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
24
+ drift = -0.5 * betas * x0 * dt
25
+ diffusion = torch.sqrt(betas * dt) * noise
26
+ x0 = x0 + drift + diffusion
27
+
28
+ elif self.method == "sub-vp":
29
+ betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
30
+ cum_betas = self.hyper_params.cum_betas[time_steps].view(-1, 1, 1, 1)
31
+ drift = -0.5 * betas * x0 * dt
32
+ diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise
33
+ x0 = x0 + drift + diffusion
34
+
35
+ elif self.method == "ode":
36
+ if self.method == "ve":
37
+ x0 = x0
38
+ else:
39
+ betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
40
+ drift = -0.5 * betas * x0 * dt
41
+ x0 = x0 + drift
42
+ else:
43
+ raise ValueError(f"Unknown method: {self.method}")
44
+ return x0
45
+
46
+
47
+
48
+ class ForwardDDPM(nn.Module):
49
+ """forward diffusion process of DDPM."""
50
+ def __init__(self, hyper_params):
51
+ super().__init__()
52
+ self.hyper_params = hyper_params
53
+
54
+ def forward(self, x0, noise, time_steps):
55
+ if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.num_steps)):
56
+ raise ValueError(f"time_steps must be between 0 and {self.hyper_params.num_steps - 1}")
57
+
58
+ if self.hyper_params.trainable_beta:
59
+ _, _, _, sqrt_alpha_bar_t, sqrt_one_minus_alpha_bar_t = self.hyper_params.compute_schedule(
60
+ self.hyper_params.betas
61
+ )
62
+ sqrt_alpha_bar_t = sqrt_alpha_bar_t[time_steps].to(x0.device)
63
+ sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t[time_steps].to(x0.device)
64
+ else:
65
+ sqrt_alpha_bar_t = self.hyper_params.sqrt_alpha_bars[time_steps].to(x0.device)
66
+ sqrt_one_minus_alpha_bar_t = self.hyper_params.sqrt_one_minus_alpha_bars[time_steps].to(x0.device)
67
+
68
+ sqrt_alpha_bar_t = sqrt_alpha_bar_t.view(-1, 1, 1, 1)
69
+ sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.view(-1, 1, 1, 1)
70
+
71
+ xt = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise
72
+ return xt
73
+
74
+
75
+
76
+ class ForwardDDIM(nn.Module):
77
+ def __init__(self, hyper_params):
78
+ """forward diffusion process of DDIM"""
79
+ super().__init__()
80
+ self.hyper_params = hyper_params
81
+
82
+ def forward(self, x0, noise, time_steps):
83
+ if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.num_steps)):
84
+ raise ValueError(f"time_steps must be between 0 and {self.hyper_params.num_steps - 1}")
85
+
86
+ if self.hyper_params.trainable_beta:
87
+ _, _, _, sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = self.hyper_params.compute_schedule(
88
+ self.hyper_params.betas
89
+ )
90
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t[time_steps].to(x0.device)
91
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t[time_steps].to(x0.device)
92
+ else:
93
+ sqrt_alpha_cumprod_t = self.hyper_params.sqrt_alpha_cumprod[time_steps].to(x0.device)
94
+ sqrt_one_minus_alpha_cumprod_t = self.hyper_params.sqrt_one_minus_alpha_cumprod[time_steps].to(x0.device)
95
+
96
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1, 1, 1)
97
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1, 1, 1)
98
+
99
+ xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
100
+ return xt
ldm/hyper_param.py ADDED
@@ -0,0 +1,239 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ class HyperParamsDDPM(nn.Module):
7
+ """hyperparameters for DDPM noise schedule with flexible beta computation."""
8
+ def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, trainable_beta=False, beta_method="linear"):
9
+ super().__init__()
10
+ self.num_steps = num_steps
11
+ self.beta_start = beta_start
12
+ self.beta_end = beta_end
13
+ self.trainable_beta = trainable_beta
14
+ self.beta_method = beta_method
15
+
16
+ # validate inputs
17
+ if not (0 < beta_start < beta_end < 1):
18
+ raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
19
+ if num_steps <= 0:
20
+ raise ValueError(f"num_steps ({num_steps}) must be positive")
21
+
22
+ # compute initial beta schedule
23
+ beta_range = (beta_start, beta_end)
24
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
25
+
26
+ # initialize betas
27
+ if trainable_beta:
28
+ self.betas = nn.Parameter(betas_init) # Trainable parameter
29
+ else:
30
+ self.register_buffer('betas', betas_init) # Fixed buffer
31
+ self.register_buffer('alphas', 1 - self.betas)
32
+ self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
33
+ self.register_buffer('sqrt_alpha_bars', torch.sqrt(self.alpha_bars))
34
+ self.register_buffer('sqrt_one_minus_alpha_bars', torch.sqrt(1 - self.alpha_bars))
35
+
36
+ def compute_beta_schedule(self, beta_range, num_steps, method):
37
+ """
38
+ Computes the beta schedule based on the selected method.
39
+ Args:
40
+ beta_range: Tuple of (min_beta, max_beta) values
41
+ num_steps: Number of diffusion steps
42
+ method: Method for beta schedule ("linear", "sigmoid", "quadratic", "constant", "inverse_time")
43
+ Returns:
44
+ Tensor of beta values, shape [num_steps]
45
+ """
46
+ beta_min, beta_max = beta_range
47
+ if method == "sigmoid":
48
+ x = torch.linspace(-6, 6, num_steps)
49
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
50
+ elif method == "quadratic":
51
+ x = torch.linspace(beta_min**0.5, beta_max**0.5, num_steps)
52
+ beta = x**2
53
+ elif method == "constant":
54
+ beta = torch.full((num_steps,), beta_max)
55
+ elif method == "inverse_time":
56
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
57
+ # scale to beta_range
58
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
59
+ elif method == "linear":
60
+ beta = torch.linspace(beta_min, beta_max, num_steps)
61
+ else:
62
+ raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
63
+
64
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
65
+ return beta
66
+
67
+ @staticmethod
68
+ def compute_schedule(betas):
69
+ """computes noise schedule parameters dynamically from betas."""
70
+ alphas = 1 - betas
71
+ alpha_bars = torch.cumprod(alphas, dim=0)
72
+ return betas, alphas, alpha_bars, torch.sqrt(alpha_bars), torch.sqrt(1 - alpha_bars)
73
+
74
+ def constrain_betas(self):
75
+ """constrains betas to a valid range during training."""
76
+ if self.trainable_beta:
77
+ with torch.no_grad():
78
+ self.betas.clamp_(min=self.beta_start, max=self.beta_end)
79
+
80
+
81
+
82
+
83
+
84
+ class HyperParamsDDIM(nn.Module):
85
+ """Hyperparameters for DDIM noise schedule with flexible beta computation."""
86
+ def __init__(self, eta=None, num_steps=1000, tau_num_steps=100, beta_start=1e-4, beta_end=0.02,
87
+ trainable_beta=False, beta_method="linear"):
88
+ super().__init__()
89
+ self.eta = eta or 0
90
+ self.num_steps = num_steps
91
+ self.tau_num_steps = tau_num_steps
92
+ self.beta_start = beta_start
93
+ self.beta_end = beta_end
94
+ self.trainable_beta = trainable_beta
95
+ self.beta_method = beta_method
96
+
97
+ if not (0 < beta_start < beta_end < 1):
98
+ raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
99
+ if num_steps <= 0:
100
+ raise ValueError(f"num_steps ({num_steps}) must be positive")
101
+
102
+ beta_range = (beta_start, beta_end)
103
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
104
+
105
+ if trainable_beta:
106
+ self.betas = nn.Parameter(betas_init)
107
+ else:
108
+ self.register_buffer('betas', betas_init)
109
+ self.register_buffer('alphas', 1 - self.betas)
110
+ self.register_buffer('alpha_cumprod', torch.cumprod(self.alphas, dim=0))
111
+ self.register_buffer('sqrt_alpha_cumprod', torch.sqrt(self.alpha_cumprod))
112
+ self.register_buffer('sqrt_one_minus_alpha_cumprod', torch.sqrt(1 - self.alpha_cumprod))
113
+
114
+ self.register_buffer('tau_indices', torch.linspace(0, num_steps - 1, tau_num_steps, dtype=torch.long))
115
+
116
+ def compute_beta_schedule(self, beta_range, num_steps, method):
117
+
118
+ beta_min, beta_max = beta_range
119
+ if method == "sigmoid":
120
+ x = torch.linspace(-6, 6, num_steps)
121
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
122
+ elif method == "quadratic":
123
+ x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
124
+ beta = x ** 2
125
+ elif method == "constant":
126
+ beta = torch.full((num_steps,), beta_max)
127
+ elif method == "inverse_time":
128
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
129
+ # scale to beta_range
130
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
131
+ elif method == "linear":
132
+ beta = torch.linspace(beta_min, beta_max, num_steps)
133
+ else:
134
+ raise ValueError(
135
+ f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
136
+
137
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
138
+ return beta
139
+
140
+ def get_tau_schedule(self):
141
+ if self.trainable_beta:
142
+ betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod = self.compute_schedule(self.betas)
143
+ else:
144
+ betas = self.betas
145
+ alphas = self.alphas
146
+ alpha_cumprod = self.alpha_cumprod
147
+ sqrt_alpha_cumprod = self.sqrt_alpha_cumprod
148
+ sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod
149
+
150
+ tau_betas = betas[self.tau_indices]
151
+ tau_alphas = alphas[self.tau_indices]
152
+ tau_alpha_cumprod = alpha_cumprod[self.tau_indices]
153
+ tau_sqrt_alpha_cumprod = sqrt_alpha_cumprod[self.tau_indices]
154
+ tau_sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[self.tau_indices]
155
+
156
+ return tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod
157
+
158
+ def compute_schedule(self, betas):
159
+ """computes noise schedule parameters dynamically from betas."""
160
+ alphas = 1 - betas
161
+ alpha_cumprod = torch.cumprod(alphas, dim=0)
162
+ return betas, alphas, alpha_cumprod, torch.sqrt(alpha_cumprod), torch.sqrt(1 - alpha_cumprod)
163
+
164
+ def constrain_betas(self):
165
+ """constrains betas to a valid range during training."""
166
+ if self.trainable_beta:
167
+ with torch.no_grad():
168
+ self.betas.clamp_(min=self.beta_start, max=self.beta_end)
169
+
170
+
171
+
172
+
173
+ class HyperParamsSDE(nn.Module):
174
+ def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, trainable_beta=False, beta_method="linear",
175
+ sigma_start=1e-3, sigma_end=10.0, start=0.0, end=1.0):
176
+ super().__init__()
177
+ self.num_steps = num_steps
178
+ self.beta_start = beta_start
179
+ self.beta_end = beta_end
180
+ self.trainable_beta = trainable_beta
181
+ self.beta_method = beta_method
182
+ self.sigma_start = sigma_start
183
+ self.sigma_end = sigma_end
184
+ self.start = start
185
+ self.end = end
186
+
187
+ if not (0 < self.beta_start < self.beta_end):
188
+ raise ValueError(f"beta_start ({self.beta_start}) and beta_end ({self.beta_end}) must satisfy 0 < start < end")
189
+ if not (0 < self.sigma_start < self.sigma_end):
190
+ raise ValueError(f"sigma_start ({self.sigma_start}) and sigma_end ({self.sigma_end}) must satisfy 0 < start < end")
191
+ if self.num_steps <= 0:
192
+ raise ValueError(f"num_steps ({self.num_steps}) must be positive")
193
+
194
+ beta_range = (beta_start, beta_end)
195
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
196
+ self.time = torch.linspace(self.start, self.end, self.num_steps, dtype=torch.float32)
197
+ self.dt = (self.end - self.start) / self.num_steps
198
+
199
+ if trainable_beta:
200
+ self.betas = nn.Parameter(betas_init)
201
+ else:
202
+ self.register_buffer('betas', betas_init)
203
+ self.register_buffer('cum_betas', torch.cumsum(betas_init, dim=0) * self.dt)
204
+ self.register_buffer("sigmas", self.sigma_start * (self.sigma_end / self.sigma_start) ** self.time)
205
+
206
+ def compute_beta_schedule(self, beta_range, num_steps, method):
207
+ beta_min, beta_max = beta_range
208
+ if method == "sigmoid":
209
+ x = torch.linspace(-6, 6, num_steps)
210
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
211
+ elif method == "quadratic":
212
+ x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
213
+ beta = x ** 2
214
+ elif method == "constant":
215
+ beta = torch.full((num_steps,), beta_max)
216
+ elif method == "inverse_time":
217
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
218
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
219
+ elif method == "linear":
220
+ beta = torch.linspace(beta_min, beta_max, num_steps)
221
+ else:
222
+ raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
223
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
224
+ return beta
225
+
226
+ def constrain_betas(self):
227
+ if self.trainable_beta:
228
+ with torch.no_grad():
229
+ self.betas.clamp_(min=self.beta_start, max=self.beta_end)
230
+
231
+ def get_variance(self, time_steps, method):
232
+ if method == "ve":
233
+ return self.sigmas[time_steps] ** 2
234
+ elif method == "vp":
235
+ return 1 - torch.exp(-self.cum_betas[time_steps])
236
+ elif method == "sub-vp":
237
+ return 1 - torch.exp(-2 * self.cum_betas[time_steps])
238
+ else:
239
+ raise ValueError(f"Unknown method: {method}")
ldm/metrics.py ADDED
@@ -0,0 +1,206 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import lpips
4
+ from pytorch_fid import fid_score
5
+ import shutil
6
+ from torchvision.utils import save_image
7
+ import os
8
+
9
+
10
+ class Metrics:
11
+ """Computes image quality metrics for evaluating diffusion models.
12
+
13
+ Supports Mean Squared Error (MSE), Peak Signal-to-Noise Ratio (PSNR), Structural
14
+ Similarity Index (SSIM), Fréchet Inception Distance (FID), and Learned Perceptual
15
+ Image Patch Similarity (LPIPS) for comparing generated and ground truth images.
16
+
17
+ Parameters
18
+ ----------
19
+ device : str, optional
20
+ Device for computation (e.g., 'cuda', 'cpu') (default: 'cuda').
21
+ fid : bool, optional
22
+ If True, compute FID score (default: True).
23
+ metrics : bool, optional
24
+ If True, compute MSE, PSNR, and SSIM (default: False).
25
+ lpips : bool, optional
26
+ If True, compute LPIPS using VGG backbone (default: False).
27
+
28
+ Attributes
29
+ ----------
30
+ device : str
31
+ Computation device.
32
+ fid : bool
33
+ Flag for FID computation.
34
+ metrics : bool
35
+ Flag for MSE, PSNR, SSIM computation.
36
+ lpips : bool
37
+ Flag for LPIPS computation.
38
+ lpips_model : lpips.LPIPS or None
39
+ LPIPS model (VGG backbone) if `lpips=True`; otherwise, None.
40
+ temp_dir_real : str
41
+ Temporary directory for real images during FID computation.
42
+ temp_dir_fake : str
43
+ Temporary directory for fake (generated) images during FID computation.
44
+ """
45
+
46
+ def __init__(self, device="cuda", fid=True, metrics=False, lpips=False):
47
+ self.device = device
48
+ self.fid = fid
49
+ self.metrics = metrics
50
+ self.lpips = lpips
51
+ self.lpips_model = lpips.LPIPS(net='vgg').to(device) if lpips else None
52
+ self.temp_dir_real = "temp_real"
53
+ self.temp_dir_fake = "temp_fake"
54
+
55
+ def compute_fid(self, real_images, fake_images):
56
+ """Computes the Fréchet Inception Distance (FID) between real and generated images.
57
+
58
+ Saves images to temporary directories and uses Inception V3 to compute FID,
59
+ cleaning up directories afterward.
60
+
61
+ Parameters
62
+ ----------
63
+ real_images : torch.Tensor
64
+ Real images, shape (batch_size, channels, height, width), in [-1, 1].
65
+ fake_images : torch.Tensor
66
+ Generated images, same shape, in [-1, 1].
67
+
68
+ Returns
69
+ -------
70
+ float
71
+ FID score, or `float('inf')` if computation fails.
72
+
73
+ Notes
74
+ -----
75
+ - Images are normalized to [0, 1] and saved as PNG files for FID computation.
76
+ - Uses Inception V3 with 2048-dimensional features (`dims=2048`).
77
+ """
78
+ if real_images.shape != fake_images.shape:
79
+ raise ValueError(f"Shape mismatch: real_images {real_images.shape}, fake_images {fake_images.shape}")
80
+
81
+ real_images = (real_images + 1) / 2
82
+ fake_images = (fake_images + 1) / 2
83
+ real_images = real_images.clamp(0, 1).cpu()
84
+ fake_images = fake_images.clamp(0, 1).cpu()
85
+
86
+ os.makedirs(self.temp_dir_real, exist_ok=True)
87
+ os.makedirs(self.temp_dir_fake, exist_ok=True)
88
+
89
+ try:
90
+ for i, (real, fake) in enumerate(zip(real_images, fake_images)):
91
+ save_image(real, f"{self.temp_dir_real}/{i}.png")
92
+ save_image(fake, f"{self.temp_dir_fake}/{i}.png")
93
+
94
+ fid = fid_score.calculate_fid_given_paths(
95
+ paths=[self.temp_dir_real, self.temp_dir_fake],
96
+ batch_size=50,
97
+ device=self.device,
98
+ dims=2048
99
+ )
100
+ except Exception as e:
101
+ print(f"Error computing FID: {e}")
102
+ fid = float('inf')
103
+ finally:
104
+ shutil.rmtree(self.temp_dir_real, ignore_errors=True)
105
+ shutil.rmtree(self.temp_dir_fake, ignore_errors=True)
106
+
107
+ return fid
108
+
109
+ def compute_metrics(self, x, x_hat):
110
+ """Computes MSE, PSNR, and SSIM for evaluating image quality.
111
+
112
+ Parameters
113
+ ----------
114
+ x : torch.Tensor
115
+ Ground truth images, shape (batch_size, channels, height, width).
116
+ x_hat : torch.Tensor
117
+ Generated images, same shape as `x`.
118
+
119
+ Returns
120
+ -------
121
+ tuple
122
+ Tuple of (mse, psnr, ssim) as floats, where:
123
+ - mse: Mean squared error.
124
+ - psnr: Peak signal-to-noise ratio.
125
+ - ssim: Structural similarity index (mean over batch).
126
+ """
127
+ if x.shape != x_hat.shape:
128
+ raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
129
+
130
+ mse = F.mse_loss(x_hat, x)
131
+ psnr = -10 * torch.log10(mse)
132
+ c1, c2 = (0.01 * 2) ** 2, (0.03 * 2) ** 2 # Adjusted for [-1, 1] range
133
+ eps = 1e-8
134
+ mu_x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
135
+ mu_y = F.avg_pool2d(x_hat, kernel_size=3, stride=1, padding=1)
136
+ mu_xy = mu_x * mu_y
137
+ sigma_x_sq = F.avg_pool2d(x.pow(2), kernel_size=3, stride=1, padding=1) - mu_x.pow(2)
138
+ sigma_y_sq = F.avg_pool2d(x_hat.pow(2), kernel_size=3, stride=1, padding=1) - mu_y.pow(2)
139
+ sigma_xy = F.avg_pool2d(x * x_hat, kernel_size=3, stride=1, padding=1) - mu_xy
140
+ ssim = ((2 * mu_xy + c1) * (2 * sigma_xy + c2)) / (
141
+ (mu_x.pow(2) + mu_y.pow(2) + c1) * (sigma_x_sq + sigma_y_sq + c2) + eps
142
+ )
143
+
144
+ return mse.item(), psnr.item(), ssim.mean().item()
145
+
146
+ def compute_lpips(self, x, x_hat):
147
+ """Computes LPIPS using a pre-trained VGG network.
148
+
149
+ Parameters
150
+ ----------
151
+ x : torch.Tensor
152
+ Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
153
+ x_hat : torch.Tensor
154
+ Generated images, same shape as `x`.
155
+
156
+ Returns
157
+ -------
158
+ float
159
+ Mean LPIPS score over the batch.
160
+
161
+ Raises
162
+ ------
163
+ RuntimeError
164
+ If `lpips=True` but `lpips_model` is not initialized.
165
+ """
166
+ if self.lpips_model is None:
167
+ raise RuntimeError("LPIPS model not initialized; set lpips=True in __init__")
168
+ if x.shape != x_hat.shape:
169
+ raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
170
+
171
+ x = x.to(self.device)
172
+ x_hat = x_hat.to(self.device)
173
+ return self.lpips_model(x, x_hat).mean().item()
174
+
175
+ def forward(self, x, x_hat):
176
+ """Computes specified metrics for ground truth and generated images.
177
+
178
+ Parameters
179
+ ----------
180
+ x : torch.Tensor
181
+ Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
182
+ x_hat : torch.Tensor
183
+ Generated images, same shape as `x`.
184
+
185
+ Returns
186
+ -------
187
+ tuple
188
+ A tuple containing:
189
+ - fid: FID score (float, or `float('inf')` if `fid=False` or fails).
190
+ - mse: Mean squared error (float, or None if `metrics=False`).
191
+ - psnr: Peak signal-to-noise ratio (float, or None if `metrics=False`).
192
+ - ssim: Structural similarity index (float, or None if `metrics=False`).
193
+ - lpips: LPIPS score (float, or None if `lpips=False`).
194
+ """
195
+ fid = float('inf')
196
+ mse, psnr, ssim = None, None, None
197
+ lpips_score = None
198
+
199
+ if self.metrics:
200
+ mse, psnr, ssim = self.compute_metrics(x, x_hat)
201
+ if self.fid:
202
+ fid = self.compute_fid(x, x_hat)
203
+ if self.lpips:
204
+ lpips_score = self.compute_lpips(x, x_hat)
205
+
206
+ return fid, mse, psnr, ssim, lpips_score