TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
ddim/__init__.py ADDED
File without changes
ddim/forward_ddim.py ADDED
@@ -0,0 +1,79 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ class ForwardDDIM(nn.Module):
7
+ """Forward diffusion process of DDIM.
8
+
9
+ Implements the forward diffusion process for Denoising Diffusion Implicit Models
10
+ (DDIM), which perturbs input data by adding Gaussian noise over a series of time
11
+ steps, as defined in Song et al. (2021, "Denoising Diffusion Implicit Models").
12
+
13
+ Parameters
14
+ ----------
15
+ hyper_params : object
16
+ Hyperparameter object containing the noise schedule parameters. Expected to have
17
+ attributes:
18
+ - `num_steps`: Number of diffusion steps (int).
19
+ - `trainable_beta`: Whether the noise schedule is trainable (bool).
20
+ - `betas`: Noise schedule parameters (torch.Tensor, optional if trainable_beta is True).
21
+ - `sqrt_alpha_cumprod`: Precomputed cumulative product of alphas (torch.Tensor, optional if trainable_beta is False).
22
+ - `sqrt_one_minus_alpha_cumprod`: Precomputed square root of one minus cumulative alpha product (torch.Tensor, optional if trainable_beta is False).
23
+ - `compute_schedule`: Method to compute the noise schedule (callable, optional if trainable_beta is True).
24
+
25
+ Attributes
26
+ ----------
27
+ hyper_params : object
28
+ Stores the provided hyperparameter object.
29
+ """
30
+ def __init__(self, hyper_params):
31
+ super().__init__()
32
+ self.hyper_params = hyper_params
33
+
34
+ def forward(self, x0, noise, time_steps):
35
+ """Applies the forward diffusion process to the input data.
36
+
37
+ Perturbs the input data `x0` by adding Gaussian noise according to the DDIM
38
+ forward process at specified time steps, using cumulative noise schedule
39
+ parameters.
40
+
41
+ Parameters
42
+ ----------
43
+ x0 : torch.Tensor
44
+ Input data tensor of shape (batch_size, channels, height, width).
45
+ noise : torch.Tensor
46
+ Gaussian noise tensor of the same shape as `x0`.
47
+ time_steps : torch.Tensor
48
+ Tensor of time step indices (long), shape (batch_size,), where each value
49
+ is in the range [0, hyper_params.num_steps - 1].
50
+
51
+ Returns
52
+ -------
53
+ torch.Tensor
54
+ Noisy data tensor `xt` at the specified time steps, with the same shape as `x0`.
55
+
56
+ Raises
57
+ ------
58
+ ValueError
59
+ If any value in `time_steps` is outside the valid range
60
+ [0, hyper_params.num_steps - 1].
61
+ """
62
+ if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.num_steps)):
63
+ raise ValueError(f"time_steps must be between 0 and {self.hyper_params.num_steps - 1}")
64
+
65
+ if self.hyper_params.trainable_beta:
66
+ _, _, _, sqrt_alpha_cumprod_t, sqrt_one_minus_alpha_cumprod_t = self.hyper_params.compute_schedule(
67
+ self.hyper_params.betas
68
+ )
69
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t[time_steps].to(x0.device)
70
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t[time_steps].to(x0.device)
71
+ else:
72
+ sqrt_alpha_cumprod_t = self.hyper_params.sqrt_alpha_cumprod[time_steps].to(x0.device)
73
+ sqrt_one_minus_alpha_cumprod_t = self.hyper_params.sqrt_one_minus_alpha_cumprod[time_steps].to(x0.device)
74
+
75
+ sqrt_alpha_cumprod_t = sqrt_alpha_cumprod_t.view(-1, 1, 1, 1)
76
+ sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alpha_cumprod_t.view(-1, 1, 1, 1)
77
+
78
+ xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
79
+ return xt
ddim/hyper_param.py ADDED
@@ -0,0 +1,225 @@
1
+ """Hyperparameters for Denoising Diffusion Implicit Models (DDIM) noise schedule.
2
+
3
+ This module implements a flexible noise schedule for DDIM, as described in Song et al.
4
+ (2021, "Denoising Diffusion Implicit Models"). It supports multiple beta schedule methods,
5
+ trainable or fixed noise schedules, and a subsampled time step schedule for faster sampling.
6
+ """
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+
14
+ class HyperParamsDDIM(nn.Module):
15
+ """Hyperparameters for DDIM noise schedule with flexible beta computation.
16
+
17
+ Manages the noise schedule parameters for DDIM, including beta values, derived
18
+ quantities (alphas, alpha_cumprod, etc.), and a subsampled time step schedule
19
+ (tau schedule), as inspired by Song et al. (2021). Supports trainable or fixed
20
+ schedules and various beta scheduling methods.
21
+
22
+ Parameters
23
+ ----------
24
+ eta : float, optional
25
+ Noise scaling factor for the DDIM reverse process (default: 0, deterministic).
26
+ num_steps : int, optional
27
+ Total number of diffusion steps (default: 1000).
28
+ tau_num_steps : int, optional
29
+ Number of subsampled time steps for DDIM sampling (default: 100).
30
+ beta_start : float, optional
31
+ Starting value for beta (default: 1e-4).
32
+ beta_end : float, optional
33
+ Ending value for beta (default: 0.02).
34
+ trainable_beta : bool, optional
35
+ Whether the beta schedule is trainable (default: False).
36
+ beta_method : str, optional
37
+ Method for computing the beta schedule (default: "linear").
38
+ Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
39
+
40
+ Attributes
41
+ ----------
42
+ eta : float
43
+ Noise scaling factor for the reverse process.
44
+ num_steps : int
45
+ Total number of diffusion steps.
46
+ tau_num_steps : int
47
+ Number of subsampled time steps.
48
+ beta_start : float
49
+ Minimum beta value.
50
+ beta_end : float
51
+ Maximum beta value.
52
+ trainable_beta : bool
53
+ Whether the beta schedule is trainable.
54
+ beta_method : str
55
+ Method used for beta schedule computation.
56
+ betas : torch.Tensor
57
+ Beta schedule values, shape (num_steps,). Trainable if `trainable_beta` is True,
58
+ otherwise a fixed buffer.
59
+ alphas : torch.Tensor, optional
60
+ Alpha values (1 - betas), shape (num_steps,). Available if `trainable_beta` is False.
61
+ alpha_cumprod : torch.Tensor, optional
62
+ Cumulative product of alphas, shape (num_steps,). Available if `trainable_beta` is False.
63
+ sqrt_alpha_cumprod : torch.Tensor, optional
64
+ Square root of alpha_cumprod, shape (num_steps,). Available if `trainable_beta` is False.
65
+ sqrt_one_minus_alpha_cumprod : torch.Tensor, optional
66
+ Square root of (1 - alpha_cumprod), shape (num_steps,). Available if `trainable_beta` is False.
67
+ tau_indices : torch.Tensor
68
+ Indices for subsampled time steps, shape (tau_num_steps,).
69
+
70
+ Raises
71
+ ------
72
+ ValueError
73
+ If `beta_start` or `beta_end` do not satisfy 0 < beta_start < beta_end < 1,
74
+ or if `num_steps` is not positive.
75
+ """
76
+
77
+ def __init__(self, eta=None, num_steps=1000, tau_num_steps=100, beta_start=1e-4, beta_end=0.02,
78
+ trainable_beta=False, beta_method="linear"):
79
+ super().__init__()
80
+ self.eta = eta or 0
81
+ self.num_steps = num_steps
82
+ self.tau_num_steps = tau_num_steps
83
+ self.beta_start = beta_start
84
+ self.beta_end = beta_end
85
+ self.trainable_beta = trainable_beta
86
+ self.beta_method = beta_method
87
+
88
+ if not (0 < beta_start < beta_end < 1):
89
+ raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
90
+ if num_steps <= 0:
91
+ raise ValueError(f"num_steps ({num_steps}) must be positive")
92
+
93
+ beta_range = (beta_start, beta_end)
94
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
95
+
96
+ if trainable_beta:
97
+ self.betas = nn.Parameter(betas_init)
98
+ else:
99
+ self.register_buffer('betas', betas_init)
100
+ self.register_buffer('alphas', 1 - self.betas)
101
+ self.register_buffer('alpha_cumprod', torch.cumprod(self.alphas, dim=0))
102
+ self.register_buffer('sqrt_alpha_cumprod', torch.sqrt(self.alpha_cumprod))
103
+ self.register_buffer('sqrt_one_minus_alpha_cumprod', torch.sqrt(1 - self.alpha_cumprod))
104
+
105
+ self.register_buffer('tau_indices', torch.linspace(0, num_steps - 1, tau_num_steps, dtype=torch.long))
106
+
107
+ def compute_beta_schedule(self, beta_range, num_steps, method):
108
+ """Computes the beta schedule based on the specified method.
109
+
110
+ Generates a sequence of beta values for the DDIM noise schedule using the
111
+ chosen method, ensuring values are clamped within the specified range.
112
+
113
+ Parameters
114
+ ----------
115
+ beta_range : tuple
116
+ Tuple of (min_beta, max_beta) specifying the valid range for beta values.
117
+ num_steps : int
118
+ Number of diffusion steps.
119
+ method : str
120
+ Method for computing the beta schedule. Supported methods:
121
+ "linear", "sigmoid", "quadratic", "constant", "inverse_time".
122
+
123
+ Returns
124
+ -------
125
+ torch.Tensor
126
+ Tensor of beta values, shape (num_steps,).
127
+
128
+ Raises
129
+ ------
130
+ ValueError
131
+ If `method` is not one of the supported beta schedule methods.
132
+ """
133
+ beta_min, beta_max = beta_range
134
+ if method == "sigmoid":
135
+ x = torch.linspace(-6, 6, num_steps)
136
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
137
+ elif method == "quadratic":
138
+ x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
139
+ beta = x ** 2
140
+ elif method == "constant":
141
+ beta = torch.full((num_steps,), beta_max)
142
+ elif method == "inverse_time":
143
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
144
+ # scale to beta_range
145
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
146
+ elif method == "linear":
147
+ beta = torch.linspace(beta_min, beta_max, num_steps)
148
+ else:
149
+ raise ValueError(
150
+ f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
151
+
152
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
153
+ return beta
154
+
155
+ def get_tau_schedule(self):
156
+ """Computes the subsampled (tau) noise schedule for DDIM.
157
+
158
+ Returns the noise schedule parameters for the subsampled time steps used in
159
+ DDIM sampling, based on the `tau_indices`.
160
+
161
+ Returns
162
+ -------
163
+ tuple
164
+ A tuple containing:
165
+ - tau_betas: Beta values for subsampled steps, shape (tau_num_steps,).
166
+ - tau_alphas: Alpha values for subsampled steps, shape (tau_num_steps,).
167
+ - tau_alpha_cumprod: Cumulative product of alphas for subsampled steps, shape (tau_num_steps,).
168
+ - tau_sqrt_alpha_cumprod: Square root of alpha_cumprod for subsampled steps, shape (tau_num_steps,).
169
+ - tau_sqrt_one_minus_alpha_cumprod: Square root of (1 - alpha_cumprod) for subsampled steps, shape (tau_num_steps,).
170
+ """
171
+ if self.trainable_beta:
172
+ betas, alphas, alpha_cumprod, sqrt_alpha_cumprod, sqrt_one_minus_alpha_cumprod = self.compute_schedule(self.betas)
173
+ else:
174
+ betas = self.betas
175
+ alphas = self.alphas
176
+ alpha_cumprod = self.alpha_cumprod
177
+ sqrt_alpha_cumprod = self.sqrt_alpha_cumprod
178
+ sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alpha_cumprod
179
+
180
+ tau_betas = betas[self.tau_indices]
181
+ tau_alphas = alphas[self.tau_indices]
182
+ tau_alpha_cumprod = alpha_cumprod[self.tau_indices]
183
+ tau_sqrt_alpha_cumprod = sqrt_alpha_cumprod[self.tau_indices]
184
+ tau_sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod[self.tau_indices]
185
+
186
+ return tau_betas, tau_alphas, tau_alpha_cumprod, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod
187
+
188
+ def compute_schedule(self, betas):
189
+ """Computes noise schedule parameters dynamically from betas.
190
+
191
+ Calculates the derived noise schedule parameters (alphas, alpha_cumprod, etc.)
192
+ from the provided beta values, as used in the DDIM forward and reverse processes.
193
+
194
+ Parameters
195
+ ----------
196
+ betas : torch.Tensor
197
+ Tensor of beta values, shape (num_steps,).
198
+
199
+ Returns
200
+ -------
201
+ tuple
202
+ A tuple containing:
203
+ - betas: Input beta values, shape (num_steps,).
204
+ - alphas: 1 - betas, shape (num_steps,).
205
+ - alpha_cumprod: Cumulative product of alphas, shape (num_steps,).
206
+ - sqrt_alpha_cumprod: Square root of alpha_cumprod, shape (num_steps,).
207
+ - sqrt_one_minus_alpha_cumprod: Square root of (1 - alpha_cumprod), shape (num_steps,).
208
+ """
209
+ alphas = 1 - betas
210
+ alpha_cumprod = torch.cumprod(alphas, dim=0)
211
+ return betas, alphas, alpha_cumprod, torch.sqrt(alpha_cumprod), torch.sqrt(1 - alpha_cumprod)
212
+
213
+ def constrain_betas(self):
214
+ """Constrains trainable betas to a valid range during training.
215
+
216
+ Ensures that trainable beta values remain within the specified range
217
+ [beta_start, beta_end] by clamping them in-place.
218
+
219
+ Notes
220
+ -----
221
+ This method only applies when `trainable_beta` is True.
222
+ """
223
+ if self.trainable_beta:
224
+ with torch.no_grad():
225
+ self.betas.clamp_(min=self.beta_start, max=self.beta_end)