TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
ddpm/train_ddpm.py ADDED
@@ -0,0 +1,386 @@
1
+ import torch
2
+ from torch.cuda.amp import GradScaler, autocast
3
+ import torch.nn as nn
4
+ from tqdm import tqdm
5
+ from torch.amp import GradScaler, autocast
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+ from transformers import BertTokenizer
8
+ from forward_ddpm import ForwardDDPM
9
+ import warnings
10
+
11
+
12
+
13
+ class TrainDDPM(nn.Module):
14
+ """Trainer for Denoising Diffusion Probabilistic Models (DDPM).
15
+
16
+ Manages the training process for DDPM, optimizing a noise predictor model to learn
17
+ the noise added by the forward diffusion process. Supports conditional training with
18
+ text prompts, mixed precision training, learning rate scheduling, early stopping, and
19
+ checkpointing, as inspired by Ho et al. (2020).
20
+
21
+ Parameters
22
+ ----------
23
+ noise_predictor : nn.Module
24
+ Model to predict noise added during the forward diffusion process.
25
+ hyper_params_model : nn.Module
26
+ Hyperparameter module (e.g., HyperParamsDDPM) defining the noise schedule.
27
+ data_loader : torch.utils.data.DataLoader
28
+ DataLoader for training data.
29
+ optimizer : torch.optim.Optimizer
30
+ Optimizer for training the noise predictor and conditional model (if applicable).
31
+ objective : callable
32
+ Loss function to compute the difference between predicted and actual noise.
33
+ val_loader : torch.utils.data.DataLoader, optional
34
+ DataLoader for validation data, default None.
35
+ max_epoch : int, optional
36
+ Maximum number of training epochs (default: 1000).
37
+ device : torch.device, optional
38
+ Device for computation (default: CUDA if available, else CPU).
39
+ conditional_model : nn.Module, optional
40
+ Model for conditional generation (e.g., text embeddings), default None.
41
+ tokenizer : BertTokenizer, optional
42
+ Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
43
+ max_length : int, optional
44
+ Maximum length for tokenized prompts (default: 77).
45
+ store_path : str, optional
46
+ Path to save model checkpoints (default: "ddpm_model.pth").
47
+ patience : int, optional
48
+ Number of epochs to wait for improvement before early stopping (default: 10).
49
+ warmup_epochs : int, optional
50
+ Number of epochs for learning rate warmup (default: 100).
51
+ val_frequency : int, optional
52
+ Frequency (in epochs) for validation (default: 10).
53
+
54
+ Attributes
55
+ ----------
56
+ device : torch.device
57
+ Device used for computation.
58
+ noise_predictor : nn.Module
59
+ Noise prediction model.
60
+ hyper_params_model : nn.Module
61
+ Hyperparameter module for the noise schedule.
62
+ conditional_model : nn.Module or None
63
+ Conditional model for text-based training, if provided.
64
+ optimizer : torch.optim.Optimizer
65
+ Optimizer for training.
66
+ objective : callable
67
+ Loss function for training.
68
+ store_path : str
69
+ Path for saving checkpoints.
70
+ data_loader : torch.utils.data.DataLoader
71
+ Training data loader.
72
+ val_loader : torch.utils.data.DataLoader or None
73
+ Validation data loader, if provided.
74
+ max_epoch : int
75
+ Maximum training epochs.
76
+ max_length : int
77
+ Maximum length for tokenized prompts.
78
+ patience : int
79
+ Patience for early stopping.
80
+ scheduler : torch.optim.lr_scheduler.ReduceLROnPlateau
81
+ Learning rate scheduler based on validation or training loss.
82
+ forward_diffusion : ForwardDDPM
83
+ Forward diffusion module.
84
+ warmup_lr_scheduler : torch.optim.lr_scheduler.LambdaLR
85
+ Learning rate scheduler for warmup.
86
+ val_frequency : int
87
+ Frequency for validation.
88
+ tokenizer : BertTokenizer
89
+ Tokenizer for text prompts.
90
+
91
+ Raises
92
+ ------
93
+ ValueError
94
+ If the default tokenizer ("bert-base-uncased") fails to load and no tokenizer is provided.
95
+ """
96
+ def __init__(self, noise_predictor, hyper_params_model, data_loader, optimizer, objective, val_loader=None,
97
+ max_epoch=1000, device=None, conditional_model=None, tokenizer=None, max_length=77,
98
+ store_path=None, patience=10, warmup_epochs=100, val_frequency=10):
99
+ super().__init__()
100
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
+ self.noise_predictor = noise_predictor
102
+ self.hyper_params_model = hyper_params_model.to(self.device)
103
+ self.conditional_model = conditional_model
104
+ self.optimizer = optimizer
105
+ self.objective = objective
106
+ self.store_path = store_path or "ddpm_model.pth"
107
+ self.data_loader = data_loader
108
+ self.val_loader = val_loader
109
+ self.max_epoch = max_epoch
110
+ self.max_length = max_length
111
+ self.patience = patience
112
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=self.patience, factor=0.5)
113
+ self.forward_diffusion = ForwardDDPM(hyper_params=self.hyper_params_model).to(self.device)
114
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
115
+ self.val_frequency = val_frequency
116
+ if tokenizer is None:
117
+ try:
118
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
119
+ except Exception as e:
120
+ raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
121
+
122
+ def load_checkpoint(self, checkpoint_path):
123
+ """Loads a training checkpoint to resume training.
124
+
125
+ Restores the state of the noise predictor, conditional model (if applicable),
126
+ and optimizer from a saved checkpoint.
127
+
128
+ Parameters
129
+ ----------
130
+ checkpoint_path : str
131
+ Path to the checkpoint file.
132
+
133
+ Returns
134
+ -------
135
+ tuple
136
+ A tuple containing:
137
+ - epoch: The epoch at which the checkpoint was saved (int).
138
+ - loss: The loss at the checkpoint (float).
139
+
140
+ Raises
141
+ ------
142
+ FileNotFoundError
143
+ If the checkpoint file is not found.
144
+ KeyError
145
+ If the checkpoint is missing required keys ('model_state_dict_noise_predictor'
146
+ or 'optimizer_state_dict').
147
+
148
+ Warns
149
+ -----
150
+ warnings.warn
151
+ If the optimizer state cannot be loaded, if the checkpoint contains a
152
+ conditional model state but none is defined, or if no conditional model
153
+ state is provided when expected.
154
+ """
155
+ try:
156
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
157
+ except FileNotFoundError:
158
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
159
+
160
+ if 'model_state_dict_noise_predictor' not in checkpoint:
161
+ raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
162
+ self.noise_predictor.load_state_dict(checkpoint['model_state_dict_noise_predictor'])
163
+
164
+ if self.conditional_model is not None:
165
+ if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
166
+ self.conditional_model.load_state_dict(checkpoint['model_state_dict_conditional'])
167
+ else:
168
+ warnings.warn(
169
+ "Checkpoint contains no 'model_state_dict_conditional' or it is None, skipping conditional model loading")
170
+ elif 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
171
+ warnings.warn(
172
+ "Checkpoint contains conditional model state, but no conditional model is defined in this instance")
173
+
174
+ if 'optimizer_state_dict' not in checkpoint:
175
+ raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
176
+ try:
177
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
178
+ except ValueError as e:
179
+ warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
180
+
181
+ epoch = checkpoint.get('epoch', -1)
182
+ loss = checkpoint.get('loss', float('inf'))
183
+
184
+ self.noise_predictor.to(self.device)
185
+ if self.conditional_model is not None:
186
+ self.conditional_model.to(self.device)
187
+
188
+ print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
189
+ return epoch, loss
190
+
191
+ @staticmethod
192
+ def warmup_scheduler(optimizer, warmup_epochs=10):
193
+ """Creates a learning rate scheduler for warmup.
194
+
195
+ Generates a scheduler that linearly increases the learning rate from 0 to the
196
+ optimizer's initial value over the specified warmup epochs, then maintains it.
197
+
198
+ Parameters
199
+ ----------
200
+ optimizer : torch.optim.Optimizer
201
+ Optimizer to apply the scheduler to.
202
+ warmup_epochs : int, optional
203
+ Number of epochs for the warmup phase (default: 10).
204
+
205
+ Returns
206
+ -------
207
+ torch.optim.lr_scheduler.LambdaLR
208
+ Learning rate scheduler for warmup.
209
+ """
210
+ def lr_lambda(epoch):
211
+ if epoch < warmup_epochs:
212
+ return epoch / warmup_epochs
213
+ return 1.0
214
+
215
+ return LambdaLR(optimizer, lr_lambda)
216
+
217
+ def forward(self):
218
+ """Trains the DDPM model to predict noise added by the forward diffusion process.
219
+
220
+ Executes the training loop, optimizing the noise predictor and conditional model
221
+ (if applicable) using mixed precision, gradient clipping, and learning rate
222
+ scheduling. Supports validation, early stopping, and checkpointing.
223
+
224
+ Returns
225
+ -------
226
+ tuple
227
+ A tuple containing:
228
+ - train_losses: List of mean training losses per epoch (list of float).
229
+ - best_val_loss: Best validation or training loss achieved (float).
230
+
231
+ Notes
232
+ -----
233
+ - Training uses mixed precision via `torch.cuda.amp` or `torch.amp` for efficiency.
234
+ - Checkpoints are saved when the validation (or training) loss improves, and on early stopping.
235
+ - Early stopping is triggered if no improvement occurs for `patience` epochs.
236
+ """
237
+ self.noise_predictor.train()
238
+ self.noise_predictor.to(self.device)
239
+ if self.conditional_model is not None:
240
+ self.conditional_model.train()
241
+ self.conditional_model.to(self.device)
242
+
243
+ scaler = GradScaler()
244
+ train_losses = []
245
+ best_val_loss = float("inf")
246
+ wait = 0
247
+ for epoch in range(self.max_epoch):
248
+ train_losses_ = []
249
+ for x, y in tqdm(self.data_loader):
250
+ x = x.to(self.device)
251
+
252
+ if self.conditional_model is not None:
253
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
254
+ y_list = [str(item) for item in y_list]
255
+ y_encoded = self.tokenizer(
256
+ y_list,
257
+ padding="max_length",
258
+ truncation=True,
259
+ max_length=self.max_length,
260
+ return_tensors="pt"
261
+ ).to(self.device)
262
+ input_ids = y_encoded["input_ids"]
263
+ attention_mask = y_encoded["attention_mask"]
264
+ y_encoded = self.conditional_model(input_ids, attention_mask)
265
+ else:
266
+ y_encoded = None
267
+
268
+ self.optimizer.zero_grad()
269
+ with autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
270
+ noise = torch.randn_like(x).to(self.device)
271
+ t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
272
+ assert x.device == noise.device == t.device, "Device mismatch detected"
273
+ assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
274
+ noisy_x = self.forward_diffusion(x, noise, t)
275
+ p_noise = self.noise_predictor(noisy_x, t, y_encoded)
276
+ loss = self.objective(p_noise, noise)
277
+ scaler.scale(loss).backward()
278
+ nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
279
+ if self.conditional_model is not None:
280
+ nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
281
+ scaler.step(self.optimizer)
282
+ scaler.update()
283
+ self.warmup_lr_scheduler.step()
284
+ train_losses_.append(loss.item())
285
+
286
+ mean_train_loss = torch.mean(torch.tensor(train_losses_)).item()
287
+ train_losses.append(mean_train_loss)
288
+ print(f"\nEpoch: {epoch + 1} | Train Loss: {mean_train_loss:.4f}", end="")
289
+
290
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
291
+ val_loss = self.validate()
292
+ print(f" | Val Loss: {val_loss:.4f}")
293
+ current_best = val_loss
294
+ self.scheduler.step(val_loss)
295
+ else:
296
+ print()
297
+ current_best = mean_train_loss
298
+ self.scheduler.step(mean_train_loss)
299
+
300
+ if current_best < best_val_loss:
301
+ best_val_loss = current_best
302
+ wait = 0
303
+ try:
304
+ torch.save({
305
+ 'epoch': epoch+1,
306
+ 'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
307
+ 'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
308
+ 'optimizer_state_dict': self.optimizer.state_dict(),
309
+ 'loss': best_val_loss,
310
+ 'hyper_params_model': self.hyper_params_model,
311
+ 'max_epoch': self.max_epoch,
312
+ }, self.store_path)
313
+ print(f"Model saved at epoch {epoch + 1}")
314
+ except Exception as e:
315
+ print(f"Failed to save model: {e}")
316
+ else:
317
+ wait += 1
318
+ if wait >= self.patience:
319
+ print("Early stopping triggered")
320
+ try:
321
+ torch.save({
322
+ 'epoch': epoch+1,
323
+ 'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
324
+ 'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
325
+ 'optimizer_state_dict': self.optimizer.state_dict(),
326
+ 'loss': best_val_loss,
327
+ 'hyper_params_model': self.hyper_params_model,
328
+ 'max_epoch': self.max_epoch,
329
+ }, self.store_path + "_early_stop.pth")
330
+ print(f"Final model saved at {self.store_path}_early_stop.pth")
331
+ except Exception as e:
332
+ print(f"Failed to save final model: {e}")
333
+ break
334
+
335
+ return train_losses, best_val_loss
336
+
337
+ def validate(self):
338
+ """Validates the DDPM model on the validation dataset.
339
+
340
+ Computes the validation loss using the noise predictor and forward diffusion
341
+ process, with optional conditional inputs.
342
+
343
+ Returns
344
+ -------
345
+ float
346
+ Mean validation loss across the validation dataset.
347
+ """
348
+ self.noise_predictor.eval()
349
+ if self.conditional_model is not None:
350
+ self.conditional_model.eval()
351
+
352
+ val_losses = []
353
+ with torch.no_grad():
354
+ for x, y in self.val_loader:
355
+ x = x.to(self.device)
356
+
357
+ if self.conditional_model is not None:
358
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
359
+ y_list = [str(item) for item in y_list]
360
+ y_encoded = self.tokenizer(
361
+ y_list,
362
+ padding="max_length",
363
+ truncation=True,
364
+ max_length=self.max_length,
365
+ return_tensors="pt"
366
+ ).to(self.device)
367
+ input_ids = y_encoded["input_ids"]
368
+ attention_mask = y_encoded["attention_mask"]
369
+ y_encoded = self.conditional_model(input_ids, attention_mask)
370
+ else:
371
+ y_encoded = None
372
+
373
+ noise = torch.randn_like(x).to(self.device)
374
+ t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
375
+ assert x.device == noise.device == t.device, "Device mismatch detected"
376
+ assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
377
+ noisy_x = self.forward_diffusion(x, noise, t)
378
+ p_noise = self.noise_predictor(noisy_x, t, y_encoded)
379
+ loss = self.objective(p_noise, noise)
380
+ val_losses.append(loss.item())
381
+
382
+ mean_val_loss = torch.mean(torch.tensor(val_losses)).item()
383
+ self.noise_predictor.train()
384
+ if self.conditional_model is not None:
385
+ self.conditional_model.train()
386
+ return mean_val_loss
ldm/__init__.py ADDED
File without changes