TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
ldm/train_ldm.py ADDED
@@ -0,0 +1,412 @@
1
+ import torch
2
+ from tqdm import tqdm
3
+ import torch.nn as nn
4
+ from torch.cuda.amp import GradScaler, autocast
5
+ from torch.amp import GradScaler, autocast
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+ from transformers import BertTokenizer
8
+ import warnings
9
+
10
+
11
+
12
+
13
+ class TrainLDM(nn.Module):
14
+ """Trainer for Latent Diffusion Models (LDM).
15
+
16
+ Manages the training process for LDMs, optimizing a noise predictor to learn the noise
17
+ added by the forward diffusion process in the latent space, as described in Rombach
18
+ et al. (2022). Uses a pre-trained compressor model to encode images into a latent
19
+ space, supports conditional training with text prompts, mixed precision, learning rate
20
+ scheduling, early stopping, and checkpointing.
21
+
22
+ Parameters
23
+ ----------
24
+ forward_model : nn.Module
25
+ Forward diffusion module (e.g., ForwardDDPM, ForwardSDE) to add noise in the latent space.
26
+ hyper_params_model : nn.Module
27
+ Hyperparameter module (e.g., HyperParamsDDPM, HyperParamsSDE) defining the noise schedule.
28
+ noise_predictor : nn.Module
29
+ Model to predict noise added during the forward diffusion process.
30
+ compressor_model : nn.Module
31
+ Pre-trained model to encode images into the latent space and decode back (e.g., autoencoder).
32
+ optimizer : torch.optim.Optimizer
33
+ Optimizer for training the noise predictor and conditional model (if applicable).
34
+ objective : callable
35
+ Loss function to compute the difference between predicted and actual noise.
36
+ data_loader : torch.utils.data.DataLoader
37
+ DataLoader for training data.
38
+ conditional_model : nn.Module, optional
39
+ Model for conditional generation (e.g., text embeddings), default None.
40
+ val_loader : torch.utils.data.DataLoader, optional
41
+ DataLoader for validation data, default None.
42
+ max_epoch : int, optional
43
+ Maximum number of training epochs (default: 1000).
44
+ device : torch.device, optional
45
+ Device for computation (default: CUDA if available, else CPU).
46
+ store_path : str, optional
47
+ Path to save model checkpoints (default: "ldm_model.pth").
48
+ patience : int, optional
49
+ Number of epochs to wait for improvement before early stopping (default: 10).
50
+ warmup_epochs : int, optional
51
+ Number of epochs for learning rate warmup (default: 100).
52
+ tokenizer : BertTokenizer, optional
53
+ Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
54
+ max_length : int, optional
55
+ Maximum length for tokenized prompts (default: 77).
56
+ val_frequency : int, optional
57
+ Frequency (in epochs) for validation (default: 10).
58
+
59
+ Attributes
60
+ ----------
61
+ device : torch.device
62
+ Device used for computation.
63
+ forward_diffusion : nn.Module
64
+ Forward diffusion module.
65
+ hyper_params_model : nn.Module
66
+ Hyperparameter module for the noise schedule.
67
+ noise_predictor : nn.Module
68
+ Noise prediction model.
69
+ compressor_model : nn.Module
70
+ Compressor model for latent space encoding/decoding.
71
+ conditional_model : nn.Module or None
72
+ Conditional model for text-based training, if provided.
73
+ optimizer : torch.optim.Optimizer
74
+ Optimizer for training.
75
+ objective : callable
76
+ Loss function for training.
77
+ data_loader : torch.utils.data.DataLoader
78
+ Training data loader.
79
+ val_loader : torch.utils.data.DataLoader or None
80
+ Validation data loader, if provided.
81
+ max_epoch : int
82
+ Maximum training epochs.
83
+ store_path : str
84
+ Path for saving checkpoints.
85
+ max_length : int
86
+ Maximum length for tokenized prompts.
87
+ patience : int
88
+ Patience for early stopping.
89
+ scheduler : torch.optim.lr_scheduler.ReduceLROnPlateau
90
+ Learning rate scheduler based on validation or training loss.
91
+ warmup_lr_scheduler : torch.optim.lr_scheduler.LambdaLR
92
+ Learning rate scheduler for warmup.
93
+ tokenizer : BertTokenizer
94
+ Tokenizer for text prompts.
95
+ val_frequency : int
96
+ Frequency for validation.
97
+
98
+ Raises
99
+ ------
100
+ ValueError
101
+ If the default tokenizer ("bert-base-uncased") fails to load and no tokenizer is provided.
102
+ """
103
+ def __init__(self, forward_model, hyper_params_model, noise_predictor, compressor_model, optimizer, objective, data_loader,
104
+ conditional_model=None, val_loader=None, max_epoch=1000, device=None, store_path=None,
105
+ patience=10, warmup_epochs=100, tokenizer=None, max_length=77, val_frequency=10):
106
+ super().__init__()
107
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
+ self.forward_diffusion = forward_model.to(device)
109
+ self.hyper_params_model = hyper_params_model.to(device)
110
+ self.noise_predictor = noise_predictor
111
+ self.compressor_model = compressor_model
112
+ self.conditional_model = conditional_model
113
+ self.optimizer = optimizer
114
+ self.objective = objective
115
+ self.data_loader = data_loader
116
+ self.val_loader = val_loader
117
+ self.max_epoch = max_epoch
118
+ self.store_path = store_path or "ldm_model.pth"
119
+ self.max_length = max_length
120
+ self.patience = patience
121
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=self.patience, factor=0.5)
122
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
123
+ self.val_frequency = val_frequency
124
+ if tokenizer is None:
125
+ try:
126
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
127
+ except Exception as e:
128
+ raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
129
+
130
+ def load_checkpoint(self, checkpoint_path):
131
+ """Loads a training checkpoint to resume training.
132
+
133
+ Restores the state of the noise predictor, conditional model (if applicable),
134
+ and optimizer from a saved checkpoint.
135
+
136
+ Parameters
137
+ ----------
138
+ checkpoint_path : str
139
+ Path to the checkpoint file.
140
+
141
+ Returns
142
+ -------
143
+ tuple
144
+ A tuple containing:
145
+ - epoch: The epoch at which the checkpoint was saved (int).
146
+ - loss: The loss at the checkpoint (float).
147
+
148
+ Raises
149
+ ------
150
+ FileNotFoundError
151
+ If the checkpoint file is not found.
152
+ KeyError
153
+ If the checkpoint is missing required keys ('model_state_dict_noise_predictor'
154
+ or 'optimizer_state_dict').
155
+
156
+ Warns
157
+ -----
158
+ warnings.warn
159
+ If the optimizer state cannot be loaded, if the checkpoint contains a
160
+ conditional model state but none is defined, or if no conditional model
161
+ state is provided when expected.
162
+ """
163
+ try:
164
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
165
+ except FileNotFoundError:
166
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
167
+
168
+ if 'model_state_dict_noise_predictor' not in checkpoint:
169
+ raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
170
+ self.noise_predictor.load_state_dict(checkpoint['model_state_dict_noise_predictor'])
171
+
172
+ if self.conditional_model is not None:
173
+ if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
174
+ self.conditional_model.load_state_dict(checkpoint['model_state_dict_conditional'])
175
+ else:
176
+ warnings.warn(
177
+ "Checkpoint contains no 'model_state_dict_conditional' or it is None, skipping conditional model loading")
178
+ elif 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
179
+ warnings.warn(
180
+ "Checkpoint contains conditional model state, but no conditional model is defined in this instance")
181
+
182
+ if 'optimizer_state_dict' not in checkpoint:
183
+ raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
184
+ try:
185
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
186
+ except ValueError as e:
187
+ warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
188
+
189
+ epoch = checkpoint.get('epoch', -1)
190
+ loss = checkpoint.get('loss', float('inf'))
191
+
192
+ self.noise_predictor.to(self.device)
193
+ if self.conditional_model is not None:
194
+ self.conditional_model.to(self.device)
195
+
196
+ print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
197
+ return epoch, loss
198
+
199
+ @staticmethod
200
+ def warmup_scheduler(optimizer, warmup_epochs=10):
201
+ """Creates a learning rate scheduler for warmup.
202
+
203
+ Generates a scheduler that linearly increases the learning rate from 0 to the
204
+ optimizer's initial value over the specified warmup epochs, then maintains it.
205
+
206
+ Parameters
207
+ ----------
208
+ optimizer : torch.optim.Optimizer
209
+ Optimizer to apply the scheduler to.
210
+ warmup_epochs : int, optional
211
+ Number of epochs for the warmup phase (default: 10).
212
+
213
+ Returns
214
+ -------
215
+ torch.optim.lr_scheduler.LambdaLR
216
+ Learning rate scheduler for warmup.
217
+ """
218
+ def lr_lambda(epoch):
219
+ if epoch < warmup_epochs:
220
+ return epoch / warmup_epochs
221
+ return 1.0
222
+
223
+ return LambdaLR(optimizer, lr_lambda)
224
+
225
+ def forward(self):
226
+ """Trains the LDM to predict noise added by the forward diffusion process in the latent space.
227
+
228
+ Executes the training loop, optimizing the noise predictor and conditional model
229
+ (if applicable) using mixed precision, gradient clipping, and learning rate
230
+ scheduling. Uses a pre-trained compressor model to encode images into the latent
231
+ space. Supports validation, early stopping, and checkpointing.
232
+
233
+ Returns
234
+ -------
235
+ tuple
236
+ A tuple containing:
237
+ - train_losses: List of mean training losses per epoch (list of float).
238
+ - best_val_loss: Best validation or training loss achieved (float).
239
+
240
+ Notes
241
+ -----
242
+ - Training uses mixed precision via `torch.cuda.amp` for efficiency.
243
+ - The compressor model is assumed pre-trained and set to evaluation mode.
244
+ - Checkpoints are saved when the validation (or training) loss improves, and on
245
+ early stopping.
246
+ - Early stopping is triggered if no improvement occurs for `patience` epochs.
247
+ """
248
+
249
+ self.noise_predictor.train()
250
+ self.noise_predictor.to(self.device)
251
+ if self.conditional_model is not None:
252
+ self.conditional_model.train()
253
+ self.conditional_model.to(self.device)
254
+ if self.compressor_model is not None:
255
+ self.compressor_model.eval() # the model is already trained
256
+ self.compressor_model.to(self.device)
257
+
258
+ scaler = GradScaler()
259
+ train_losses = []
260
+ best_val_loss = float("inf")
261
+ wait = 0
262
+ for epoch in range(self.max_epoch):
263
+ train_losses_ = []
264
+ for x, y in tqdm(self.data_loader):
265
+ x = x.to(self.device)
266
+ with torch.no_grad():
267
+ x, _ = self.compressor_model.encode(x) # using compressor model bring x to latent space
268
+ if self.conditional_model is not None:
269
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
270
+ y_list = [str(item) for item in y_list]
271
+ y_encoded = self.tokenizer(
272
+ y_list,
273
+ padding="max_length",
274
+ truncation=True,
275
+ max_length=self.max_length,
276
+ return_tensors="pt"
277
+ ).to(self.device)
278
+ input_ids = y_encoded["input_ids"]
279
+ attention_mask = y_encoded["attention_mask"]
280
+ y_encoded = self.conditional_model(input_ids, attention_mask)
281
+ else:
282
+ y_encoded = None
283
+
284
+ self.optimizer.zero_grad()
285
+ with autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
286
+ noise = torch.randn_like(x).to(self.device)
287
+ t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
288
+ assert x.device == noise.device == t.device, "Device mismatch detected"
289
+ assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
290
+ noisy_x = self.forward_diffusion(x, noise, t)
291
+ p_noise = self.noise_predictor(noisy_x, t, y_encoded)
292
+ loss = self.objective(p_noise, noise)
293
+ scaler.scale(loss).backward()
294
+
295
+ nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
296
+ if self.conditional_model is not None:
297
+ nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
298
+ scaler.step(self.optimizer)
299
+ scaler.update()
300
+ self.warmup_lr_scheduler.step()
301
+ train_losses_.append(loss.item())
302
+
303
+ mean_train_loss = torch.mean(torch.tensor(train_losses_)).item()
304
+ train_losses.append(mean_train_loss)
305
+ print(f"\nEpoch: {epoch + 1} | Train Loss: {mean_train_loss:.4f}", end="")
306
+
307
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
308
+ val_loss = self.validate()
309
+ print(f" | Val Loss: {val_loss:.4f}")
310
+ current_best = val_loss
311
+ self.scheduler.step(val_loss)
312
+ else:
313
+ print()
314
+ current_best = mean_train_loss
315
+ self.scheduler.step(mean_train_loss)
316
+
317
+ if current_best < best_val_loss:
318
+ best_val_loss = current_best
319
+ wait = 0
320
+ try:
321
+ torch.save({
322
+ 'epoch': epoch + 1,
323
+ 'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
324
+ 'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
325
+ 'optimizer_state_dict': self.optimizer.state_dict(),
326
+ 'loss': best_val_loss,
327
+ 'hyper_params_model': self.hyper_params_model,
328
+ 'max_epoch': self.max_epoch,
329
+ }, self.store_path)
330
+ print(f"Model saved at epoch {epoch + 1}")
331
+ except Exception as e:
332
+ print(f"Failed to save model: {e}")
333
+ else:
334
+ wait += 1
335
+ if wait >= self.patience:
336
+ print("Early stopping triggered")
337
+ try:
338
+ torch.save({
339
+ 'epoch': epoch + 1,
340
+ 'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
341
+ 'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
342
+ 'optimizer_state_dict': self.optimizer.state_dict(),
343
+ 'loss': best_val_loss,
344
+ 'hyper_params_model': self.hyper_params_model,
345
+ 'max_epoch': self.max_epoch,
346
+ }, self.store_path + "_early_stop.pth")
347
+ print(f"Final model saved at {self.store_path}_early_stop.pth")
348
+ except Exception as e:
349
+ print(f"Failed to save final model: {e}")
350
+ break
351
+
352
+ return train_losses, best_val_loss
353
+
354
+ def validate(self):
355
+ """Validates the LDM on the validation dataset.
356
+
357
+ Computes the validation loss using the noise predictor and forward diffusion
358
+ process in the latent space, with optional conditional inputs.
359
+
360
+ Returns
361
+ -------
362
+ float
363
+ Mean validation loss across the validation dataset.
364
+
365
+ Notes
366
+ -----
367
+ - Validation is performed with `torch.no_grad()` for efficiency.
368
+ - The compressor model is used to encode validation data into the latent space.
369
+ - The noise predictor and conditional model (if applicable) are set to evaluation
370
+ mode during validation and restored to training mode afterward.
371
+ """
372
+ self.noise_predictor.eval()
373
+ if self.conditional_model is not None:
374
+ self.conditional_model.eval()
375
+
376
+ val_losses = []
377
+ with torch.no_grad():
378
+ for x, y in self.val_loader:
379
+ x = x.to(self.device)
380
+ with torch.no_grad():
381
+ x, _ = self.compressor_model.encode(x)
382
+
383
+ if self.conditional_model is not None:
384
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
385
+ y_list = [str(item) for item in y_list]
386
+ y_encoded = self.tokenizer(
387
+ y_list,
388
+ padding="max_length",
389
+ truncation=True,
390
+ max_length=self.max_length,
391
+ return_tensors="pt"
392
+ ).to(self.device)
393
+ input_ids = y_encoded["input_ids"]
394
+ attention_mask = y_encoded["attention_mask"]
395
+ y_encoded = self.conditional_model(input_ids, attention_mask)
396
+ else:
397
+ y_encoded = None
398
+
399
+ noise = torch.randn_like(x).to(self.device)
400
+ t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
401
+ assert x.device == noise.device == t.device, "Device mismatch detected"
402
+ assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
403
+ noisy_x = self.forward_diffusion(x, noise, t)
404
+ p_noise = self.noise_predictor(noisy_x, t, y_encoded)
405
+ loss = self.objective(p_noise, noise)
406
+ val_losses.append(loss.item())
407
+
408
+ mean_val_loss = torch.mean(torch.tensor(val_losses)).item()
409
+ self.noise_predictor.train()
410
+ if self.conditional_model is not None:
411
+ self.conditional_model.train()
412
+ return mean_val_loss
sde/__init__.py ADDED
File without changes
sde/forward_sde.py ADDED
@@ -0,0 +1,98 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ class ForwardSDE(nn.Module):
7
+ """Forward diffusion process for SDE-based generative models.
8
+
9
+ Implements the forward diffusion process for score-based generative models using
10
+ Stochastic Differential Equations (SDEs), supporting Variance Exploding (VE),
11
+ Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE methods, as
12
+ described in Song et al. (2021).
13
+
14
+ Parameters
15
+ ----------
16
+ hyper_params : object
17
+ Hyperparameter object containing SDE-specific parameters. Expected to have
18
+ attributes:
19
+ - `dt`: Time step size for SDE integration (float).
20
+ - `sigmas`: Sigma values for VE method (torch.Tensor, optional).
21
+ - `betas`: Beta values for VP, sub-VP, or ODE methods (torch.Tensor).
22
+ - `cum_betas`: Cumulative beta values for sub-VP method (torch.Tensor, optional).
23
+ method : str
24
+ SDE method to use. Supported methods: "ve", "vp", "sub-vp", "ode".
25
+
26
+ Attributes
27
+ ----------
28
+ hyper_params : object
29
+ Stores the provided hyperparameter object.
30
+ method : str
31
+ Selected SDE method.
32
+
33
+ Raises
34
+ ------
35
+ ValueError
36
+ If `method` is not one of the supported methods ("ve", "vp", "sub-vp", "ode").
37
+ """
38
+ def __init__(self, hyper_params, method):
39
+ super().__init__()
40
+ self.hyper_params = hyper_params
41
+ self.method = method
42
+
43
+ def forward(self, x0, noise, time_steps):
44
+ """Applies the forward SDE diffusion process to the input data.
45
+
46
+ Perturbs the input data `x0` by adding noise according to the specified SDE
47
+ method at given time steps, incorporating drift and diffusion terms as applicable.
48
+
49
+ Parameters
50
+ ----------
51
+ x0 : torch.Tensor
52
+ Input data tensor, shape (batch_size, channels, height, width).
53
+ noise : torch.Tensor
54
+ Gaussian noise tensor, same shape as `x0`.
55
+ time_steps : torch.Tensor
56
+ Tensor of time step indices (long), shape (batch_size,), where each value
57
+ is in the range [0, hyper_params.num_steps - 1].
58
+
59
+ Returns
60
+ -------
61
+ torch.Tensor
62
+ Noisy data tensor at the specified time steps, same shape as `x0`.
63
+
64
+ Raises
65
+ ------
66
+ ValueError
67
+ If `method` is not one of the supported methods ("ve", "vp", "sub-vp", "ode").
68
+ """
69
+ dt = self.hyper_params.dt
70
+ if self.method == "ve":
71
+ sigma_t = self.hyper_params.sigmas[time_steps]
72
+ sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
73
+ sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
74
+ x0 = x0 + noise * sigma_diff.view(-1, 1, 1, 1)
75
+
76
+ elif self.method == "vp":
77
+ betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
78
+ drift = -0.5 * betas * x0 * dt
79
+ diffusion = torch.sqrt(betas * dt) * noise
80
+ x0 = x0 + drift + diffusion
81
+
82
+ elif self.method == "sub-vp":
83
+ betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
84
+ cum_betas = self.hyper_params.cum_betas[time_steps].view(-1, 1, 1, 1)
85
+ drift = -0.5 * betas * x0 * dt
86
+ diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise
87
+ x0 = x0 + drift + diffusion
88
+
89
+ elif self.method == "ode":
90
+ if self.method == "ve":
91
+ x0 = x0
92
+ else:
93
+ betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
94
+ drift = -0.5 * betas * x0 * dt
95
+ x0 = x0 + drift
96
+ else:
97
+ raise ValueError(f"Unknown method: {self.method}")
98
+ return x0