TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
sde/hyper_param.py ADDED
@@ -0,0 +1,200 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+
7
+ class HyperParamsSDE(nn.Module):
8
+ """Hyperparameters for SDE-based generative models.
9
+
10
+ Manages the noise schedule and SDE-specific parameters for score-based generative
11
+ models, including beta and sigma schedules, time steps, and variance computations,
12
+ as described in Song et al. (2021). Supports trainable or fixed beta schedules and
13
+ multiple scheduling methods for flexible noise control.
14
+
15
+ Parameters
16
+ ----------
17
+ num_steps : int, optional
18
+ Number of diffusion steps (default: 1000).
19
+ beta_start : float, optional
20
+ Starting value for beta schedule (default: 1e-4).
21
+ beta_end : float, optional
22
+ Ending value for beta schedule (default: 0.02).
23
+ trainable_beta : bool, optional
24
+ Whether the beta schedule is trainable (default: False).
25
+ beta_method : str, optional
26
+ Method for computing the beta schedule (default: "linear").
27
+ Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
28
+ sigma_start : float, optional
29
+ Starting value for sigma schedule for VE method (default: 1e-3).
30
+ sigma_end : float, optional
31
+ Ending value for sigma schedule for VE method (default: 10.0).
32
+ start : float, optional
33
+ Start of the time interval for SDE integration (default: 0.0).
34
+ end : float, optional
35
+ End of the time interval for SDE integration (default: 1.0).
36
+
37
+ Attributes
38
+ ----------
39
+ num_steps : int
40
+ Number of diffusion steps.
41
+ beta_start : float
42
+ Minimum beta value.
43
+ beta_end : float
44
+ Maximum beta value.
45
+ trainable_beta : bool
46
+ Whether the beta schedule is trainable.
47
+ beta_method : str
48
+ Method used for beta schedule computation.
49
+ sigma_start : float
50
+ Minimum sigma value for VE method.
51
+ sigma_end : float
52
+ Maximum sigma value for VE method.
53
+ start : float
54
+ Start of the time interval.
55
+ end : float
56
+ End of the time interval.
57
+ betas : torch.Tensor
58
+ Beta schedule values, shape (num_steps,). Trainable if `trainable_beta` is True,
59
+ otherwise a fixed buffer.
60
+ cum_betas : torch.Tensor, optional
61
+ Cumulative sum of betas scaled by `dt`, shape (num_steps,). Available if
62
+ `trainable_beta` is False.
63
+ sigmas : torch.Tensor, optional
64
+ Sigma schedule for VE method, shape (num_steps,). Available if
65
+ `trainable_beta` is False.
66
+ time : torch.Tensor
67
+ Time points for SDE integration, shape (num_steps,).
68
+ dt : float
69
+ Time step size for SDE integration, computed as (end - start) / num_steps.
70
+
71
+ Raises
72
+ ------
73
+ ValueError
74
+ If `beta_start` or `beta_end` do not satisfy 0 < beta_start < beta_end,
75
+ `sigma_start` or `sigma_end` do not satisfy 0 < sigma_start < sigma_end,
76
+ or `num_steps` is not positive.
77
+ """
78
+ def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, trainable_beta=False, beta_method="linear",
79
+ sigma_start=1e-3, sigma_end=10.0, start=0.0, end=1.0):
80
+ super().__init__()
81
+ self.num_steps = num_steps
82
+ self.beta_start = beta_start
83
+ self.beta_end = beta_end
84
+ self.trainable_beta = trainable_beta
85
+ self.beta_method = beta_method
86
+ self.sigma_start = sigma_start
87
+ self.sigma_end = sigma_end
88
+ self.start = start
89
+ self.end = end
90
+
91
+ if not (0 < self.beta_start < self.beta_end):
92
+ raise ValueError(f"beta_start ({self.beta_start}) and beta_end ({self.beta_end}) must satisfy 0 < start < end")
93
+ if not (0 < self.sigma_start < self.sigma_end):
94
+ raise ValueError(f"sigma_start ({self.sigma_start}) and sigma_end ({self.sigma_end}) must satisfy 0 < start < end")
95
+ if self.num_steps <= 0:
96
+ raise ValueError(f"num_steps ({self.num_steps}) must be positive")
97
+
98
+ beta_range = (beta_start, beta_end)
99
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
100
+ self.time = torch.linspace(self.start, self.end, self.num_steps, dtype=torch.float32)
101
+ self.dt = (self.end - self.start) / self.num_steps
102
+
103
+ if trainable_beta:
104
+ self.betas = nn.Parameter(betas_init)
105
+ else:
106
+ self.register_buffer('betas', betas_init)
107
+ self.register_buffer('cum_betas', torch.cumsum(betas_init, dim=0) * self.dt)
108
+ self.register_buffer("sigmas", self.sigma_start * (self.sigma_end / self.sigma_start) ** self.time)
109
+
110
+ def compute_beta_schedule(self, beta_range, num_steps, method):
111
+ """Computes the beta schedule based on the specified method.
112
+
113
+ Generates a sequence of beta values for the SDE noise schedule using the chosen
114
+ method, ensuring values are clamped within the specified range.
115
+
116
+ Parameters
117
+ ----------
118
+ beta_range : tuple
119
+ Tuple of (min_beta, max_beta) specifying the valid range for beta values.
120
+ num_steps : int
121
+ Number of diffusion steps.
122
+ method : str
123
+ Method for computing the beta schedule. Supported methods:
124
+ "linear", "sigmoid", "quadratic", "constant", "inverse_time".
125
+
126
+ Returns
127
+ -------
128
+ torch.Tensor
129
+ Tensor of beta values, shape (num_steps,).
130
+
131
+ Raises
132
+ ------
133
+ ValueError
134
+ If `method` is not one of the supported beta schedule methods.
135
+ """
136
+ beta_min, beta_max = beta_range
137
+ if method == "sigmoid":
138
+ x = torch.linspace(-6, 6, num_steps)
139
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
140
+ elif method == "quadratic":
141
+ x = torch.linspace(beta_min ** 0.5, beta_max ** 0.5, num_steps)
142
+ beta = x ** 2
143
+ elif method == "constant":
144
+ beta = torch.full((num_steps,), beta_max)
145
+ elif method == "inverse_time":
146
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
147
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
148
+ elif method == "linear":
149
+ beta = torch.linspace(beta_min, beta_max, num_steps)
150
+ else:
151
+ raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
152
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
153
+ return beta
154
+
155
+ def constrain_betas(self):
156
+ """Constrains trainable betas to a valid range during training.
157
+
158
+ Ensures that trainable beta values remain within the specified range
159
+ [beta_start, beta_end] by clamping them in-place.
160
+
161
+ Notes
162
+ -----
163
+ This method only applies when `trainable_beta` is True.
164
+ """
165
+ if self.trainable_beta:
166
+ with torch.no_grad():
167
+ self.betas.clamp_(min=self.beta_start, max=self.beta_end)
168
+
169
+ def get_variance(self, time_steps, method):
170
+ """Computes the variance for the specified SDE method at given time steps.
171
+
172
+ Calculates the variance used in SDE diffusion processes based on the method
173
+ (VE, VP, or sub-VP), leveraging the sigma or cumulative beta schedules.
174
+
175
+ Parameters
176
+ ----------
177
+ time_steps : torch.Tensor
178
+ Tensor of time step indices (long), shape (batch_size,), where each value
179
+ is in the range [0, num_steps - 1].
180
+ method : str
181
+ SDE method to compute variance for. Supported methods: "ve", "vp", "sub-vp".
182
+
183
+ Returns
184
+ -------
185
+ torch.Tensor
186
+ Variance values for the specified time steps, shape (batch_size,).
187
+
188
+ Raises
189
+ ------
190
+ ValueError
191
+ If `method` is not one of the supported methods ("ve", "vp", "sub-vp").
192
+ """
193
+ if method == "ve":
194
+ return self.sigmas[time_steps] ** 2
195
+ elif method == "vp":
196
+ return 1 - torch.exp(-self.cum_betas[time_steps])
197
+ elif method == "sub-vp":
198
+ return 1 - torch.exp(-2 * self.cum_betas[time_steps])
199
+ else:
200
+ raise ValueError(f"Unknown method: {method}")