TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
sde/hyper_param.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class HyperParamsSDE(nn.Module):
|
|
8
|
+
"""Hyperparameters for SDE-based generative models.
|
|
9
|
+
|
|
10
|
+
Manages the noise schedule and SDE-specific parameters for score-based generative
|
|
11
|
+
models, including beta and sigma schedules, time steps, and variance computations,
|
|
12
|
+
as described in Song et al. (2021). Supports trainable or fixed beta schedules and
|
|
13
|
+
multiple scheduling methods for flexible noise control.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
num_steps : int, optional
|
|
18
|
+
Number of diffusion steps (default: 1000).
|
|
19
|
+
beta_start : float, optional
|
|
20
|
+
Starting value for beta schedule (default: 1e-4).
|
|
21
|
+
beta_end : float, optional
|
|
22
|
+
Ending value for beta schedule (default: 0.02).
|
|
23
|
+
trainable_beta : bool, optional
|
|
24
|
+
Whether the beta schedule is trainable (default: False).
|
|
25
|
+
beta_method : str, optional
|
|
26
|
+
Method for computing the beta schedule (default: "linear").
|
|
27
|
+
Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
28
|
+
sigma_start : float, optional
|
|
29
|
+
Starting value for sigma schedule for VE method (default: 1e-3).
|
|
30
|
+
sigma_end : float, optional
|
|
31
|
+
Ending value for sigma schedule for VE method (default: 10.0).
|
|
32
|
+
start : float, optional
|
|
33
|
+
Start of the time interval for SDE integration (default: 0.0).
|
|
34
|
+
end : float, optional
|
|
35
|
+
End of the time interval for SDE integration (default: 1.0).
|
|
36
|
+
|
|
37
|
+
Attributes
|
|
38
|
+
----------
|
|
39
|
+
num_steps : int
|
|
40
|
+
Number of diffusion steps.
|
|
41
|
+
beta_start : float
|
|
42
|
+
Minimum beta value.
|
|
43
|
+
beta_end : float
|
|
44
|
+
Maximum beta value.
|
|
45
|
+
trainable_beta : bool
|
|
46
|
+
Whether the beta schedule is trainable.
|
|
47
|
+
beta_method : str
|
|
48
|
+
Method used for beta schedule computation.
|
|
49
|
+
sigma_start : float
|
|
50
|
+
Minimum sigma value for VE method.
|
|
51
|
+
sigma_end : float
|
|
52
|
+
Maximum sigma value for VE method.
|
|
53
|
+
start : float
|
|
54
|
+
Start of the time interval.
|
|
55
|
+
end : float
|
|
56
|
+
End of the time interval.
|
|
57
|
+
betas : torch.Tensor
|
|
58
|
+
Beta schedule values, shape (num_steps,). Trainable if `trainable_beta` is True,
|
|
59
|
+
otherwise a fixed buffer.
|
|
60
|
+
cum_betas : torch.Tensor, optional
|
|
61
|
+
Cumulative sum of betas scaled by `dt`, shape (num_steps,). Available if
|
|
62
|
+
`trainable_beta` is False.
|
|
63
|
+
sigmas : torch.Tensor, optional
|
|
64
|
+
Sigma schedule for VE method, shape (num_steps,). Available if
|
|
65
|
+
`trainable_beta` is False.
|
|
66
|
+
time : torch.Tensor
|
|
67
|
+
Time points for SDE integration, shape (num_steps,).
|
|
68
|
+
dt : float
|
|
69
|
+
Time step size for SDE integration, computed as (end - start) / num_steps.
|
|
70
|
+
|
|
71
|
+
Raises
|
|
72
|
+
------
|
|
73
|
+
ValueError
|
|
74
|
+
If `beta_start` or `beta_end` do not satisfy 0 < beta_start < beta_end,
|
|
75
|
+
`sigma_start` or `sigma_end` do not satisfy 0 < sigma_start < sigma_end,
|
|
76
|
+
or `num_steps` is not positive.
|
|
77
|
+
"""
|
|
78
|
+
def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, trainable_beta=False, beta_method="linear",
|
|
79
|
+
sigma_start=1e-3, sigma_end=10.0, start=0.0, end=1.0):
|
|
80
|
+
super().__init__()
|
|
81
|
+
self.num_steps = num_steps
|
|
82
|
+
self.beta_start = beta_start
|
|
83
|
+
self.beta_end = beta_end
|
|
84
|
+
self.trainable_beta = trainable_beta
|
|
85
|
+
self.beta_method = beta_method
|
|
86
|
+
self.sigma_start = sigma_start
|
|
87
|
+
self.sigma_end = sigma_end
|
|
88
|
+
self.start = start
|
|
89
|
+
self.end = end
|
|
90
|
+
|
|
91
|
+
if not (0 < self.beta_start < self.beta_end):
|
|
92
|
+
raise ValueError(f"beta_start ({self.beta_start}) and beta_end ({self.beta_end}) must satisfy 0 < start < end")
|
|
93
|
+
if not (0 < self.sigma_start < self.sigma_end):
|
|
94
|
+
raise ValueError(f"sigma_start ({self.sigma_start}) and sigma_end ({self.sigma_end}) must satisfy 0 < start < end")
|
|
95
|
+
if self.num_steps <= 0:
|
|
96
|
+
raise ValueError(f"num_steps ({self.num_steps}) must be positive")
|
|
97
|
+
|
|
98
|
+
beta_range = (beta_start, beta_end)
|
|
99
|
+
betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
|
|
100
|
+
self.time = torch.linspace(self.start, self.end, self.num_steps, dtype=torch.float32)
|
|
101
|
+
self.dt = (self.end - self.start) / self.num_steps
|
|
102
|
+
|
|
103
|
+
if trainable_beta:
|
|
104
|
+
self.betas = nn.Parameter(betas_init)
|
|
105
|
+
else:
|
|
106
|
+
self.register_buffer('betas', betas_init)
|
|
107
|
+
self.register_buffer('cum_betas', torch.cumsum(betas_init, dim=0) * self.dt)
|
|
108
|
+
self.register_buffer("sigmas", self.sigma_start * (self.sigma_end / self.sigma_start) ** self.time)
|
|
109
|
+
|
|
110
|
+
def compute_beta_schedule(self, beta_range, num_steps, method):
|
|
111
|
+
"""Computes the beta schedule based on the specified method.
|
|
112
|
+
|
|
113
|
+
Generates a sequence of beta values for the SDE noise schedule using the chosen
|
|
114
|
+
method, ensuring values are clamped within the specified range.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
beta_range : tuple
|
|
119
|
+
Tuple of (min_beta, max_beta) specifying the valid range for beta values.
|
|
120
|
+
num_steps : int
|
|
121
|
+
Number of diffusion steps.
|
|
122
|
+
method : str
|
|
123
|
+
Method for computing the beta schedule. Supported methods:
|
|
124
|
+
"linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
125
|
+
|
|
126
|
+
Returns
|
|
127
|
+
-------
|
|
128
|
+
torch.Tensor
|
|
129
|
+
Tensor of beta values, shape (num_steps,).
|
|
130
|
+
|
|
131
|
+
Raises
|
|
132
|
+
------
|
|
133
|
+
ValueError
|
|
134
|
+
If `method` is not one of the supported beta schedule methods.
|
|
135
|
+
"""
|
|
136
|
+
beta_min, beta_max = beta_range
|
|
137
|
+
if method == "sigmoid":
|
|
138
|
+
x = torch.linspace(-6, 6, num_steps)
|
|
139
|
+
beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
|
|
140
|
+
elif method == "quadratic":
|
|
141
|
+
x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
|
|
142
|
+
beta = x ** 2
|
|
143
|
+
elif method == "constant":
|
|
144
|
+
beta = torch.full((num_steps,), beta_max)
|
|
145
|
+
elif method == "inverse_time":
|
|
146
|
+
beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
|
|
147
|
+
beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
|
|
148
|
+
elif method == "linear":
|
|
149
|
+
beta = torch.linspace(beta_min, beta_max, num_steps)
|
|
150
|
+
else:
|
|
151
|
+
raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
|
|
152
|
+
beta = torch.clamp(beta, min=beta_min, max=beta_max)
|
|
153
|
+
return beta
|
|
154
|
+
|
|
155
|
+
def constrain_betas(self):
|
|
156
|
+
"""Constrains trainable betas to a valid range during training.
|
|
157
|
+
|
|
158
|
+
Ensures that trainable beta values remain within the specified range
|
|
159
|
+
[beta_start, beta_end] by clamping them in-place.
|
|
160
|
+
|
|
161
|
+
Notes
|
|
162
|
+
-----
|
|
163
|
+
This method only applies when `trainable_beta` is True.
|
|
164
|
+
"""
|
|
165
|
+
if self.trainable_beta:
|
|
166
|
+
with torch.no_grad():
|
|
167
|
+
self.betas.clamp_(min=self.beta_start, max=self.beta_end)
|
|
168
|
+
|
|
169
|
+
def get_variance(self, time_steps, method):
|
|
170
|
+
"""Computes the variance for the specified SDE method at given time steps.
|
|
171
|
+
|
|
172
|
+
Calculates the variance used in SDE diffusion processes based on the method
|
|
173
|
+
(VE, VP, or sub-VP), leveraging the sigma or cumulative beta schedules.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
time_steps : torch.Tensor
|
|
178
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
179
|
+
is in the range [0, num_steps - 1].
|
|
180
|
+
method : str
|
|
181
|
+
SDE method to compute variance for. Supported methods: "ve", "vp", "sub-vp".
|
|
182
|
+
|
|
183
|
+
Returns
|
|
184
|
+
-------
|
|
185
|
+
torch.Tensor
|
|
186
|
+
Variance values for the specified time steps, shape (batch_size,).
|
|
187
|
+
|
|
188
|
+
Raises
|
|
189
|
+
------
|
|
190
|
+
ValueError
|
|
191
|
+
If `method` is not one of the supported methods ("ve", "vp", "sub-vp").
|
|
192
|
+
"""
|
|
193
|
+
if method == "ve":
|
|
194
|
+
return self.sigmas[time_steps] ** 2
|
|
195
|
+
elif method == "vp":
|
|
196
|
+
return 1 - torch.exp(-self.cum_betas[time_steps])
|
|
197
|
+
elif method == "sub-vp":
|
|
198
|
+
return 1 - torch.exp(-2 * self.cum_betas[time_steps])
|
|
199
|
+
else:
|
|
200
|
+
raise ValueError(f"Unknown method: {method}")
|