TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
ldm/reverse_ldm.py ADDED
@@ -0,0 +1,119 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+
7
+ class ReverseSDE(nn.Module):
8
+ def __init__(self, hyper_params, method):
9
+ super().__init__()
10
+ self.hyper_params = hyper_params
11
+ self.method = method
12
+
13
+ def forward(self, xt, noise, predicted_noise, time_steps):
14
+
15
+ dt = self.hyper_params.dt
16
+ betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
17
+ cum_betas = self.hyper_params.cum_betas[time_steps].view(-1, 1, 1, 1)
18
+ if self.method == "ve":
19
+ sigma_t = self.hyper_params.sigmas[time_steps]
20
+ sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
21
+ sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
22
+ drift = -(sigma_t ** 2 - sigma_t_prev ** 2).view(-1, 1, 1, 1) * predicted_noise * dt
23
+ diffusion = sigma_diff.view(-1, 1, 1, 1) * noise if noise is not None else 0
24
+ xt = xt + drift + diffusion
25
+ xt = torch.clamp(xt, -1e5, 1e5)
26
+
27
+ elif self.method == "vp":
28
+ drift = -0.5 * betas * xt * dt - betas * predicted_noise * dt
29
+ diffusion = torch.sqrt(betas * dt) * noise if noise is not None else 0
30
+ xt = xt + drift + diffusion
31
+
32
+ elif self.method == "sub-vp":
33
+ drift = -0.5 * betas * xt * dt - betas * (1 - torch.exp(-2 * cum_betas)) * predicted_noise * dt
34
+ diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise if noise is not None else 0
35
+ xt = xt + drift + diffusion
36
+
37
+ elif self.method == "ode":
38
+ if self.method == "ve":
39
+ sigma_t = self.hyper_params.sigmas[time_steps]
40
+ sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
41
+ drift = -0.5 * (sigma_t ** 2 - sigma_t_prev ** 2).view(-1, 1, 1, 1) * predicted_noise * dt
42
+ else:
43
+ drift = -0.5 * betas * xt * dt - 0.5 * betas * predicted_noise * dt
44
+ xt = xt + drift
45
+ xt = torch.clamp(xt, -1e5, 1e5)
46
+ else:
47
+ raise ValueError(f"Unknown method: {self.method}")
48
+ return xt
49
+
50
+
51
+
52
+
53
+ class ReverseDDIM(nn.Module):
54
+ def __init__(self, hyper_params):
55
+ super().__init__()
56
+ self.hyper_params = hyper_params
57
+
58
+ def forward(self, xt, predicted_noise, time_steps, prev_time_steps):
59
+
60
+ if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.tau_num_steps)):
61
+ raise ValueError(f"time_steps must be between 0 and {self.hyper_params.tau_num_steps - 1}")
62
+ if not torch.all((prev_time_steps >= 0) & (prev_time_steps < self.hyper_params.tau_num_steps)):
63
+ raise ValueError(f"prev_time_steps must be between 0 and {self.hyper_params.tau_num_steps - 1}")
64
+
65
+ _, _, _, tau_sqrt_alpha_cumprod, tau_sqrt_one_minus_alpha_cumprod = self.hyper_params.get_tau_schedule()
66
+ tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
67
+ tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[time_steps].to(xt.device).view(-1, 1, 1, 1)
68
+ prev_tau_sqrt_alpha_cumprod_t = tau_sqrt_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
69
+ prev_tau_sqrt_one_minus_alpha_cumprod_t = tau_sqrt_one_minus_alpha_cumprod[prev_time_steps].to(xt.device).view(-1, 1, 1, 1)
70
+
71
+ eta = self.hyper_params.eta
72
+ x0 = (xt - tau_sqrt_one_minus_alpha_cumprod_t * predicted_noise) / tau_sqrt_alpha_cumprod_t
73
+ noise_coeff = eta * ((tau_sqrt_one_minus_alpha_cumprod_t / prev_tau_sqrt_alpha_cumprod_t) *
74
+ prev_tau_sqrt_one_minus_alpha_cumprod_t / torch.clamp(tau_sqrt_one_minus_alpha_cumprod_t, min=1e-8))
75
+ direction_coeff = torch.clamp(prev_tau_sqrt_one_minus_alpha_cumprod_t ** 2 - noise_coeff ** 2, min=1e-8).sqrt()
76
+ xt_prev = prev_tau_sqrt_alpha_cumprod_t * x0 + noise_coeff * torch.randn_like(xt) + direction_coeff * predicted_noise
77
+
78
+ return xt_prev, x0
79
+
80
+
81
+
82
+ class ReverseDDPM(nn.Module):
83
+ """reverse diffusion process of DDPM."""
84
+ def __init__(self, hyper_params):
85
+ super().__init__()
86
+ self.hyper_params = hyper_params # hyperparameters class
87
+
88
+ def forward(self, xt, predicted_noise, time_steps):
89
+ if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.num_steps)):
90
+ raise ValueError(f"time_steps must be between 0 and {self.hyper_params.num_steps - 1}")
91
+
92
+ if self.hyper_params.trainable_beta:
93
+ betas_t, alphas_t, alpha_bars_t, _, _ = self.hyper_params.compute_schedule(self.hyper_params.betas)
94
+ betas_t = betas_t[time_steps].to(xt.device)
95
+ alphas_t = alphas_t[time_steps].to(xt.device)
96
+ alpha_bars_t = alpha_bars_t[time_steps].to(xt.device)
97
+ alpha_bars_t_minus_1 = alpha_bars_t[time_steps - 1].to(xt.device) if time_steps.any() else None
98
+ else:
99
+ betas_t = self.hyper_params.betas[time_steps].to(xt.device)
100
+ alphas_t = self.hyper_params.alphas[time_steps].to(xt.device)
101
+ alpha_bars_t = self.hyper_params.alpha_bars[time_steps].to(xt.device)
102
+ alpha_bars_t_minus_1 = self.hyper_params.alpha_bars[time_steps - 1].to(xt.device) if time_steps.any() else None
103
+
104
+ sqrt_alphas_t = torch.sqrt(alphas_t).view(-1, 1, 1, 1)
105
+ sqrt_one_minus_alpha_bars_t = torch.sqrt(1 - alpha_bars_t).view(-1, 1, 1, 1)
106
+ betas_t = betas_t.view(-1, 1, 1, 1)
107
+
108
+ mu = (xt - (betas_t / sqrt_one_minus_alpha_bars_t) * predicted_noise) / sqrt_alphas_t
109
+
110
+ mask = (time_steps == 0)
111
+ if mask.all():
112
+ return mu
113
+
114
+ variance = (1 - alpha_bars_t_minus_1) / (1 - alpha_bars_t) * betas_t.squeeze()
115
+ std = torch.sqrt(variance).view(-1, 1, 1, 1)
116
+
117
+ z = torch.randn_like(xt).to(xt.device)
118
+ xt_minus_1 = mu + (~mask).float().view(-1, 1, 1, 1) * std * z
119
+ return xt_minus_1
ldm/sample_ldm.py ADDED
@@ -0,0 +1,254 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertTokenizer
4
+
5
+
6
+ class SampleLDM(nn.Module):
7
+ """Sampler for generating images using Latent Diffusion Models (LDM).
8
+
9
+ Generates images by iteratively denoising random noise in the latent space using a
10
+ reverse diffusion process, decoding the result back to the image space with a
11
+ pre-trained compressor, as described in Rombach et al. (2022). Supports DDPM, DDIM,
12
+ and SDE diffusion models, as well as conditional generation with text prompts.
13
+
14
+ Parameters
15
+ ----------
16
+ model : str
17
+ Diffusion model type. Supported: "ddpm", "ddim", "sde".
18
+ reverse_diffusion : nn.Module
19
+ Reverse diffusion module (e.g., ReverseDDPM, ReverseDDIM, ReverseSDE).
20
+ noise_predictor : nn.Module
21
+ Model to predict noise added during the forward diffusion process.
22
+ compressor_model : nn.Module
23
+ Pre-trained model to encode/decode between image and latent spaces (e.g., autoencoder).
24
+ image_shape : tuple
25
+ Shape of generated images as (height, width).
26
+ conditional_model : nn.Module, optional
27
+ Model for conditional generation (e.g., text embeddings), default None.
28
+ tokenizer : str or BertTokenizer, optional
29
+ Tokenizer for processing text prompts, default "bert-base-uncased".
30
+ batch_size : int, optional
31
+ Number of images to generate per batch (default: 1).
32
+ in_channels : int, optional
33
+ Number of input channels for latent representations (default: 3).
34
+ device : torch.device, optional
35
+ Device for computation (default: CUDA if available, else CPU).
36
+ max_length : int, optional
37
+ Maximum length for tokenized prompts (default: 77).
38
+ output_range : tuple, optional
39
+ Range for clamping generated images (min, max), default (-1, 1).
40
+
41
+ Attributes
42
+ ----------
43
+ device : torch.device
44
+ Device used for computation.
45
+ model : str
46
+ Diffusion model type ("ddpm", "ddim", "sde").
47
+ noise_predictor : nn.Module
48
+ Noise prediction model.
49
+ reverse : nn.Module
50
+ Reverse diffusion module.
51
+ compressor : nn.Module
52
+ Compressor model for latent space encoding/decoding.
53
+ conditional_model : nn.Module or None
54
+ Conditional model for text-based generation, if provided.
55
+ tokenizer : BertTokenizer
56
+ Tokenizer for text prompts.
57
+ in_channels : int
58
+ Number of input channels for latent representations.
59
+ image_shape : tuple
60
+ Shape of generated images (height, width).
61
+ batch_size : int
62
+ Batch size for generation.
63
+ max_length : int
64
+ Maximum length for tokenized prompts.
65
+ output_range : tuple
66
+ Range for clamping generated images.
67
+
68
+ Raises
69
+ ------
70
+ ValueError
71
+ If `image_shape` is not a tuple of two positive integers, `batch_size` is not
72
+ positive, `in_channels` is not positive, or `output_range` is not a tuple
73
+ (min, max) with min < max.
74
+ """
75
+ def __init__(self, model, reverse_diffusion, noise_predictor, compressor_model, image_shape, conditional_model=None,
76
+ tokenizer="bert-base-uncased", batch_size=1, in_channels=3, device=None, max_length=77, output_range=(-1, 1)):
77
+ super().__init__()
78
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ self.model = model
80
+ self.noise_predictor = noise_predictor.to(self.device)
81
+ self.reverse = reverse_diffusion.to(self.device)
82
+ self.compressor = compressor_model.to(self.device)
83
+ self.conditional_model = conditional_model.to(self.device) if conditional_model else None
84
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
85
+ self.in_channels = in_channels
86
+ self.image_shape = image_shape
87
+ self.batch_size = batch_size
88
+ self.max_length = max_length
89
+ self.output_range = output_range
90
+
91
+ if not isinstance(image_shape, (tuple, list)) or len(image_shape) != 2 or not all(isinstance(s, int) and s > 0 for s in image_shape):
92
+ raise ValueError("image_shape must be a tuple of two positive integers (height, width)")
93
+ if batch_size <= 0:
94
+ raise ValueError("batch_size must be positive")
95
+ if in_channels <= 0:
96
+ raise ValueError("in_channels must be positive")
97
+ if not isinstance(output_range, (tuple, list)) or len(output_range) != 2 or output_range[0] >= output_range[1]:
98
+ raise ValueError("output_range must be a tuple (min, max) with min < max")
99
+
100
+ def tokenize(self, prompts):
101
+ """Tokenizes text prompts for conditional generation.
102
+
103
+ Converts input prompts into tokenized tensors using the specified tokenizer.
104
+
105
+ Parameters
106
+ ----------
107
+ prompts : str or list
108
+ Text prompt(s) for conditional generation. Can be a single string or a list
109
+ of strings.
110
+
111
+ Returns
112
+ -------
113
+ tuple
114
+ A tuple containing:
115
+ - input_ids: Tokenized input IDs (torch.Tensor, shape (batch_size, max_length)).
116
+ - attention_mask: Attention mask for tokenized inputs (torch.Tensor, same shape).
117
+
118
+ Raises
119
+ ------
120
+ TypeError
121
+ If `prompts` is not a string or a list of strings.
122
+ """
123
+ if isinstance(prompts, str):
124
+ prompts = [prompts]
125
+ elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
126
+ raise TypeError("prompts must be a string or list of strings")
127
+
128
+ encoded = self.tokenizer(
129
+ prompts,
130
+ padding="max_length",
131
+ truncation=True,
132
+ max_length=self.max_length,
133
+ return_tensors="pt"
134
+ )
135
+ return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
136
+
137
+ def forward(self, conditions=None, normalize_output=True):
138
+ """Generates images using the reverse diffusion process in the latent space.
139
+
140
+ Iteratively denoises random noise in the latent space using the specified reverse
141
+ diffusion model (DDPM, DDIM, or SDE), then decodes the result to the image space
142
+ with the compressor model. Supports conditional generation with text prompts.
143
+
144
+ Parameters
145
+ ----------
146
+ conditions : str or list, optional
147
+ Text prompt(s) for conditional generation, default None.
148
+ normalize_output : bool, optional
149
+ If True, normalizes output images to [0, 1] (default: True).
150
+
151
+ Returns
152
+ -------
153
+ torch.Tensor
154
+ Generated images, shape (batch_size, channels, height, width).
155
+ If `normalize_output` is True, images are normalized to [0, 1]; otherwise,
156
+ they are clamped to `output_range`.
157
+
158
+ Raises
159
+ ------
160
+ ValueError
161
+ If `conditions` is provided but no conditional model is specified, if a
162
+ conditional model is specified but `conditions` is None, or if `model` is not
163
+ one of "ddpm", "ddim", "sde".
164
+
165
+ Notes
166
+ -----
167
+ - Sampling is performed with `torch.no_grad()` for efficiency.
168
+ - The noise predictor, reverse diffusion, compressor, and conditional model
169
+ (if applicable) are set to evaluation mode during sampling.
170
+ - For DDIM, uses the subsampled tau schedule (`tau_num_steps`); for DDPM/SDE,
171
+ uses the full number of steps (`num_steps`).
172
+ - The compressor model is assumed to have `encode` and `decode` methods for
173
+ latent space conversion.
174
+ """
175
+ if conditions is not None and self.conditional_model is None:
176
+ raise ValueError("Conditions provided but no conditional model specified")
177
+ if conditions is None and self.conditional_model is not None:
178
+ raise ValueError("Conditions must be provided for conditional model")
179
+
180
+ noisy_samples = torch.randn(self.batch_size, self.in_channels, self.image_shape[0], self.image_shape[1]).to(self.device)
181
+
182
+ self.noise_predictor.eval()
183
+ self.compressor.eval()
184
+ self.reverse.eval()
185
+ if self.conditional_model:
186
+ self.conditional_model.eval()
187
+
188
+ with torch.no_grad():
189
+ xt = noisy_samples
190
+ xt, _ = self.compressor.encode(xt)
191
+
192
+ if self.model == "ddim":
193
+ num_steps = self.reverse.hyper_params.tau_num_steps
194
+ elif self.model == "ddpm" or self.model == "sde":
195
+ num_steps = self.reverse.hyper_params.num_steps
196
+ else:
197
+ raise ValueError(f"Unknown model: {self.model}. Supported: ddpm, ddim, sde")
198
+
199
+ for t in reversed(range(num_steps)):
200
+ time_steps = torch.full((self.batch_size,), t, device=self.device, dtype=torch.long)
201
+ prev_time_steps = torch.full((self.batch_size,), max(t - 1, 0), device=self.device, dtype=torch.long)
202
+
203
+ if self.model == "sde":
204
+ noise = torch.randn_like(xt) if getattr(self.reverse, "method", None) != "ode" else None
205
+
206
+ if self.conditional_model is not None and conditions is not None:
207
+ input_ids, attention_masks = self.tokenize(conditions)
208
+ key_padding_mask = (attention_masks == 0)
209
+ y = self.conditional_model(input_ids, key_padding_mask)
210
+ predicted_noise = self.noise_predictor(xt, time_steps, y)
211
+ else:
212
+ predicted_noise = self.noise_predictor(xt, time_steps)
213
+
214
+ if self.model == "sde":
215
+ xt = self.reverse(xt, noise, predicted_noise, time_steps)
216
+ elif self.model == "ddim":
217
+ xt, _ = self.reverse(xt, predicted_noise, time_steps, prev_time_steps)
218
+ elif self.model == "ddpm":
219
+ xt = self.reverse(xt, predicted_noise, time_steps)
220
+ else:
221
+ raise ValueError(f"Unknown model: {self.model}. Supported: ddpm, ddim, sde")
222
+
223
+ x = self.compressor.decode(xt)
224
+ generated_imgs = torch.clamp(x, min=self.output_range[0], max=self.output_range[1])
225
+ if normalize_output:
226
+ generated_imgs = (generated_imgs - self.output_range[0]) / (self.output_range[1] - self.output_range[0])
227
+
228
+ return generated_imgs
229
+
230
+ def to(self, device):
231
+ """Moves the module and its components to the specified device.
232
+
233
+ Parameters
234
+ ----------
235
+ device : torch.device
236
+ Target device for computation.
237
+
238
+ Returns
239
+ -------
240
+ self
241
+ The module moved to the specified device.
242
+
243
+ Notes
244
+ -----
245
+ - Moves `noise_predictor`, `reverse`, `compressor`, and `conditional_model`
246
+ (if applicable) to the specified device.
247
+ """
248
+ self.device = device
249
+ self.noise_predictor.to(device)
250
+ self.reverse.to(device)
251
+ self.compressor.to(device)
252
+ if self.conditional_model:
253
+ self.conditional_model.to(device)
254
+ return super().to(device)