TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ldm/forward_idm.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ForwardSDE(nn.Module):
|
|
8
|
+
def __init__(self, hyper_params, method):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.hyper_params = hyper_params
|
|
11
|
+
self.method = method
|
|
12
|
+
|
|
13
|
+
def forward(self, x0, noise, time_steps):
|
|
14
|
+
|
|
15
|
+
dt = self.hyper_params.dt
|
|
16
|
+
if self.method == "ve":
|
|
17
|
+
sigma_t = self.hyper_params.sigmas[time_steps]
|
|
18
|
+
sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
|
|
19
|
+
sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
|
|
20
|
+
x0 = x0 + noise * sigma_diff.view(-1, 1, 1, 1)
|
|
21
|
+
|
|
22
|
+
elif self.method == "vp":
|
|
23
|
+
betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
|
|
24
|
+
drift = -0.5 * betas * x0 * dt
|
|
25
|
+
diffusion = torch.sqrt(betas * dt) * noise
|
|
26
|
+
x0 = x0 + drift + diffusion
|
|
27
|
+
|
|
28
|
+
elif self.method == "sub-vp":
|
|
29
|
+
betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
|
|
30
|
+
cum_betas = self.hyper_params.cum_betas[time_steps].view(-1, 1, 1, 1)
|
|
31
|
+
drift = -0.5 * betas * x0 * dt
|
|
32
|
+
diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise
|
|
33
|
+
x0 = x0 + drift + diffusion
|
|
34
|
+
|
|
35
|
+
elif self.method == "ode":
|
|
36
|
+
if self.method == "ve":
|
|
37
|
+
x0 = x0
|
|
38
|
+
else:
|
|
39
|
+
betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
|
|
40
|
+
drift = -0.5 * betas * x0 * dt
|
|
41
|
+
x0 = x0 + drift
|
|
42
|
+
else:
|
|
43
|
+
raise ValueError(f"Unknown method: {self.method}")
|
|
44
|
+
return x0
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ForwardDDPM(nn.Module):
|
|
49
|
+
"""forward diffusion process of DDPM."""
|
|
50
|
+
def __init__(self, hyper_params):
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.hyper_params = hyper_params
|
|
53
|
+
|
|
54
|
+
def forward(self, x0, noise, time_steps):
|
|
55
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.num_steps)):
|
|
56
|
+
raise ValueError(f"time_steps must be between 0 and {self.hyper_params.num_steps - 1}")
|
|
57
|
+
|
|
58
|
+
if self.hyper_params.trainable_beta:
|
|
59
|
+
_, _, _, sqrt_alpha_bar_t, sqrt_one_minus_alpha_bar_t = self.hyper_params.compute_schedule(
|
|
60
|
+
self.hyper_params.betas
|
|
61
|
+
)
|
|
62
|
+
sqrt_alpha_bar_t = sqrt_alpha_bar_t[time_steps].to(x0.device)
|
|
63
|
+
sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t[time_steps].to(x0.device)
|
|
64
|
+
else:
|
|
65
|
+
sqrt_alpha_bar_t = self.hyper_params.sqrt_alpha_bars[time_steps].to(x0.device)
|
|
66
|
+
sqrt_one_minus_alpha_bar_t = self.hyper_params.sqrt_one_minus_alpha_bars[time_steps].to(x0.device)
|
|
67
|
+
|
|
68
|
+
sqrt_alpha_bar_t = sqrt_alpha_bar_t.view(-1, 1, 1, 1)
|
|
69
|
+
sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.view(-1, 1, 1, 1)
|
|
70
|
+
|
|
71
|
+
xt = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise
|
|
72
|
+
return xt
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class ForwardDDIM(nn.Module):
|
|
77
|
+
def __init__(self, hyper_params):
|
|
78
|
+
"""forward diffusion process of DDIM"""
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.hyper_params = hyper_params
|
|
81
|
+
|
|
82
|
+
def forward(self, x0, noise, time_steps):
|
|
83
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.num_steps)):
|
|
84
|
+
raise ValueError(f"time_steps must be between 0 and {self.hyper_params.num_steps - 1}")
|
|
85
|
+
|
|
86
|
+
if self.hyper_params.trainable_beta:
|
|
87
|
+
_, _, _, sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = self.hyper_params.compute_schedule(
|
|
88
|
+
self.hyper_params.betas
|
|
89
|
+
)
|
|
90
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t[time_steps].to(x0.device)
|
|
91
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t[time_steps].to(x0.device)
|
|
92
|
+
else:
|
|
93
|
+
sqrt_alpha_cumprod_t = self.hyper_params.sqrt_alpha_cumprod[time_steps].to(x0.device)
|
|
94
|
+
sqrt_one_minus_alpha_cumprod_t = self.hyper_params.sqrt_one_minus_alpha_cumprod[time_steps].to(x0.device)
|
|
95
|
+
|
|
96
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1, 1, 1)
|
|
97
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1, 1, 1)
|
|
98
|
+
|
|
99
|
+
xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
|
|
100
|
+
return xt
|
ldm/hyper_param.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class HyperParamsDDPM(nn.Module):
|
|
7
|
+
"""hyperparameters for DDPM noise schedule with flexible beta computation."""
|
|
8
|
+
def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, trainable_beta=False, beta_method="linear"):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.num_steps = num_steps
|
|
11
|
+
self.beta_start = beta_start
|
|
12
|
+
self.beta_end = beta_end
|
|
13
|
+
self.trainable_beta = trainable_beta
|
|
14
|
+
self.beta_method = beta_method
|
|
15
|
+
|
|
16
|
+
# validate inputs
|
|
17
|
+
if not (0 < beta_start < beta_end < 1):
|
|
18
|
+
raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
|
|
19
|
+
if num_steps <= 0:
|
|
20
|
+
raise ValueError(f"num_steps ({num_steps}) must be positive")
|
|
21
|
+
|
|
22
|
+
# compute initial beta schedule
|
|
23
|
+
beta_range = (beta_start, beta_end)
|
|
24
|
+
betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
|
|
25
|
+
|
|
26
|
+
# initialize betas
|
|
27
|
+
if trainable_beta:
|
|
28
|
+
self.betas = nn.Parameter(betas_init) # Trainable parameter
|
|
29
|
+
else:
|
|
30
|
+
self.register_buffer('betas', betas_init) # Fixed buffer
|
|
31
|
+
self.register_buffer('alphas', 1 - self.betas)
|
|
32
|
+
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
|
|
33
|
+
self.register_buffer('sqrt_alpha_bars', torch.sqrt(self.alpha_bars))
|
|
34
|
+
self.register_buffer('sqrt_one_minus_alpha_bars', torch.sqrt(1 - self.alpha_bars))
|
|
35
|
+
|
|
36
|
+
def compute_beta_schedule(self, beta_range, num_steps, method):
|
|
37
|
+
"""
|
|
38
|
+
Computes the beta schedule based on the selected method.
|
|
39
|
+
Args:
|
|
40
|
+
beta_range: Tuple of (min_beta, max_beta) values
|
|
41
|
+
num_steps: Number of diffusion steps
|
|
42
|
+
method: Method for beta schedule ("linear", "sigmoid", "quadratic", "constant", "inverse_time")
|
|
43
|
+
Returns:
|
|
44
|
+
Tensor of beta values, shape [num_steps]
|
|
45
|
+
"""
|
|
46
|
+
beta_min, beta_max = beta_range
|
|
47
|
+
if method == "sigmoid":
|
|
48
|
+
x = torch.linspace(-6, 6, num_steps)
|
|
49
|
+
beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
|
|
50
|
+
elif method == "quadratic":
|
|
51
|
+
x = torch.linspace(beta_min**0.5, beta_max**0.5, num_steps)
|
|
52
|
+
beta = x**2
|
|
53
|
+
elif method == "constant":
|
|
54
|
+
beta = torch.full((num_steps,), beta_max)
|
|
55
|
+
elif method == "inverse_time":
|
|
56
|
+
beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
|
|
57
|
+
# scale to beta_range
|
|
58
|
+
beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
|
|
59
|
+
elif method == "linear":
|
|
60
|
+
beta = torch.linspace(beta_min, beta_max, num_steps)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
|
|
63
|
+
|
|
64
|
+
beta = torch.clamp(beta, min=beta_min, max=beta_max)
|
|
65
|
+
return beta
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def compute_schedule(betas):
|
|
69
|
+
"""computes noise schedule parameters dynamically from betas."""
|
|
70
|
+
alphas = 1 - betas
|
|
71
|
+
alpha_bars = torch.cumprod(alphas, dim=0)
|
|
72
|
+
return betas, alphas, alpha_bars, torch.sqrt(alpha_bars), torch.sqrt(1 - alpha_bars)
|
|
73
|
+
|
|
74
|
+
def constrain_betas(self):
|
|
75
|
+
"""constrains betas to a valid range during training."""
|
|
76
|
+
if self.trainable_beta:
|
|
77
|
+
with torch.no_grad():
|
|
78
|
+
self.betas.clamp_(min=self.beta_start, max=self.beta_end)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class HyperParamsDDIM(nn.Module):
|
|
85
|
+
"""Hyperparameters for DDIM noise schedule with flexible beta computation."""
|
|
86
|
+
def __init__(self, eta=None, num_steps=1000, tau_num_steps=100, beta_start=1e-4, beta_end=0.02,
|
|
87
|
+
trainable_beta=False, beta_method="linear"):
|
|
88
|
+
super().__init__()
|
|
89
|
+
self.eta = eta or 0
|
|
90
|
+
self.num_steps = num_steps
|
|
91
|
+
self.tau_num_steps = tau_num_steps
|
|
92
|
+
self.beta_start = beta_start
|
|
93
|
+
self.beta_end = beta_end
|
|
94
|
+
self.trainable_beta = trainable_beta
|
|
95
|
+
self.beta_method = beta_method
|
|
96
|
+
|
|
97
|
+
if not (0 < beta_start < beta_end < 1):
|
|
98
|
+
raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
|
|
99
|
+
if num_steps <= 0:
|
|
100
|
+
raise ValueError(f"num_steps ({num_steps}) must be positive")
|
|
101
|
+
|
|
102
|
+
beta_range = (beta_start, beta_end)
|
|
103
|
+
betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
|
|
104
|
+
|
|
105
|
+
if trainable_beta:
|
|
106
|
+
self.betas = nn.Parameter(betas_init)
|
|
107
|
+
else:
|
|
108
|
+
self.register_buffer('betas', betas_init)
|
|
109
|
+
self.register_buffer('alphas', 1 - self.betas)
|
|
110
|
+
self.register_buffer('alpha_cumprod', torch.cumprod(self.alphas, dim=0))
|
|
111
|
+
self.register_buffer('sqrt_alpha_cumprod', torch.sqrt(self.alpha_cumprod))
|
|
112
|
+
self.register_buffer('sqrt_one_minus_alpha_cumprod', torch.sqrt(1 - self.alpha_cumprod))
|
|
113
|
+
|
|
114
|
+
self.register_buffer('tau_indices', torch.linspace(0, num_steps - 1, tau_num_steps, dtype=torch.long))
|
|
115
|
+
|
|
116
|
+
def compute_beta_schedule(self, beta_range, num_steps, method):
|
|
117
|
+
|
|
118
|
+
beta_min, beta_max = beta_range
|
|
119
|
+
if method == "sigmoid":
|
|
120
|
+
x = torch.linspace(-6, 6, num_steps)
|
|
121
|
+
beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
|
|
122
|
+
elif method == "quadratic":
|
|
123
|
+
x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
|
|
124
|
+
beta = x ** 2
|
|
125
|
+
elif method == "constant":
|
|
126
|
+
beta = torch.full((num_steps,), beta_max)
|
|
127
|
+
elif method == "inverse_time":
|
|
128
|
+
beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
|
|
129
|
+
# scale to beta_range
|
|
130
|
+
beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
|
|
131
|
+
elif method == "linear":
|
|
132
|
+
beta = torch.linspace(beta_min, beta_max, num_steps)
|
|
133
|
+
else:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
|
|
136
|
+
|
|
137
|
+
beta = torch.clamp(beta, min=beta_min, max=beta_max)
|
|
138
|
+
return beta
|
|
139
|
+
|
|
140
|
+
def get_tau_schedule(self):
|
|
141
|
+
if self.trainable_beta:
|
|
142
|
+
betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod = self.compute_schedule(self.betas)
|
|
143
|
+
else:
|
|
144
|
+
betas = self.betas
|
|
145
|
+
alphas = self.alphas
|
|
146
|
+
alpha_cumprod = self.alpha_cumprod
|
|
147
|
+
sqrt_alpha_cumprod = self.sqrt_alpha_cumprod
|
|
148
|
+
sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod
|
|
149
|
+
|
|
150
|
+
tau_betas = betas[self.tau_indices]
|
|
151
|
+
tau_alphas = alphas[self.tau_indices]
|
|
152
|
+
tau_alpha_cumprod = alpha_cumprod[self.tau_indices]
|
|
153
|
+
tau_sqrt_alpha_cumprod = sqrt_alpha_cumprod[self.tau_indices]
|
|
154
|
+
tau_sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[self.tau_indices]
|
|
155
|
+
|
|
156
|
+
return tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod
|
|
157
|
+
|
|
158
|
+
def compute_schedule(self, betas):
|
|
159
|
+
"""computes noise schedule parameters dynamically from betas."""
|
|
160
|
+
alphas = 1 - betas
|
|
161
|
+
alpha_cumprod = torch.cumprod(alphas, dim=0)
|
|
162
|
+
return betas, alphas, alpha_cumprod, torch.sqrt(alpha_cumprod), torch.sqrt(1 - alpha_cumprod)
|
|
163
|
+
|
|
164
|
+
def constrain_betas(self):
|
|
165
|
+
"""constrains betas to a valid range during training."""
|
|
166
|
+
if self.trainable_beta:
|
|
167
|
+
with torch.no_grad():
|
|
168
|
+
self.betas.clamp_(min=self.beta_start, max=self.beta_end)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class HyperParamsSDE(nn.Module):
|
|
174
|
+
def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, trainable_beta=False, beta_method="linear",
|
|
175
|
+
sigma_start=1e-3, sigma_end=10.0, start=0.0, end=1.0):
|
|
176
|
+
super().__init__()
|
|
177
|
+
self.num_steps = num_steps
|
|
178
|
+
self.beta_start = beta_start
|
|
179
|
+
self.beta_end = beta_end
|
|
180
|
+
self.trainable_beta = trainable_beta
|
|
181
|
+
self.beta_method = beta_method
|
|
182
|
+
self.sigma_start = sigma_start
|
|
183
|
+
self.sigma_end = sigma_end
|
|
184
|
+
self.start = start
|
|
185
|
+
self.end = end
|
|
186
|
+
|
|
187
|
+
if not (0 < self.beta_start < self.beta_end):
|
|
188
|
+
raise ValueError(f"beta_start ({self.beta_start}) and beta_end ({self.beta_end}) must satisfy 0 < start < end")
|
|
189
|
+
if not (0 < self.sigma_start < self.sigma_end):
|
|
190
|
+
raise ValueError(f"sigma_start ({self.sigma_start}) and sigma_end ({self.sigma_end}) must satisfy 0 < start < end")
|
|
191
|
+
if self.num_steps <= 0:
|
|
192
|
+
raise ValueError(f"num_steps ({self.num_steps}) must be positive")
|
|
193
|
+
|
|
194
|
+
beta_range = (beta_start, beta_end)
|
|
195
|
+
betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
|
|
196
|
+
self.time = torch.linspace(self.start, self.end, self.num_steps, dtype=torch.float32)
|
|
197
|
+
self.dt = (self.end - self.start) / self.num_steps
|
|
198
|
+
|
|
199
|
+
if trainable_beta:
|
|
200
|
+
self.betas = nn.Parameter(betas_init)
|
|
201
|
+
else:
|
|
202
|
+
self.register_buffer('betas', betas_init)
|
|
203
|
+
self.register_buffer('cum_betas', torch.cumsum(betas_init, dim=0) * self.dt)
|
|
204
|
+
self.register_buffer("sigmas", self.sigma_start * (self.sigma_end / self.sigma_start) ** self.time)
|
|
205
|
+
|
|
206
|
+
def compute_beta_schedule(self, beta_range, num_steps, method):
|
|
207
|
+
beta_min, beta_max = beta_range
|
|
208
|
+
if method == "sigmoid":
|
|
209
|
+
x = torch.linspace(-6, 6, num_steps)
|
|
210
|
+
beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
|
|
211
|
+
elif method == "quadratic":
|
|
212
|
+
x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
|
|
213
|
+
beta = x ** 2
|
|
214
|
+
elif method == "constant":
|
|
215
|
+
beta = torch.full((num_steps,), beta_max)
|
|
216
|
+
elif method == "inverse_time":
|
|
217
|
+
beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
|
|
218
|
+
beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
|
|
219
|
+
elif method == "linear":
|
|
220
|
+
beta = torch.linspace(beta_min, beta_max, num_steps)
|
|
221
|
+
else:
|
|
222
|
+
raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
|
|
223
|
+
beta = torch.clamp(beta, min=beta_min, max=beta_max)
|
|
224
|
+
return beta
|
|
225
|
+
|
|
226
|
+
def constrain_betas(self):
|
|
227
|
+
if self.trainable_beta:
|
|
228
|
+
with torch.no_grad():
|
|
229
|
+
self.betas.clamp_(min=self.beta_start, max=self.beta_end)
|
|
230
|
+
|
|
231
|
+
def get_variance(self, time_steps, method):
|
|
232
|
+
if method == "ve":
|
|
233
|
+
return self.sigmas[time_steps] ** 2
|
|
234
|
+
elif method == "vp":
|
|
235
|
+
return 1 - torch.exp(-self.cum_betas[time_steps])
|
|
236
|
+
elif method == "sub-vp":
|
|
237
|
+
return 1 - torch.exp(-2 * self.cum_betas[time_steps])
|
|
238
|
+
else:
|
|
239
|
+
raise ValueError(f"Unknown method: {method}")
|
ldm/metrics.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
import lpips
|
|
4
|
+
from pytorch_fid import fid_score
|
|
5
|
+
import shutil
|
|
6
|
+
from torchvision.utils import save_image
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Metrics:
|
|
11
|
+
"""Computes image quality metrics for evaluating diffusion models.
|
|
12
|
+
|
|
13
|
+
Supports Mean Squared Error (MSE), Peak Signal-to-Noise Ratio (PSNR), Structural
|
|
14
|
+
Similarity Index (SSIM), Fréchet Inception Distance (FID), and Learned Perceptual
|
|
15
|
+
Image Patch Similarity (LPIPS) for comparing generated and ground truth images.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
device : str, optional
|
|
20
|
+
Device for computation (e.g., 'cuda', 'cpu') (default: 'cuda').
|
|
21
|
+
fid : bool, optional
|
|
22
|
+
If True, compute FID score (default: True).
|
|
23
|
+
metrics : bool, optional
|
|
24
|
+
If True, compute MSE, PSNR, and SSIM (default: False).
|
|
25
|
+
lpips : bool, optional
|
|
26
|
+
If True, compute LPIPS using VGG backbone (default: False).
|
|
27
|
+
|
|
28
|
+
Attributes
|
|
29
|
+
----------
|
|
30
|
+
device : str
|
|
31
|
+
Computation device.
|
|
32
|
+
fid : bool
|
|
33
|
+
Flag for FID computation.
|
|
34
|
+
metrics : bool
|
|
35
|
+
Flag for MSE, PSNR, SSIM computation.
|
|
36
|
+
lpips : bool
|
|
37
|
+
Flag for LPIPS computation.
|
|
38
|
+
lpips_model : lpips.LPIPS or None
|
|
39
|
+
LPIPS model (VGG backbone) if `lpips=True`; otherwise, None.
|
|
40
|
+
temp_dir_real : str
|
|
41
|
+
Temporary directory for real images during FID computation.
|
|
42
|
+
temp_dir_fake : str
|
|
43
|
+
Temporary directory for fake (generated) images during FID computation.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, device="cuda", fid=True, metrics=False, lpips=False):
|
|
47
|
+
self.device = device
|
|
48
|
+
self.fid = fid
|
|
49
|
+
self.metrics = metrics
|
|
50
|
+
self.lpips = lpips
|
|
51
|
+
self.lpips_model = lpips.LPIPS(net='vgg').to(device) if lpips else None
|
|
52
|
+
self.temp_dir_real = "temp_real"
|
|
53
|
+
self.temp_dir_fake = "temp_fake"
|
|
54
|
+
|
|
55
|
+
def compute_fid(self, real_images, fake_images):
|
|
56
|
+
"""Computes the Fréchet Inception Distance (FID) between real and generated images.
|
|
57
|
+
|
|
58
|
+
Saves images to temporary directories and uses Inception V3 to compute FID,
|
|
59
|
+
cleaning up directories afterward.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
real_images : torch.Tensor
|
|
64
|
+
Real images, shape (batch_size, channels, height, width), in [-1, 1].
|
|
65
|
+
fake_images : torch.Tensor
|
|
66
|
+
Generated images, same shape, in [-1, 1].
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
float
|
|
71
|
+
FID score, or `float('inf')` if computation fails.
|
|
72
|
+
|
|
73
|
+
Notes
|
|
74
|
+
-----
|
|
75
|
+
- Images are normalized to [0, 1] and saved as PNG files for FID computation.
|
|
76
|
+
- Uses Inception V3 with 2048-dimensional features (`dims=2048`).
|
|
77
|
+
"""
|
|
78
|
+
if real_images.shape != fake_images.shape:
|
|
79
|
+
raise ValueError(f"Shape mismatch: real_images {real_images.shape}, fake_images {fake_images.shape}")
|
|
80
|
+
|
|
81
|
+
real_images = (real_images + 1) / 2
|
|
82
|
+
fake_images = (fake_images + 1) / 2
|
|
83
|
+
real_images = real_images.clamp(0, 1).cpu()
|
|
84
|
+
fake_images = fake_images.clamp(0, 1).cpu()
|
|
85
|
+
|
|
86
|
+
os.makedirs(self.temp_dir_real, exist_ok=True)
|
|
87
|
+
os.makedirs(self.temp_dir_fake, exist_ok=True)
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
for i, (real, fake) in enumerate(zip(real_images, fake_images)):
|
|
91
|
+
save_image(real, f"{self.temp_dir_real}/{i}.png")
|
|
92
|
+
save_image(fake, f"{self.temp_dir_fake}/{i}.png")
|
|
93
|
+
|
|
94
|
+
fid = fid_score.calculate_fid_given_paths(
|
|
95
|
+
paths=[self.temp_dir_real, self.temp_dir_fake],
|
|
96
|
+
batch_size=50,
|
|
97
|
+
device=self.device,
|
|
98
|
+
dims=2048
|
|
99
|
+
)
|
|
100
|
+
except Exception as e:
|
|
101
|
+
print(f"Error computing FID: {e}")
|
|
102
|
+
fid = float('inf')
|
|
103
|
+
finally:
|
|
104
|
+
shutil.rmtree(self.temp_dir_real, ignore_errors=True)
|
|
105
|
+
shutil.rmtree(self.temp_dir_fake, ignore_errors=True)
|
|
106
|
+
|
|
107
|
+
return fid
|
|
108
|
+
|
|
109
|
+
def compute_metrics(self, x, x_hat):
|
|
110
|
+
"""Computes MSE, PSNR, and SSIM for evaluating image quality.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
x : torch.Tensor
|
|
115
|
+
Ground truth images, shape (batch_size, channels, height, width).
|
|
116
|
+
x_hat : torch.Tensor
|
|
117
|
+
Generated images, same shape as `x`.
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
tuple
|
|
122
|
+
Tuple of (mse, psnr, ssim) as floats, where:
|
|
123
|
+
- mse: Mean squared error.
|
|
124
|
+
- psnr: Peak signal-to-noise ratio.
|
|
125
|
+
- ssim: Structural similarity index (mean over batch).
|
|
126
|
+
"""
|
|
127
|
+
if x.shape != x_hat.shape:
|
|
128
|
+
raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
|
|
129
|
+
|
|
130
|
+
mse = F.mse_loss(x_hat, x)
|
|
131
|
+
psnr = -10 * torch.log10(mse)
|
|
132
|
+
c1, c2 = (0.01 * 2) ** 2, (0.03 * 2) ** 2 # Adjusted for [-1, 1] range
|
|
133
|
+
eps = 1e-8
|
|
134
|
+
mu_x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
|
135
|
+
mu_y = F.avg_pool2d(x_hat, kernel_size=3, stride=1, padding=1)
|
|
136
|
+
mu_xy = mu_x * mu_y
|
|
137
|
+
sigma_x_sq = F.avg_pool2d(x.pow(2), kernel_size=3, stride=1, padding=1) - mu_x.pow(2)
|
|
138
|
+
sigma_y_sq = F.avg_pool2d(x_hat.pow(2), kernel_size=3, stride=1, padding=1) - mu_y.pow(2)
|
|
139
|
+
sigma_xy = F.avg_pool2d(x * x_hat, kernel_size=3, stride=1, padding=1) - mu_xy
|
|
140
|
+
ssim = ((2 * mu_xy + c1) * (2 * sigma_xy + c2)) / (
|
|
141
|
+
(mu_x.pow(2) + mu_y.pow(2) + c1) * (sigma_x_sq + sigma_y_sq + c2) + eps
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
return mse.item(), psnr.item(), ssim.mean().item()
|
|
145
|
+
|
|
146
|
+
def compute_lpips(self, x, x_hat):
|
|
147
|
+
"""Computes LPIPS using a pre-trained VGG network.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
x : torch.Tensor
|
|
152
|
+
Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
|
|
153
|
+
x_hat : torch.Tensor
|
|
154
|
+
Generated images, same shape as `x`.
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
float
|
|
159
|
+
Mean LPIPS score over the batch.
|
|
160
|
+
|
|
161
|
+
Raises
|
|
162
|
+
------
|
|
163
|
+
RuntimeError
|
|
164
|
+
If `lpips=True` but `lpips_model` is not initialized.
|
|
165
|
+
"""
|
|
166
|
+
if self.lpips_model is None:
|
|
167
|
+
raise RuntimeError("LPIPS model not initialized; set lpips=True in __init__")
|
|
168
|
+
if x.shape != x_hat.shape:
|
|
169
|
+
raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
|
|
170
|
+
|
|
171
|
+
x = x.to(self.device)
|
|
172
|
+
x_hat = x_hat.to(self.device)
|
|
173
|
+
return self.lpips_model(x, x_hat).mean().item()
|
|
174
|
+
|
|
175
|
+
def forward(self, x, x_hat):
|
|
176
|
+
"""Computes specified metrics for ground truth and generated images.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
x : torch.Tensor
|
|
181
|
+
Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
|
|
182
|
+
x_hat : torch.Tensor
|
|
183
|
+
Generated images, same shape as `x`.
|
|
184
|
+
|
|
185
|
+
Returns
|
|
186
|
+
-------
|
|
187
|
+
tuple
|
|
188
|
+
A tuple containing:
|
|
189
|
+
- fid: FID score (float, or `float('inf')` if `fid=False` or fails).
|
|
190
|
+
- mse: Mean squared error (float, or None if `metrics=False`).
|
|
191
|
+
- psnr: Peak signal-to-noise ratio (float, or None if `metrics=False`).
|
|
192
|
+
- ssim: Structural similarity index (float, or None if `metrics=False`).
|
|
193
|
+
- lpips: LPIPS score (float, or None if `lpips=False`).
|
|
194
|
+
"""
|
|
195
|
+
fid = float('inf')
|
|
196
|
+
mse, psnr, ssim = None, None, None
|
|
197
|
+
lpips_score = None
|
|
198
|
+
|
|
199
|
+
if self.metrics:
|
|
200
|
+
mse, psnr, ssim = self.compute_metrics(x, x_hat)
|
|
201
|
+
if self.fid:
|
|
202
|
+
fid = self.compute_fid(x, x_hat)
|
|
203
|
+
if self.lpips:
|
|
204
|
+
lpips_score = self.compute_lpips(x, x_hat)
|
|
205
|
+
|
|
206
|
+
return fid, mse, psnr, ssim, lpips_score
|