TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ddim/__init__.py
ADDED
|
File without changes
|
ddim/forward_ddim.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ForwardDDIM(nn.Module):
|
|
7
|
+
"""Forward diffusion process of DDIM.
|
|
8
|
+
|
|
9
|
+
Implements the forward diffusion process for Denoising Diffusion Implicit Models
|
|
10
|
+
(DDIM), which perturbs input data by adding Gaussian noise over a series of time
|
|
11
|
+
steps, as defined in Song et al. (2021, "Denoising Diffusion Implicit Models").
|
|
12
|
+
|
|
13
|
+
Parameters
|
|
14
|
+
----------
|
|
15
|
+
hyper_params : object
|
|
16
|
+
Hyperparameter object containing the noise schedule parameters. Expected to have
|
|
17
|
+
attributes:
|
|
18
|
+
- `num_steps`: Number of diffusion steps (int).
|
|
19
|
+
- `trainable_beta`: Whether the noise schedule is trainable (bool).
|
|
20
|
+
- `betas`: Noise schedule parameters (torch.Tensor, optional if trainable_beta is True).
|
|
21
|
+
- `sqrt_alpha_cumprod`: Precomputed cumulative product of alphas (torch.Tensor, optional if trainable_beta is False).
|
|
22
|
+
- `sqrt_one_minus_alpha_cumprod`: Precomputed square root of one minus cumulative alpha product (torch.Tensor, optional if trainable_beta is False).
|
|
23
|
+
- `compute_schedule`: Method to compute the noise schedule (callable, optional if trainable_beta is True).
|
|
24
|
+
|
|
25
|
+
Attributes
|
|
26
|
+
----------
|
|
27
|
+
hyper_params : object
|
|
28
|
+
Stores the provided hyperparameter object.
|
|
29
|
+
"""
|
|
30
|
+
def __init__(self, hyper_params):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.hyper_params = hyper_params
|
|
33
|
+
|
|
34
|
+
def forward(self, x0, noise, time_steps):
|
|
35
|
+
"""Applies the forward diffusion process to the input data.
|
|
36
|
+
|
|
37
|
+
Perturbs the input data `x0` by adding Gaussian noise according to the DDIM
|
|
38
|
+
forward process at specified time steps, using cumulative noise schedule
|
|
39
|
+
parameters.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
x0 : torch.Tensor
|
|
44
|
+
Input data tensor of shape (batch_size, channels, height, width).
|
|
45
|
+
noise : torch.Tensor
|
|
46
|
+
Gaussian noise tensor of the same shape as `x0`.
|
|
47
|
+
time_steps : torch.Tensor
|
|
48
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
49
|
+
is in the range [0, hyper_params.num_steps - 1].
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
torch.Tensor
|
|
54
|
+
Noisy data tensor `xt` at the specified time steps, with the same shape as `x0`.
|
|
55
|
+
|
|
56
|
+
Raises
|
|
57
|
+
------
|
|
58
|
+
ValueError
|
|
59
|
+
If any value in `time_steps` is outside the valid range
|
|
60
|
+
[0, hyper_params.num_steps - 1].
|
|
61
|
+
"""
|
|
62
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.num_steps)):
|
|
63
|
+
raise ValueError(f"time_steps must be between 0 and {self.hyper_params.num_steps - 1}")
|
|
64
|
+
|
|
65
|
+
if self.hyper_params.trainable_beta:
|
|
66
|
+
_, _, _, sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = self.hyper_params.compute_schedule(
|
|
67
|
+
self.hyper_params.betas
|
|
68
|
+
)
|
|
69
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t[time_steps].to(x0.device)
|
|
70
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t[time_steps].to(x0.device)
|
|
71
|
+
else:
|
|
72
|
+
sqrt_alpha_cumprod_t = self.hyper_params.sqrt_alpha_cumprod[time_steps].to(x0.device)
|
|
73
|
+
sqrt_one_minus_alpha_cumprod_t = self.hyper_params.sqrt_one_minus_alpha_cumprod[time_steps].to(x0.device)
|
|
74
|
+
|
|
75
|
+
sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1, 1, 1)
|
|
76
|
+
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1, 1, 1)
|
|
77
|
+
|
|
78
|
+
xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
|
|
79
|
+
return xt
|
ddim/hyper_param.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""Hyperparameters for Denoising Diffusion Implicit Models (DDIM) noise schedule.
|
|
2
|
+
|
|
3
|
+
This module implements a flexible noise schedule for DDIM, as described in Song et al.
|
|
4
|
+
(2021, "Denoising Diffusion Implicit Models"). It supports multiple beta schedule methods,
|
|
5
|
+
trainable or fixed noise schedules, and a subsampled time step schedule for faster sampling.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class HyperParamsDDIM(nn.Module):
|
|
15
|
+
"""Hyperparameters for DDIM noise schedule with flexible beta computation.
|
|
16
|
+
|
|
17
|
+
Manages the noise schedule parameters for DDIM, including beta values, derived
|
|
18
|
+
quantities (alphas, alpha_cumprod, etc.), and a subsampled time step schedule
|
|
19
|
+
(tau schedule), as inspired by Song et al. (2021). Supports trainable or fixed
|
|
20
|
+
schedules and various beta scheduling methods.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
eta : float, optional
|
|
25
|
+
Noise scaling factor for the DDIM reverse process (default: 0, deterministic).
|
|
26
|
+
num_steps : int, optional
|
|
27
|
+
Total number of diffusion steps (default: 1000).
|
|
28
|
+
tau_num_steps : int, optional
|
|
29
|
+
Number of subsampled time steps for DDIM sampling (default: 100).
|
|
30
|
+
beta_start : float, optional
|
|
31
|
+
Starting value for beta (default: 1e-4).
|
|
32
|
+
beta_end : float, optional
|
|
33
|
+
Ending value for beta (default: 0.02).
|
|
34
|
+
trainable_beta : bool, optional
|
|
35
|
+
Whether the beta schedule is trainable (default: False).
|
|
36
|
+
beta_method : str, optional
|
|
37
|
+
Method for computing the beta schedule (default: "linear").
|
|
38
|
+
Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
39
|
+
|
|
40
|
+
Attributes
|
|
41
|
+
----------
|
|
42
|
+
eta : float
|
|
43
|
+
Noise scaling factor for the reverse process.
|
|
44
|
+
num_steps : int
|
|
45
|
+
Total number of diffusion steps.
|
|
46
|
+
tau_num_steps : int
|
|
47
|
+
Number of subsampled time steps.
|
|
48
|
+
beta_start : float
|
|
49
|
+
Minimum beta value.
|
|
50
|
+
beta_end : float
|
|
51
|
+
Maximum beta value.
|
|
52
|
+
trainable_beta : bool
|
|
53
|
+
Whether the beta schedule is trainable.
|
|
54
|
+
beta_method : str
|
|
55
|
+
Method used for beta schedule computation.
|
|
56
|
+
betas : torch.Tensor
|
|
57
|
+
Beta schedule values, shape (num_steps,). Trainable if `trainable_beta` is True,
|
|
58
|
+
otherwise a fixed buffer.
|
|
59
|
+
alphas : torch.Tensor, optional
|
|
60
|
+
Alpha values (1 - betas), shape (num_steps,). Available if `trainable_beta` is False.
|
|
61
|
+
alpha_cumprod : torch.Tensor, optional
|
|
62
|
+
Cumulative product of alphas, shape (num_steps,). Available if `trainable_beta` is False.
|
|
63
|
+
sqrt_alpha_cumprod : torch.Tensor, optional
|
|
64
|
+
Square root of alpha_cumprod, shape (num_steps,). Available if `trainable_beta` is False.
|
|
65
|
+
sqrt_one_minus_alpha_cumprod : torch.Tensor, optional
|
|
66
|
+
Square root of (1 - alpha_cumprod), shape (num_steps,). Available if `trainable_beta` is False.
|
|
67
|
+
tau_indices : torch.Tensor
|
|
68
|
+
Indices for subsampled time steps, shape (tau_num_steps,).
|
|
69
|
+
|
|
70
|
+
Raises
|
|
71
|
+
------
|
|
72
|
+
ValueError
|
|
73
|
+
If `beta_start` or `beta_end` do not satisfy 0 < beta_start < beta_end < 1,
|
|
74
|
+
or if `num_steps` is not positive.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(self, eta=None, num_steps=1000, tau_num_steps=100, beta_start=1e-4, beta_end=0.02,
|
|
78
|
+
trainable_beta=False, beta_method="linear"):
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.eta = eta or 0
|
|
81
|
+
self.num_steps = num_steps
|
|
82
|
+
self.tau_num_steps = tau_num_steps
|
|
83
|
+
self.beta_start = beta_start
|
|
84
|
+
self.beta_end = beta_end
|
|
85
|
+
self.trainable_beta = trainable_beta
|
|
86
|
+
self.beta_method = beta_method
|
|
87
|
+
|
|
88
|
+
if not (0 < beta_start < beta_end < 1):
|
|
89
|
+
raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
|
|
90
|
+
if num_steps <= 0:
|
|
91
|
+
raise ValueError(f"num_steps ({num_steps}) must be positive")
|
|
92
|
+
|
|
93
|
+
beta_range = (beta_start, beta_end)
|
|
94
|
+
betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
|
|
95
|
+
|
|
96
|
+
if trainable_beta:
|
|
97
|
+
self.betas = nn.Parameter(betas_init)
|
|
98
|
+
else:
|
|
99
|
+
self.register_buffer('betas', betas_init)
|
|
100
|
+
self.register_buffer('alphas', 1 - self.betas)
|
|
101
|
+
self.register_buffer('alpha_cumprod', torch.cumprod(self.alphas, dim=0))
|
|
102
|
+
self.register_buffer('sqrt_alpha_cumprod', torch.sqrt(self.alpha_cumprod))
|
|
103
|
+
self.register_buffer('sqrt_one_minus_alpha_cumprod', torch.sqrt(1 - self.alpha_cumprod))
|
|
104
|
+
|
|
105
|
+
self.register_buffer('tau_indices', torch.linspace(0, num_steps - 1, tau_num_steps, dtype=torch.long))
|
|
106
|
+
|
|
107
|
+
def compute_beta_schedule(self, beta_range, num_steps, method):
|
|
108
|
+
"""Computes the beta schedule based on the specified method.
|
|
109
|
+
|
|
110
|
+
Generates a sequence of beta values for the DDIM noise schedule using the
|
|
111
|
+
chosen method, ensuring values are clamped within the specified range.
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
beta_range : tuple
|
|
116
|
+
Tuple of (min_beta, max_beta) specifying the valid range for beta values.
|
|
117
|
+
num_steps : int
|
|
118
|
+
Number of diffusion steps.
|
|
119
|
+
method : str
|
|
120
|
+
Method for computing the beta schedule. Supported methods:
|
|
121
|
+
"linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
122
|
+
|
|
123
|
+
Returns
|
|
124
|
+
-------
|
|
125
|
+
torch.Tensor
|
|
126
|
+
Tensor of beta values, shape (num_steps,).
|
|
127
|
+
|
|
128
|
+
Raises
|
|
129
|
+
------
|
|
130
|
+
ValueError
|
|
131
|
+
If `method` is not one of the supported beta schedule methods.
|
|
132
|
+
"""
|
|
133
|
+
beta_min, beta_max = beta_range
|
|
134
|
+
if method == "sigmoid":
|
|
135
|
+
x = torch.linspace(-6, 6, num_steps)
|
|
136
|
+
beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
|
|
137
|
+
elif method == "quadratic":
|
|
138
|
+
x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
|
|
139
|
+
beta = x ** 2
|
|
140
|
+
elif method == "constant":
|
|
141
|
+
beta = torch.full((num_steps,), beta_max)
|
|
142
|
+
elif method == "inverse_time":
|
|
143
|
+
beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
|
|
144
|
+
# scale to beta_range
|
|
145
|
+
beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
|
|
146
|
+
elif method == "linear":
|
|
147
|
+
beta = torch.linspace(beta_min, beta_max, num_steps)
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
|
|
151
|
+
|
|
152
|
+
beta = torch.clamp(beta, min=beta_min, max=beta_max)
|
|
153
|
+
return beta
|
|
154
|
+
|
|
155
|
+
def get_tau_schedule(self):
|
|
156
|
+
"""Computes the subsampled (tau) noise schedule for DDIM.
|
|
157
|
+
|
|
158
|
+
Returns the noise schedule parameters for the subsampled time steps used in
|
|
159
|
+
DDIM sampling, based on the `tau_indices`.
|
|
160
|
+
|
|
161
|
+
Returns
|
|
162
|
+
-------
|
|
163
|
+
tuple
|
|
164
|
+
A tuple containing:
|
|
165
|
+
- tau_betas: Beta values for subsampled steps, shape (tau_num_steps,).
|
|
166
|
+
- tau_alphas: Alpha values for subsampled steps, shape (tau_num_steps,).
|
|
167
|
+
- tau_alpha_cumprod: Cumulative product of alphas for subsampled steps, shape (tau_num_steps,).
|
|
168
|
+
- tau_sqrt_alpha_cumprod: Square root of alpha_cumprod for subsampled steps, shape (tau_num_steps,).
|
|
169
|
+
- tau_sqrt_one_minus_alpha_cumprod: Square root of (1 - alpha_cumprod) for subsampled steps, shape (tau_num_steps,).
|
|
170
|
+
"""
|
|
171
|
+
if self.trainable_beta:
|
|
172
|
+
betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod = self.compute_schedule(self.betas)
|
|
173
|
+
else:
|
|
174
|
+
betas = self.betas
|
|
175
|
+
alphas = self.alphas
|
|
176
|
+
alpha_cumprod = self.alpha_cumprod
|
|
177
|
+
sqrt_alpha_cumprod = self.sqrt_alpha_cumprod
|
|
178
|
+
sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod
|
|
179
|
+
|
|
180
|
+
tau_betas = betas[self.tau_indices]
|
|
181
|
+
tau_alphas = alphas[self.tau_indices]
|
|
182
|
+
tau_alpha_cumprod = alpha_cumprod[self.tau_indices]
|
|
183
|
+
tau_sqrt_alpha_cumprod = sqrt_alpha_cumprod[self.tau_indices]
|
|
184
|
+
tau_sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[self.tau_indices]
|
|
185
|
+
|
|
186
|
+
return tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod
|
|
187
|
+
|
|
188
|
+
def compute_schedule(self, betas):
|
|
189
|
+
"""Computes noise schedule parameters dynamically from betas.
|
|
190
|
+
|
|
191
|
+
Calculates the derived noise schedule parameters (alphas, alpha_cumprod, etc.)
|
|
192
|
+
from the provided beta values, as used in the DDIM forward and reverse processes.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
betas : torch.Tensor
|
|
197
|
+
Tensor of beta values, shape (num_steps,).
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
tuple
|
|
202
|
+
A tuple containing:
|
|
203
|
+
- betas: Input beta values, shape (num_steps,).
|
|
204
|
+
- alphas: 1 - betas, shape (num_steps,).
|
|
205
|
+
- alpha_cumprod: Cumulative product of alphas, shape (num_steps,).
|
|
206
|
+
- sqrt_alpha_cumprod: Square root of alpha_cumprod, shape (num_steps,).
|
|
207
|
+
- sqrt_one_minus_alpha_cumprod: Square root of (1 - alpha_cumprod), shape (num_steps,).
|
|
208
|
+
"""
|
|
209
|
+
alphas = 1 - betas
|
|
210
|
+
alpha_cumprod = torch.cumprod(alphas, dim=0)
|
|
211
|
+
return betas, alphas, alpha_cumprod, torch.sqrt(alpha_cumprod), torch.sqrt(1 - alpha_cumprod)
|
|
212
|
+
|
|
213
|
+
def constrain_betas(self):
|
|
214
|
+
"""Constrains trainable betas to a valid range during training.
|
|
215
|
+
|
|
216
|
+
Ensures that trainable beta values remain within the specified range
|
|
217
|
+
[beta_start, beta_end] by clamping them in-place.
|
|
218
|
+
|
|
219
|
+
Notes
|
|
220
|
+
-----
|
|
221
|
+
This method only applies when `trainable_beta` is True.
|
|
222
|
+
"""
|
|
223
|
+
if self.trainable_beta:
|
|
224
|
+
with torch.no_grad():
|
|
225
|
+
self.betas.clamp_(min=self.beta_start, max=self.beta_end)
|