TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
ddim/train_ddim.py ADDED
@@ -0,0 +1,394 @@
1
+ """Training module for Denoising Diffusion Implicit Models (DDIM).
2
+
3
+ This module implements the training process for DDIM, as described in Song et al. (2021,
4
+ "Denoising Diffusion Implicit Models"). It supports both unconditional and conditional
5
+ training with text prompts, using mixed precision and learning rate scheduling.
6
+ """
7
+
8
+ import torch
9
+ from torch.cuda.amp import GradScaler, autocast
10
+ import torch.nn as nn
11
+ from tqdm import tqdm
12
+ from torch.amp import GradScaler, autocast
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from transformers import BertTokenizer
15
+ import warnings
16
+ from forward_ddim import ForwardDDIM
17
+
18
+
19
+
20
+
21
+ class TrainDDIM(nn.Module):
22
+ """Trainer for Denoising Diffusion Implicit Models (DDIM).
23
+
24
+ Manages the training process for DDIM, optimizing a noise predictor model to learn
25
+ the noise added by the forward diffusion process. Supports conditional training with
26
+ text prompts, mixed precision training, learning rate scheduling, early stopping, and
27
+ checkpointing, as inspired by Song et al. (2021).
28
+
29
+ Parameters
30
+ ----------
31
+ noise_predictor : nn.Module
32
+ Model to predict noise added during the forward diffusion process.
33
+ hyper_params_model : nn.Module
34
+ Hyperparameter module (e.g., HyperParamsDDIM) defining the noise schedule.
35
+ data_loader : torch.utils.data.DataLoader
36
+ DataLoader for training data.
37
+ optimizer : torch.optim.Optimizer
38
+ Optimizer for training the noise predictor and conditional model (if applicable).
39
+ objective : callable
40
+ Loss function to compute the difference between predicted and actual noise.
41
+ val_loader : torch.utils.data.DataLoader, optional
42
+ DataLoader for validation data, default None.
43
+ max_epoch : int, optional
44
+ Maximum number of training epochs (default: 1000).
45
+ device : torch.device, optional
46
+ Device for computation (default: CUDA if available, else CPU).
47
+ conditional_model : nn.Module, optional
48
+ Model for conditional generation (e.g., text embeddings), default None.
49
+ tokenizer : BertTokenizer, optional
50
+ Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
51
+ max_length : int, optional
52
+ Maximum length for tokenized prompts (default: 77).
53
+ store_path : str, optional
54
+ Path to save model checkpoints (default: "ddim_model.pth").
55
+ patience : int, optional
56
+ Number of epochs to wait for improvement before early stopping (default: 10).
57
+ warmup_epochs : int, optional
58
+ Number of epochs for learning rate warmup (default: 100).
59
+ val_frequency : int, optional
60
+ Frequency (in epochs) for validation (default: 10).
61
+
62
+ Attributes
63
+ ----------
64
+ device : torch.device
65
+ Device used for computation.
66
+ noise_predictor : nn.Module
67
+ Noise prediction model.
68
+ hyper_params_model : nn.Module
69
+ Hyperparameter module for the noise schedule.
70
+ conditional_model : nn.Module or None
71
+ Conditional model for text-based training, if provided.
72
+ optimizer : torch.optim.Optimizer
73
+ Optimizer for training.
74
+ objective : callable
75
+ Loss function for training.
76
+ store_path : str
77
+ Path for saving checkpoints.
78
+ data_loader : torch.utils.data.DataLoader
79
+ Training data loader.
80
+ val_loader : torch.utils.data.DataLoader or None
81
+ Validation data loader, if provided.
82
+ max_epoch : int
83
+ Maximum training epochs.
84
+ max_length : int
85
+ Maximum length for tokenized prompts.
86
+ patience : int
87
+ Patience for early stopping.
88
+ scheduler : torch.optim.lr_scheduler.ReduceLROnPlateau
89
+ Learning rate scheduler based on validation or training loss.
90
+ forward_diffusion : ForwardDDIM
91
+ Forward diffusion module for DDIM.
92
+ warmup_lr_scheduler : torch.optim.lr_scheduler.LambdaLR
93
+ Learning rate scheduler for warmup.
94
+ val_frequency : int
95
+ Frequency for validation.
96
+ tokenizer : BertTokenizer
97
+ Tokenizer for text prompts.
98
+
99
+ Raises
100
+ ------
101
+ ValueError
102
+ If the default tokenizer ("bert-base-uncased") fails to load and no tokenizer is provided.
103
+ """
104
+ def __init__(self, noise_predictor, hyper_params_model, data_loader, optimizer, objective, val_loader=None,
105
+ max_epoch=1000, device=None, conditional_model=None, tokenizer=None, max_length=77,
106
+ store_path=None, patience=10, warmup_epochs=100, val_frequency=10):
107
+ super().__init__()
108
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
109
+ self.noise_predictor = noise_predictor
110
+ self.hyper_params_model = hyper_params_model.to(self.device)
111
+ self.conditional_model = conditional_model
112
+ self.optimizer = optimizer
113
+ self.objective = objective
114
+ self.store_path = store_path or "ddim_model.pth"
115
+ self.data_loader = data_loader
116
+ self.val_loader = val_loader
117
+ self.max_epoch = max_epoch
118
+ self.max_length = max_length
119
+ self.patience = patience
120
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=self.patience, factor=0.5)
121
+ self.forward_diffusion = ForwardDDIM(hyper_params=self.hyper_params_model).to(self.device)
122
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
123
+ self.val_frequency = val_frequency
124
+ if tokenizer is None:
125
+ try:
126
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
127
+ except Exception as e:
128
+ raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
129
+
130
+ def load_checkpoint(self, checkpoint_path):
131
+ """Loads a training checkpoint to resume training.
132
+
133
+ Restores the state of the noise predictor, conditional model (if applicable),
134
+ and optimizer from a saved checkpoint.
135
+
136
+ Parameters
137
+ ----------
138
+ checkpoint_path : str
139
+ Path to the checkpoint file.
140
+
141
+ Returns
142
+ -------
143
+ tuple
144
+ A tuple containing:
145
+ - epoch: The epoch at which the checkpoint was saved (int).
146
+ - loss: The loss at the checkpoint (float).
147
+
148
+ Raises
149
+ ------
150
+ FileNotFoundError
151
+ If the checkpoint file is not found.
152
+ KeyError
153
+ If the checkpoint is missing required keys ('model_state_dict_noise_predictor'
154
+ or 'optimizer_state_dict').
155
+
156
+ Warns
157
+ -----
158
+ warnings.warn
159
+ If the optimizer state cannot be loaded, if the checkpoint contains a
160
+ conditional model state but none is defined, or if no conditional model
161
+ state is provided when expected.
162
+ """
163
+ try:
164
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
165
+ except FileNotFoundError:
166
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
167
+
168
+ if 'model_state_dict_noise_predictor' not in checkpoint:
169
+ raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
170
+ self.noise_predictor.load_state_dict(checkpoint['model_state_dict_noise_predictor'])
171
+
172
+ if self.conditional_model is not None:
173
+ if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
174
+ self.conditional_model.load_state_dict(checkpoint['model_state_dict_conditional'])
175
+ else:
176
+ warnings.warn(
177
+ "Checkpoint contains no 'model_state_dict_conditional' or it is None, skipping conditional model loading")
178
+ elif 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
179
+ warnings.warn(
180
+ "Checkpoint contains conditional model state, but no conditional model is defined in this instance")
181
+
182
+ if 'optimizer_state_dict' not in checkpoint:
183
+ raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
184
+ try:
185
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
186
+ except ValueError as e:
187
+ warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
188
+
189
+ epoch = checkpoint.get('epoch', -1)
190
+ loss = checkpoint.get('loss', float('inf'))
191
+
192
+ self.noise_predictor.to(self.device)
193
+ if self.conditional_model is not None:
194
+ self.conditional_model.to(self.device)
195
+
196
+ print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
197
+ return epoch, loss
198
+
199
+ @staticmethod
200
+ def warmup_scheduler(optimizer, warmup_epochs=10):
201
+ """Creates a learning rate scheduler for warmup.
202
+
203
+ Generates a scheduler that linearly increases the learning rate from 0 to the
204
+ optimizer's initial value over the specified warmup epochs, then maintains it.
205
+
206
+ Parameters
207
+ ----------
208
+ optimizer : torch.optim.Optimizer
209
+ Optimizer to apply the scheduler to.
210
+ warmup_epochs : int, optional
211
+ Number of epochs for the warmup phase (default: 10).
212
+
213
+ Returns
214
+ -------
215
+ torch.optim.lr_scheduler.LambdaLR
216
+ Learning rate scheduler for warmup.
217
+ """
218
+ def lr_lambda(epoch):
219
+ if epoch < warmup_epochs:
220
+ return epoch / warmup_epochs
221
+ return 1.0
222
+
223
+ return LambdaLR(optimizer, lr_lambda)
224
+
225
+ def forward(self):
226
+ """Trains the DDIM model to predict noise added by the forward diffusion process.
227
+
228
+ Executes the training loop, optimizing the noise predictor and conditional model
229
+ (if applicable) using mixed precision, gradient clipping, and learning rate
230
+ scheduling. Supports validation, early stopping, and checkpointing.
231
+
232
+ Returns
233
+ -------
234
+ tuple
235
+ A tuple containing:
236
+ - train_losses: List of mean training losses per epoch (list of float).
237
+ - best_val_loss: Best validation or training loss achieved (float).
238
+
239
+ Notes
240
+ -----
241
+ - Training uses mixed precision via `torch.cuda.amp` or `torch.amp` for efficiency.
242
+ - Checkpoints are saved when the validation (or training) loss improves, and on early stopping.
243
+ - Early stopping is triggered if no improvement occurs for `patience` epochs.
244
+ """
245
+ self.noise_predictor.train()
246
+ self.noise_predictor.to(self.device)
247
+ if self.conditional_model is not None:
248
+ self.conditional_model.train()
249
+ self.conditional_model.to(self.device)
250
+
251
+ scaler = GradScaler()
252
+ train_losses = []
253
+ best_val_loss = float("inf")
254
+ wait = 0
255
+ for epoch in range(self.max_epoch):
256
+ train_losses_ = []
257
+ for x, y in tqdm(self.data_loader):
258
+ x = x.to(self.device)
259
+
260
+ if self.conditional_model is not None:
261
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
262
+ y_list = [str(item) for item in y_list]
263
+ y_encoded = self.tokenizer(
264
+ y_list,
265
+ padding="max_length",
266
+ truncation=True,
267
+ max_length=self.max_length,
268
+ return_tensors="pt"
269
+ ).to(self.device)
270
+ input_ids = y_encoded["input_ids"]
271
+ attention_mask = y_encoded["attention_mask"]
272
+ y_encoded = self.conditional_model(input_ids, attention_mask)
273
+ else:
274
+ y_encoded = None
275
+
276
+ self.optimizer.zero_grad()
277
+ with autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
278
+ noise = torch.randn_like(x).to(self.device)
279
+ t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
280
+ assert x.device == noise.device == t.device, "Device mismatch detected"
281
+ assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
282
+ noisy_x = self.forward_diffusion(x, noise, t)
283
+ p_noise = self.noise_predictor(noisy_x, t, y_encoded)
284
+ loss = self.objective(p_noise, noise)
285
+ scaler.scale(loss).backward()
286
+ nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
287
+ if self.conditional_model is not None:
288
+ nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
289
+ scaler.step(self.optimizer)
290
+ scaler.update()
291
+ self.warmup_lr_scheduler.step()
292
+ train_losses_.append(loss.item())
293
+
294
+ mean_train_loss = torch.mean(torch.tensor(train_losses_)).item()
295
+ train_losses.append(mean_train_loss)
296
+ print(f"\nEpoch: {epoch + 1} | Train Loss: {mean_train_loss:.4f}", end="")
297
+
298
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
299
+ val_loss = self.validate()
300
+ print(f" | Val Loss: {val_loss:.4f}")
301
+ current_best = val_loss
302
+ self.scheduler.step(val_loss)
303
+ else:
304
+ print()
305
+ current_best = mean_train_loss
306
+ self.scheduler.step(mean_train_loss)
307
+
308
+ if current_best < best_val_loss:
309
+ best_val_loss = current_best
310
+ wait = 0
311
+ try:
312
+ torch.save({
313
+ 'epoch': epoch + 1,
314
+ 'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
315
+ 'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
316
+ 'optimizer_state_dict': self.optimizer.state_dict(),
317
+ 'loss': best_val_loss,
318
+ 'hyper_params_model': self.hyper_params_model,
319
+ 'max_epoch': self.max_epoch,
320
+ }, self.store_path)
321
+ print(f"Model saved at epoch {epoch + 1}")
322
+ except Exception as e:
323
+ print(f"Failed to save model: {e}")
324
+ else:
325
+ wait += 1
326
+ if wait >= self.patience:
327
+ print("Early stopping triggered")
328
+ try:
329
+ torch.save({
330
+ 'epoch': epoch + 1,
331
+ 'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
332
+ 'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
333
+ 'optimizer_state_dict': self.optimizer.state_dict(),
334
+ 'loss': best_val_loss,
335
+ 'hyper_params_model': self.hyper_params_model,
336
+ 'max_epoch': self.max_epoch,
337
+ }, self.store_path + "_early_stop.pth")
338
+ print(f"Final model saved at {self.store_path}_early_stop.pth")
339
+ except Exception as e:
340
+ print(f"Failed to save final model: {e}")
341
+ break
342
+
343
+ return train_losses, best_val_loss
344
+
345
+ def validate(self):
346
+ """Validates the DDIM model on the validation dataset.
347
+
348
+ Computes the validation loss using the noise predictor and forward diffusion
349
+ process, with optional conditional inputs.
350
+
351
+ Returns
352
+ -------
353
+ float
354
+ Mean validation loss across the validation dataset.
355
+ """
356
+ self.noise_predictor.eval()
357
+ if self.conditional_model is not None:
358
+ self.conditional_model.eval()
359
+
360
+ val_losses = []
361
+ with torch.no_grad():
362
+ for x, y in self.val_loader:
363
+ x = x.to(self.device)
364
+
365
+ if self.conditional_model is not None:
366
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
367
+ y_list = [str(item) for item in y_list]
368
+ y_encoded = self.tokenizer(
369
+ y_list,
370
+ padding="max_length",
371
+ truncation=True,
372
+ max_length=self.max_length,
373
+ return_tensors="pt"
374
+ ).to(self.device)
375
+ input_ids = y_encoded["input_ids"]
376
+ attention_mask = y_encoded["attention_mask"]
377
+ y_encoded = self.conditional_model(input_ids, attention_mask)
378
+ else:
379
+ y_encoded = None
380
+
381
+ noise = torch.randn_like(x).to(self.device)
382
+ t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
383
+ assert x.device == noise.device == t.device, "Device mismatch detected"
384
+ assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
385
+ noisy_x = self.forward_diffusion(x, noise, t)
386
+ p_noise = self.noise_predictor(noisy_x, t, y_encoded)
387
+ loss = self.objective(p_noise, noise)
388
+ val_losses.append(loss.item())
389
+
390
+ mean_val_loss = torch.mean(torch.tensor(val_losses)).item()
391
+ self.noise_predictor.train()
392
+ if self.conditional_model is not None:
393
+ self.conditional_model.train()
394
+ return mean_val_loss
ddpm/__init__.py ADDED
File without changes
ddpm/forward_ddpm.py ADDED
@@ -0,0 +1,89 @@
1
+ """Forward diffusion process for Denoising Diffusion Probabilistic Models (DDPM).
2
+
3
+ This module implements the forward diffusion process as described in the DDPM paper
4
+ (Ho et al., 2020, "Denoising Diffusion Probabilistic Models"). The forward process
5
+ gradually adds noise to the input data according to a predefined noise schedule.
6
+ """
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+
14
+ class ForwardDDPM(nn.Module):
15
+ """Forward diffusion process of DDPM.
16
+
17
+ Implements the forward diffusion process for DDPM, which perturbs input data by
18
+ adding Gaussian noise over a series of time steps, as defined in Ho et al. (2020).
19
+ The noise schedule can be either fixed or trainable, depending on the provided
20
+ hyperparameters.
21
+
22
+ Parameters
23
+ ----------
24
+ hyper_params : object
25
+ Hyperparameter object containing the noise schedule parameters. Expected to have
26
+ attributes:
27
+ - `num_steps`: Number of diffusion steps (int).
28
+ - `trainable_beta`: Whether the noise schedule is trainable (bool).
29
+ - `betas`: Noise schedule parameters (torch.Tensor, optional if trainable_beta is True).
30
+ - `sqrt_alpha_bars`: Precomputed cumulative product of alphas (torch.Tensor, optional if trainable_beta is False).
31
+ - `sqrt_one_minus_alpha_bars`: Precomputed square root of one minus cumulative alpha product (torch.Tensor, optional if trainable_beta is False).
32
+ - `compute_schedule`: Method to compute the noise schedule (callable, optional if trainable_beta is True).
33
+
34
+ Attributes
35
+ ----------
36
+ hyper_params : object
37
+ Stores the provided hyperparameter object for use in the forward process.
38
+ """
39
+ def __init__(self, hyper_params):
40
+ super().__init__()
41
+ self.hyper_params = hyper_params
42
+
43
+ def forward(self, x0, noise, time_steps):
44
+ """Applies the forward diffusion process to the input data.
45
+
46
+ Perturbs the input data `x0` by adding Gaussian noise according to the DDPM
47
+ forward process at specified time steps. The noise is scaled based on the
48
+ cumulative noise schedule parameters (`sqrt_alpha_bar_t` and
49
+ `sqrt_one_minus_alpha_bar_t`).
50
+
51
+ Parameters
52
+ ----------
53
+ x0 : torch.Tensor
54
+ Input data tensor of shape (batch_size, channels, height, width).
55
+ noise : torch.Tensor
56
+ Gaussian noise tensor of the same shape as `x0`.
57
+ time_steps : torch.Tensor
58
+ Tensor of time step indices (long), shape (batch_size,), where each value
59
+ is in the range [0, hyper_params.num_steps - 1].
60
+
61
+ Returns
62
+ -------
63
+ torch.Tensor
64
+ Noisy data tensor `xt` at the specified time steps, with the same shape as `x0`.
65
+
66
+ Raises
67
+ ------
68
+ ValueError
69
+ If any value in `time_steps` is outside the valid range
70
+ [0, hyper_params.num_steps - 1].
71
+ """
72
+ if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.num_steps)):
73
+ raise ValueError(f"time_steps must be between 0 and {self.hyper_params.num_steps - 1}")
74
+
75
+ if self.hyper_params.trainable_beta:
76
+ _, _, _, sqrt_alpha_bar_t, sqrt_one_minus_alpha_bar_t = self.hyper_params.compute_schedule(
77
+ self.hyper_params.betas
78
+ )
79
+ sqrt_alpha_bar_t = sqrt_alpha_bar_t[time_steps].to(x0.device)
80
+ sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t[time_steps].to(x0.device)
81
+ else:
82
+ sqrt_alpha_bar_t = self.hyper_params.sqrt_alpha_bars[time_steps].to(x0.device)
83
+ sqrt_one_minus_alpha_bar_t = self.hyper_params.sqrt_one_minus_alpha_bars[time_steps].to(x0.device)
84
+
85
+ sqrt_alpha_bar_t = sqrt_alpha_bar_t.view(-1, 1, 1, 1)
86
+ sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.view(-1, 1, 1, 1)
87
+
88
+ xt = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise
89
+ return xt
ddpm/hyper_param.py ADDED
@@ -0,0 +1,180 @@
1
+ """Hyperparameters for Denoising Diffusion Probabilistic Models (DDPM) noise schedule.
2
+
3
+ This module implements a flexible noise schedule for DDPM, as described in Ho et al.
4
+ (2020, "Denoising Diffusion Probabilistic Models"). It supports multiple beta schedule
5
+ methods and allows for trainable or fixed noise schedules.
6
+ """
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+
14
+
15
+ class HyperParamsDDPM(nn.Module):
16
+ """Hyperparameters for DDPM noise schedule with flexible beta computation.
17
+
18
+ Manages the noise schedule parameters for DDPM, including the computation of beta
19
+ values and derived quantities (alphas, alpha_bars, etc.), with support for
20
+ trainable or fixed schedules and various beta scheduling methods, as inspired by
21
+ Ho et al. (2020).
22
+
23
+ Parameters
24
+ ----------
25
+ num_steps : int, optional
26
+ Number of diffusion steps (default: 1000).
27
+ beta_start : float, optional
28
+ Starting value for beta (default: 1e-4).
29
+ beta_end : float, optional
30
+ Ending value for beta (default: 0.02).
31
+ trainable_beta : bool, optional
32
+ Whether the beta schedule is trainable (default: False).
33
+ beta_method : str, optional
34
+ Method for computing the beta schedule (default: "linear").
35
+ Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
36
+
37
+ Attributes
38
+ ----------
39
+ num_steps : int
40
+ Number of diffusion steps.
41
+ beta_start : float
42
+ Minimum beta value.
43
+ beta_end : float
44
+ Maximum beta value.
45
+ trainable_beta : bool
46
+ Whether the beta schedule is trainable.
47
+ beta_method : str
48
+ Method used for beta schedule computation.
49
+ betas : torch.Tensor
50
+ Beta schedule values, shape (num_steps,). Trainable if `trainable_beta` is True,
51
+ otherwise a fixed buffer.
52
+ alphas : torch.Tensor, optional
53
+ Alpha values (1 - betas), shape (num_steps,). Available if `trainable_beta` is False.
54
+ alpha_bars : torch.Tensor, optional
55
+ Cumulative product of alphas, shape (num_steps,). Available if `trainable_beta` is False.
56
+ sqrt_alpha_bars : torch.Tensor, optional
57
+ Square root of alpha_bars, shape (num_steps,). Available if `trainable_beta` is False.
58
+ sqrt_one_minus_alpha_bars : torch.Tensor, optional
59
+ Square root of (1 - alpha_bars), shape (num_steps,). Available if `trainable_beta` is False.
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If `beta_start` or `beta_end` do not satisfy 0 < beta_start < beta_end < 1,
65
+ or if `num_steps` is not positive.
66
+ """
67
+ def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, trainable_beta=False, beta_method="linear"):
68
+ super().__init__()
69
+ self.num_steps = num_steps
70
+ self.beta_start = beta_start
71
+ self.beta_end = beta_end
72
+ self.trainable_beta = trainable_beta
73
+ self.beta_method = beta_method
74
+
75
+ # validate inputs
76
+ if not (0 < beta_start < beta_end < 1):
77
+ raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
78
+ if num_steps <= 0:
79
+ raise ValueError(f"num_steps ({num_steps}) must be positive")
80
+
81
+ # compute initial beta schedule
82
+ beta_range = (beta_start, beta_end)
83
+ betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
84
+
85
+ # initialize betas
86
+ if trainable_beta:
87
+ self.betas = nn.Parameter(betas_init) # Trainable parameter
88
+ else:
89
+ self.register_buffer('betas', betas_init) # Fixed buffer
90
+ self.register_buffer('alphas', 1 - self.betas)
91
+ self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
92
+ self.register_buffer('sqrt_alpha_bars', torch.sqrt(self.alpha_bars))
93
+ self.register_buffer('sqrt_one_minus_alpha_bars', torch.sqrt(1 - self.alpha_bars))
94
+
95
+ def compute_beta_schedule(self, beta_range, num_steps, method):
96
+ """Computes the beta schedule based on the specified method.
97
+
98
+ Generates a sequence of beta values for the DDPM noise schedule using the
99
+ chosen method, ensuring values are clamped within the specified range.
100
+
101
+ Parameters
102
+ ----------
103
+ beta_range : tuple
104
+ Tuple of (min_beta, max_beta) specifying the valid range for beta values.
105
+ num_steps : int
106
+ Number of diffusion steps.
107
+ method : str
108
+ Method for computing the beta schedule. Supported methods:
109
+ "linear", "sigmoid", "quadratic", "constant", "inverse_time".
110
+
111
+ Returns
112
+ -------
113
+ torch.Tensor
114
+ Tensor of beta values, shape (num_steps,).
115
+
116
+ Raises
117
+ ------
118
+ ValueError
119
+ If `method` is not one of the supported beta schedule methods.
120
+ """
121
+ beta_min, beta_max = beta_range
122
+ if method == "sigmoid":
123
+ x = torch.linspace(-6, 6, num_steps)
124
+ beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
125
+ elif method == "quadratic":
126
+ x = torch.linspace(beta_min**0.5, beta_max**0.5, num_steps)
127
+ beta = x**2
128
+ elif method == "constant":
129
+ beta = torch.full((num_steps,), beta_max)
130
+ elif method == "inverse_time":
131
+ beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
132
+ # scale to beta_range
133
+ beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
134
+ elif method == "linear":
135
+ beta = torch.linspace(beta_min, beta_max, num_steps)
136
+ else:
137
+ raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
138
+
139
+ beta = torch.clamp(beta, min=beta_min, max=beta_max)
140
+ return beta
141
+
142
+ @staticmethod
143
+ def compute_schedule(betas):
144
+ """Computes noise schedule parameters dynamically from betas.
145
+
146
+ Calculates the derived noise schedule parameters (alphas, alpha_bars, etc.)
147
+ from the provided beta values, as used in the DDPM forward and reverse processes.
148
+
149
+ Parameters
150
+ ----------
151
+ betas : torch.Tensor
152
+ Tensor of beta values, shape (num_steps,).
153
+
154
+ Returns
155
+ -------
156
+ tuple
157
+ A tuple containing:
158
+ - betas: Input beta values, shape (num_steps,).
159
+ - alphas: 1 - betas, shape (num_steps,).
160
+ - alpha_bars: Cumulative product of alphas, shape (num_steps,).
161
+ - sqrt_alpha_bars: Square root of alpha_bars, shape (num_steps,).
162
+ - sqrt_one_minus_alpha_bars: Square root of (1 - alpha_bars), shape (num_steps,).
163
+ """
164
+ alphas = 1 - betas
165
+ alpha_bars = torch.cumprod(alphas, dim=0)
166
+ return betas, alphas, alpha_bars, torch.sqrt(alpha_bars), torch.sqrt(1 - alpha_bars)
167
+
168
+ def constrain_betas(self):
169
+ """Constrains trainable betas to a valid range during training.
170
+
171
+ Ensures that trainable beta values remain within the specified range
172
+ [beta_start, beta_end] by clamping them in-place.
173
+
174
+ Notes
175
+ -----
176
+ This method only applies when `trainable_beta` is True.
177
+ """
178
+ if self.trainable_beta:
179
+ with torch.no_grad():
180
+ self.betas.clamp_(min=self.beta_start, max=self.beta_end)