TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
sde/train_sde.py ADDED
@@ -0,0 +1,400 @@
1
+ import torch
2
+ from torch.cuda.amp import GradScaler, autocast
3
+ import torch.nn as nn
4
+ from tqdm import tqdm
5
+ from torch.amp import GradScaler, autocast
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+ from transformers import BertTokenizer
8
+ from forward_sde import ForwardSDE
9
+ import warnings
10
+
11
+
12
+
13
+ class TrainSDE(nn.Module):
14
+ """Trainer for score-based generative models using Stochastic Differential Equations.
15
+
16
+ Manages the training process for SDE-based generative models, optimizing a noise
17
+ predictor to learn the noise added by the forward SDE process, as described in Song
18
+ et al. (2021). Supports conditional training with text prompts, mixed precision,
19
+ learning rate scheduling, early stopping, and checkpointing.
20
+
21
+ Parameters
22
+ ----------
23
+ method : str
24
+ SDE method to use for forward diffusion. Supported methods: "ve", "vp", "sub-vp", "ode".
25
+ noise_predictor : nn.Module
26
+ Model to predict noise added during the forward SDE process.
27
+ hyper_params_model : nn.Module
28
+ Hyperparameter module (e.g., HyperParamsSDE) defining the noise schedule and SDE parameters.
29
+ data_loader : torch.utils.data.DataLoader
30
+ DataLoader for training data.
31
+ optimizer : torch.optim.Optimizer
32
+ Optimizer for training the noise predictor and conditional model (if applicable).
33
+ objective : callable
34
+ Loss function to compute the difference between predicted and actual noise.
35
+ val_loader : torch.utils.data.DataLoader, optional
36
+ DataLoader for validation data, default None.
37
+ max_epoch : int, optional
38
+ Maximum number of training epochs (default: 1000).
39
+ device : torch.device, optional
40
+ Device for computation (default: CUDA if available, else CPU).
41
+ conditional_model : nn.Module, optional
42
+ Model for conditional generation (e.g., text embeddings), default None.
43
+ tokenizer : BertTokenizer, optional
44
+ Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
45
+ max_length : int, optional
46
+ Maximum length for tokenized prompts (default: 77).
47
+ store_path : str, optional
48
+ Path to save model checkpoints (default: "sde_model.pth").
49
+ patience : int, optional
50
+ Number of epochs to wait for improvement before early stopping (default: 10).
51
+ warmup_epochs : int, optional
52
+ Number of epochs for learning rate warmup (default: 100).
53
+ val_frequency : int, optional
54
+ Frequency (in epochs) for validation (default: 10).
55
+
56
+ Attributes
57
+ ----------
58
+ device : torch.device
59
+ Device used for computation.
60
+ method : str
61
+ Selected SDE method.
62
+ noise_predictor : nn.Module
63
+ Noise prediction model.
64
+ hyper_params_model : nn.Module
65
+ Hyperparameter module for the noise schedule and SDE parameters.
66
+ conditional_model : nn.Module or None
67
+ Conditional model for text-based training, if provided.
68
+ optimizer : torch.optim.Optimizer
69
+ Optimizer for training.
70
+ objective : callable
71
+ Loss function for training.
72
+ store_path : str
73
+ Path for saving checkpoints.
74
+ data_loader : torch.utils.data.DataLoader
75
+ Training data loader.
76
+ val_loader : torch.utils.data.DataLoader or None
77
+ Validation data loader, if provided.
78
+ max_epoch : int
79
+ Maximum training epochs.
80
+ max_length : int
81
+ Maximum length for tokenized prompts.
82
+ patience : int
83
+ Patience for early stopping.
84
+ scheduler : torch.optim.lr_scheduler.ReduceLROnPlateau
85
+ Learning rate scheduler based on validation or training loss.
86
+ forward_diffusion : ForwardSDE
87
+ Forward SDE diffusion module.
88
+ warmup_lr_scheduler : torch.optim.lr_scheduler.LambdaLR
89
+ Learning rate scheduler for warmup.
90
+ val_frequency : int
91
+ Frequency for validation.
92
+ tokenizer : BertTokenizer
93
+ Tokenizer for text prompts.
94
+
95
+ Raises
96
+ ------
97
+ ValueError
98
+ If the default tokenizer ("bert-base-uncased") fails to load and no tokenizer is provided.
99
+ """
100
+
101
+ def __init__(self, method, noise_predictor, hyper_params_model, data_loader, optimizer, objective, val_loader=None,
102
+ max_epoch=1000, device=None, conditional_model=None, tokenizer=None, max_length=77,
103
+ store_path=None, patience=10, warmup_epochs=100, val_frequency=10):
104
+ super().__init__()
105
+ self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
+ self.method = method
107
+ self.noise_predictor = noise_predictor
108
+ self.hyper_params_model = hyper_params_model.to(self.device)
109
+ self.conditional_model = conditional_model
110
+ self.optimizer = optimizer
111
+ self.objective = objective
112
+ self.store_path = store_path or "sde_model.pth"
113
+ self.data_loader = data_loader
114
+ self.val_loader = val_loader
115
+ self.max_epoch = max_epoch
116
+ self.max_length = max_length
117
+ self.patience = patience
118
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=self.patience, factor=0.5)
119
+ self.forward_diffusion = ForwardSDE(hyper_params=self.hyper_params_model, method=self.method).to(self.device)
120
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
121
+ self.val_frequency = val_frequency
122
+ if tokenizer is None:
123
+ try:
124
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
125
+ except Exception as e:
126
+ raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
127
+
128
+
129
+ def load_checkpoint(self, checkpoint_path):
130
+ """Loads a training checkpoint to resume training.
131
+
132
+ Restores the state of the noise predictor, conditional model (if applicable),
133
+ and optimizer from a saved checkpoint.
134
+
135
+ Parameters
136
+ ----------
137
+ checkpoint_path : str
138
+ Path to the checkpoint file.
139
+
140
+ Returns
141
+ -------
142
+ tuple
143
+ A tuple containing:
144
+ - epoch: The epoch at which the checkpoint was saved (int).
145
+ - loss: The loss at the checkpoint (float).
146
+
147
+ Raises
148
+ ------
149
+ FileNotFoundError
150
+ If the checkpoint file is not found.
151
+ KeyError
152
+ If the checkpoint is missing required keys ('model_state_dict_noise_predictor'
153
+ or 'optimizer_state_dict').
154
+
155
+ Warns
156
+ -----
157
+ warnings.warn
158
+ If the optimizer state cannot be loaded, if the checkpoint contains a
159
+ conditional model state but none is defined, or if no conditional model
160
+ state is provided when expected.
161
+ """
162
+ try:
163
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
164
+ except FileNotFoundError:
165
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
166
+
167
+ if 'model_state_dict_noise_predictor' not in checkpoint:
168
+ raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
169
+ self.noise_predictor.load_state_dict(checkpoint['model_state_dict_noise_predictor'])
170
+
171
+ if self.conditional_model is not None:
172
+ if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
173
+ self.conditional_model.load_state_dict(checkpoint['model_state_dict_conditional'])
174
+ else:
175
+ warnings.warn(
176
+ "Checkpoint contains no 'model_state_dict_conditional' or it is None, skipping conditional model loading")
177
+ elif 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
178
+ warnings.warn(
179
+ "Checkpoint contains conditional model state, but no conditional model is defined in this instance")
180
+
181
+ if 'optimizer_state_dict' not in checkpoint:
182
+ raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
183
+ try:
184
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
185
+ except ValueError as e:
186
+ warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
187
+
188
+ epoch = checkpoint.get('epoch', -1)
189
+ loss = checkpoint.get('loss', float('inf'))
190
+
191
+ self.noise_predictor.to(self.device)
192
+ if self.conditional_model is not None:
193
+ self.conditional_model.to(self.device)
194
+
195
+ print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
196
+ return epoch, loss
197
+
198
+ @staticmethod
199
+ def warmup_scheduler(optimizer, warmup_epochs=10):
200
+ """Creates a learning rate scheduler for warmup.
201
+
202
+ Generates a scheduler that linearly increases the learning rate from 0 to the
203
+ optimizer's initial value over the specified warmup epochs, then maintains it.
204
+
205
+ Parameters
206
+ ----------
207
+ optimizer : torch.optim.Optimizer
208
+ Optimizer to apply the scheduler to.
209
+ warmup_epochs : int, optional
210
+ Number of epochs for the warmup phase (default: 10).
211
+
212
+ Returns
213
+ -------
214
+ torch.optim.lr_scheduler.LambdaLR
215
+ Learning rate scheduler for warmup.
216
+ """
217
+ def lr_lambda(epoch):
218
+ if epoch < warmup_epochs:
219
+ return epoch / warmup_epochs
220
+ return 1.0
221
+
222
+ return LambdaLR(optimizer, lr_lambda)
223
+
224
+ def forward(self):
225
+ """Trains the SDE model to predict noise added by the forward diffusion process.
226
+
227
+ Executes the training loop, optimizing the noise predictor and conditional model
228
+ (if applicable) using mixed precision, gradient clipping, and learning rate
229
+ scheduling. Supports validation, early stopping, and checkpointing.
230
+
231
+ Returns
232
+ -------
233
+ tuple
234
+ A tuple containing:
235
+ - train_losses: List of mean training losses per epoch (list of float).
236
+ - best_val_loss: Best validation or training loss achieved (float).
237
+
238
+ Notes
239
+ -----
240
+ - Training uses mixed precision via `torch.cuda.amp` or `torch.amp` for efficiency.
241
+ - Checkpoints are saved when the validation (or training) loss improves, and on
242
+ early stopping.
243
+ - Early stopping is triggered if no improvement occurs for `patience` epochs.
244
+ """
245
+ self.noise_predictor.train()
246
+ self.noise_predictor.to(self.device)
247
+ if self.conditional_model is not None:
248
+ self.conditional_model.train()
249
+ self.conditional_model.to(self.device)
250
+
251
+ scaler = GradScaler()
252
+ train_losses = []
253
+ best_val_loss = float("inf")
254
+ wait = 0
255
+ for epoch in range(self.max_epoch):
256
+ train_losses_ = []
257
+ for x, y in tqdm(self.data_loader):
258
+ x = x.to(self.device)
259
+
260
+ if self.conditional_model is not None:
261
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
262
+ y_list = [str(item) for item in y_list]
263
+ y_encoded = self.tokenizer(
264
+ y_list,
265
+ padding="max_length",
266
+ truncation=True,
267
+ max_length=self.max_length,
268
+ return_tensors="pt"
269
+ ).to(self.device)
270
+ input_ids = y_encoded["input_ids"]
271
+ attention_mask = y_encoded["attention_mask"]
272
+ y_encoded = self.conditional_model(input_ids, attention_mask)
273
+ else:
274
+ y_encoded = None
275
+
276
+ self.optimizer.zero_grad()
277
+ with autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
278
+ noise = torch.randn_like(x).to(self.device)
279
+ t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
280
+ assert x.device == noise.device == t.device, "Device mismatch detected"
281
+ assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
282
+ noisy_x = self.forward_diffusion(x, noise, t)
283
+ p_noise = self.noise_predictor(noisy_x, t, y_encoded)
284
+ loss = self.objective(p_noise, noise)
285
+ scaler.scale(loss).backward()
286
+ nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
287
+ if self.conditional_model is not None:
288
+ nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
289
+ scaler.step(self.optimizer)
290
+ scaler.update()
291
+ self.warmup_lr_scheduler.step()
292
+ train_losses_.append(loss.item())
293
+
294
+ mean_train_loss = torch.mean(torch.tensor(train_losses_)).item()
295
+ train_losses.append(mean_train_loss)
296
+ print(f"\nEpoch: {epoch + 1} | Train Loss: {mean_train_loss:.4f}", end="")
297
+
298
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
299
+ val_loss = self.validate()
300
+ print(f" | Val Loss: {val_loss:.4f}")
301
+ current_best = val_loss
302
+ self.scheduler.step(val_loss)
303
+ else:
304
+ print()
305
+ current_best = mean_train_loss
306
+ self.scheduler.step(mean_train_loss)
307
+
308
+ if current_best < best_val_loss:
309
+ best_val_loss = current_best
310
+ wait = 0
311
+ try:
312
+ torch.save({
313
+ 'epoch': epoch+1,
314
+ 'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
315
+ 'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
316
+ 'optimizer_state_dict': self.optimizer.state_dict(),
317
+ 'loss': best_val_loss,
318
+ 'hyper_params_model': self.hyper_params_model,
319
+ 'max_epoch': self.max_epoch,
320
+ }, self.store_path)
321
+ print(f"Model saved at epoch {epoch + 1}")
322
+ except Exception as e:
323
+ print(f"Failed to save model: {e}")
324
+ else:
325
+ wait += 1
326
+ if wait >= self.patience:
327
+ print("Early stopping triggered")
328
+ try:
329
+ torch.save({
330
+ 'epoch': epoch+1,
331
+ 'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
332
+ 'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
333
+ 'optimizer_state_dict': self.optimizer.state_dict(),
334
+ 'loss': best_val_loss,
335
+ 'hyper_params_model': self.hyper_params_model,
336
+ 'max_epoch': self.max_epoch,
337
+ }, self.store_path + "_early_stop.pth")
338
+ print(f"Final model saved at {self.store_path}_early_stop.pth")
339
+ except Exception as e:
340
+ print(f"Failed to save final model: {e}")
341
+ break
342
+
343
+ return train_losses, best_val_loss
344
+
345
+ def validate(self):
346
+ """Validates the SDE model on the validation dataset.
347
+
348
+ Computes the validation loss using the noise predictor and forward SDE diffusion
349
+ process, with optional conditional inputs.
350
+
351
+ Returns
352
+ -------
353
+ float
354
+ Mean validation loss across the validation dataset.
355
+
356
+ Notes
357
+ -----
358
+ - Validation is performed with `torch.no_grad()` for efficiency.
359
+ - The noise predictor and conditional model (if applicable) are set to evaluation
360
+ mode during validation and restored to training mode afterward.
361
+ """
362
+ self.noise_predictor.eval()
363
+ if self.conditional_model is not None:
364
+ self.conditional_model.eval()
365
+
366
+ val_losses = []
367
+ with torch.no_grad():
368
+ for x, y in self.val_loader:
369
+ x = x.to(self.device)
370
+
371
+ if self.conditional_model is not None:
372
+ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
373
+ y_list = [str(item) for item in y_list]
374
+ y_encoded = self.tokenizer(
375
+ y_list,
376
+ padding="max_length",
377
+ truncation=True,
378
+ max_length=self.max_length,
379
+ return_tensors="pt"
380
+ ).to(self.device)
381
+ input_ids = y_encoded["input_ids"]
382
+ attention_mask = y_encoded["attention_mask"]
383
+ y_encoded = self.conditional_model(input_ids, attention_mask)
384
+ else:
385
+ y_encoded = None
386
+
387
+ noise = torch.randn_like(x).to(self.device)
388
+ t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
389
+ assert x.device == noise.device == t.device, "Device mismatch detected"
390
+ assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
391
+ noisy_x = self.forward_diffusion(x, noise, t)
392
+ p_noise = self.noise_predictor(noisy_x, t, y_encoded)
393
+ loss = self.objective(p_noise, noise)
394
+ val_losses.append(loss.item())
395
+
396
+ mean_val_loss = torch.mean(torch.tensor(val_losses)).item()
397
+ self.noise_predictor.train()
398
+ if self.conditional_model is not None:
399
+ self.conditional_model.train()
400
+ return mean_val_loss
torchdiff/__init__.py ADDED
@@ -0,0 +1,8 @@
1
+ __version__ = "2.0.0"
2
+
3
+ from .ddim import ForwardDDIM, ReverseDDIM, VarianceSchedulerDDIM, TrainDDIM, SampleDDIM
4
+ from .ddpm import ForwardDDPM, ReverseDDPM, VarianceSchedulerDDPM, TrainDDPM, SampleDDPM
5
+ from .ldm import TrainLDM, TrainAE, AutoencoderLDM, SampleLDM
6
+ from .sde import ForwardSDE, ReverseSDE, VarianceSchedulerSDE, TrainSDE, SampleSDE
7
+ from .unclip import ForwardUnCLIP, ReverseUnCLIP, VarianceSchedulerUnCLIP, CLIPEncoder, SampleUnCLIP, UnClipDecoder, UnCLIPTransformerPrior, CLIPContextProjection, CLIPEmbeddingProjection, TrainUnClipDecoder, SampleUnCLIP, UpsamplerUnCLIP, TrainUpsamplerUnCLIP
8
+ from .utils import NoisePredictor, TextEncoder, Metrics