TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ddim/train_ddim.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
"""Training module for Denoising Diffusion Implicit Models (DDIM).
|
|
2
|
+
|
|
3
|
+
This module implements the training process for DDIM, as described in Song et al. (2021,
|
|
4
|
+
"Denoising Diffusion Implicit Models"). It supports both unconditional and conditional
|
|
5
|
+
training with text prompts, using mixed precision and learning rate scheduling.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch.cuda.amp import GradScaler, autocast
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
from torch.amp import GradScaler, autocast
|
|
13
|
+
from torch.optim.lr_scheduler import LambdaLR
|
|
14
|
+
from transformers import BertTokenizer
|
|
15
|
+
import warnings
|
|
16
|
+
from forward_ddim import ForwardDDIM
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TrainDDIM(nn.Module):
|
|
22
|
+
"""Trainer for Denoising Diffusion Implicit Models (DDIM).
|
|
23
|
+
|
|
24
|
+
Manages the training process for DDIM, optimizing a noise predictor model to learn
|
|
25
|
+
the noise added by the forward diffusion process. Supports conditional training with
|
|
26
|
+
text prompts, mixed precision training, learning rate scheduling, early stopping, and
|
|
27
|
+
checkpointing, as inspired by Song et al. (2021).
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
noise_predictor : nn.Module
|
|
32
|
+
Model to predict noise added during the forward diffusion process.
|
|
33
|
+
hyper_params_model : nn.Module
|
|
34
|
+
Hyperparameter module (e.g., HyperParamsDDIM) defining the noise schedule.
|
|
35
|
+
data_loader : torch.utils.data.DataLoader
|
|
36
|
+
DataLoader for training data.
|
|
37
|
+
optimizer : torch.optim.Optimizer
|
|
38
|
+
Optimizer for training the noise predictor and conditional model (if applicable).
|
|
39
|
+
objective : callable
|
|
40
|
+
Loss function to compute the difference between predicted and actual noise.
|
|
41
|
+
val_loader : torch.utils.data.DataLoader, optional
|
|
42
|
+
DataLoader for validation data, default None.
|
|
43
|
+
max_epoch : int, optional
|
|
44
|
+
Maximum number of training epochs (default: 1000).
|
|
45
|
+
device : torch.device, optional
|
|
46
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
47
|
+
conditional_model : nn.Module, optional
|
|
48
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
49
|
+
tokenizer : BertTokenizer, optional
|
|
50
|
+
Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
|
|
51
|
+
max_length : int, optional
|
|
52
|
+
Maximum length for tokenized prompts (default: 77).
|
|
53
|
+
store_path : str, optional
|
|
54
|
+
Path to save model checkpoints (default: "ddim_model.pth").
|
|
55
|
+
patience : int, optional
|
|
56
|
+
Number of epochs to wait for improvement before early stopping (default: 10).
|
|
57
|
+
warmup_epochs : int, optional
|
|
58
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
59
|
+
val_frequency : int, optional
|
|
60
|
+
Frequency (in epochs) for validation (default: 10).
|
|
61
|
+
|
|
62
|
+
Attributes
|
|
63
|
+
----------
|
|
64
|
+
device : torch.device
|
|
65
|
+
Device used for computation.
|
|
66
|
+
noise_predictor : nn.Module
|
|
67
|
+
Noise prediction model.
|
|
68
|
+
hyper_params_model : nn.Module
|
|
69
|
+
Hyperparameter module for the noise schedule.
|
|
70
|
+
conditional_model : nn.Module or None
|
|
71
|
+
Conditional model for text-based training, if provided.
|
|
72
|
+
optimizer : torch.optim.Optimizer
|
|
73
|
+
Optimizer for training.
|
|
74
|
+
objective : callable
|
|
75
|
+
Loss function for training.
|
|
76
|
+
store_path : str
|
|
77
|
+
Path for saving checkpoints.
|
|
78
|
+
data_loader : torch.utils.data.DataLoader
|
|
79
|
+
Training data loader.
|
|
80
|
+
val_loader : torch.utils.data.DataLoader or None
|
|
81
|
+
Validation data loader, if provided.
|
|
82
|
+
max_epoch : int
|
|
83
|
+
Maximum training epochs.
|
|
84
|
+
max_length : int
|
|
85
|
+
Maximum length for tokenized prompts.
|
|
86
|
+
patience : int
|
|
87
|
+
Patience for early stopping.
|
|
88
|
+
scheduler : torch.optim.lr_scheduler.ReduceLROnPlateau
|
|
89
|
+
Learning rate scheduler based on validation or training loss.
|
|
90
|
+
forward_diffusion : ForwardDDIM
|
|
91
|
+
Forward diffusion module for DDIM.
|
|
92
|
+
warmup_lr_scheduler : torch.optim.lr_scheduler.LambdaLR
|
|
93
|
+
Learning rate scheduler for warmup.
|
|
94
|
+
val_frequency : int
|
|
95
|
+
Frequency for validation.
|
|
96
|
+
tokenizer : BertTokenizer
|
|
97
|
+
Tokenizer for text prompts.
|
|
98
|
+
|
|
99
|
+
Raises
|
|
100
|
+
------
|
|
101
|
+
ValueError
|
|
102
|
+
If the default tokenizer ("bert-base-uncased") fails to load and no tokenizer is provided.
|
|
103
|
+
"""
|
|
104
|
+
def __init__(self, noise_predictor, hyper_params_model, data_loader, optimizer, objective, val_loader=None,
|
|
105
|
+
max_epoch=1000, device=None, conditional_model=None, tokenizer=None, max_length=77,
|
|
106
|
+
store_path=None, patience=10, warmup_epochs=100, val_frequency=10):
|
|
107
|
+
super().__init__()
|
|
108
|
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
109
|
+
self.noise_predictor = noise_predictor
|
|
110
|
+
self.hyper_params_model = hyper_params_model.to(self.device)
|
|
111
|
+
self.conditional_model = conditional_model
|
|
112
|
+
self.optimizer = optimizer
|
|
113
|
+
self.objective = objective
|
|
114
|
+
self.store_path = store_path or "ddim_model.pth"
|
|
115
|
+
self.data_loader = data_loader
|
|
116
|
+
self.val_loader = val_loader
|
|
117
|
+
self.max_epoch = max_epoch
|
|
118
|
+
self.max_length = max_length
|
|
119
|
+
self.patience = patience
|
|
120
|
+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=self.patience, factor=0.5)
|
|
121
|
+
self.forward_diffusion = ForwardDDIM(hyper_params=self.hyper_params_model).to(self.device)
|
|
122
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
123
|
+
self.val_frequency = val_frequency
|
|
124
|
+
if tokenizer is None:
|
|
125
|
+
try:
|
|
126
|
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
127
|
+
except Exception as e:
|
|
128
|
+
raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
|
|
129
|
+
|
|
130
|
+
def load_checkpoint(self, checkpoint_path):
|
|
131
|
+
"""Loads a training checkpoint to resume training.
|
|
132
|
+
|
|
133
|
+
Restores the state of the noise predictor, conditional model (if applicable),
|
|
134
|
+
and optimizer from a saved checkpoint.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
checkpoint_path : str
|
|
139
|
+
Path to the checkpoint file.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
tuple
|
|
144
|
+
A tuple containing:
|
|
145
|
+
- epoch: The epoch at which the checkpoint was saved (int).
|
|
146
|
+
- loss: The loss at the checkpoint (float).
|
|
147
|
+
|
|
148
|
+
Raises
|
|
149
|
+
------
|
|
150
|
+
FileNotFoundError
|
|
151
|
+
If the checkpoint file is not found.
|
|
152
|
+
KeyError
|
|
153
|
+
If the checkpoint is missing required keys ('model_state_dict_noise_predictor'
|
|
154
|
+
or 'optimizer_state_dict').
|
|
155
|
+
|
|
156
|
+
Warns
|
|
157
|
+
-----
|
|
158
|
+
warnings.warn
|
|
159
|
+
If the optimizer state cannot be loaded, if the checkpoint contains a
|
|
160
|
+
conditional model state but none is defined, or if no conditional model
|
|
161
|
+
state is provided when expected.
|
|
162
|
+
"""
|
|
163
|
+
try:
|
|
164
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
165
|
+
except FileNotFoundError:
|
|
166
|
+
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
|
167
|
+
|
|
168
|
+
if 'model_state_dict_noise_predictor' not in checkpoint:
|
|
169
|
+
raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
|
|
170
|
+
self.noise_predictor.load_state_dict(checkpoint['model_state_dict_noise_predictor'])
|
|
171
|
+
|
|
172
|
+
if self.conditional_model is not None:
|
|
173
|
+
if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
|
|
174
|
+
self.conditional_model.load_state_dict(checkpoint['model_state_dict_conditional'])
|
|
175
|
+
else:
|
|
176
|
+
warnings.warn(
|
|
177
|
+
"Checkpoint contains no 'model_state_dict_conditional' or it is None, skipping conditional model loading")
|
|
178
|
+
elif 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
|
|
179
|
+
warnings.warn(
|
|
180
|
+
"Checkpoint contains conditional model state, but no conditional model is defined in this instance")
|
|
181
|
+
|
|
182
|
+
if 'optimizer_state_dict' not in checkpoint:
|
|
183
|
+
raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
|
|
184
|
+
try:
|
|
185
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
186
|
+
except ValueError as e:
|
|
187
|
+
warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
|
|
188
|
+
|
|
189
|
+
epoch = checkpoint.get('epoch', -1)
|
|
190
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
191
|
+
|
|
192
|
+
self.noise_predictor.to(self.device)
|
|
193
|
+
if self.conditional_model is not None:
|
|
194
|
+
self.conditional_model.to(self.device)
|
|
195
|
+
|
|
196
|
+
print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
|
|
197
|
+
return epoch, loss
|
|
198
|
+
|
|
199
|
+
@staticmethod
|
|
200
|
+
def warmup_scheduler(optimizer, warmup_epochs=10):
|
|
201
|
+
"""Creates a learning rate scheduler for warmup.
|
|
202
|
+
|
|
203
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
204
|
+
optimizer's initial value over the specified warmup epochs, then maintains it.
|
|
205
|
+
|
|
206
|
+
Parameters
|
|
207
|
+
----------
|
|
208
|
+
optimizer : torch.optim.Optimizer
|
|
209
|
+
Optimizer to apply the scheduler to.
|
|
210
|
+
warmup_epochs : int, optional
|
|
211
|
+
Number of epochs for the warmup phase (default: 10).
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
torch.optim.lr_scheduler.LambdaLR
|
|
216
|
+
Learning rate scheduler for warmup.
|
|
217
|
+
"""
|
|
218
|
+
def lr_lambda(epoch):
|
|
219
|
+
if epoch < warmup_epochs:
|
|
220
|
+
return epoch / warmup_epochs
|
|
221
|
+
return 1.0
|
|
222
|
+
|
|
223
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
224
|
+
|
|
225
|
+
def forward(self):
|
|
226
|
+
"""Trains the DDIM model to predict noise added by the forward diffusion process.
|
|
227
|
+
|
|
228
|
+
Executes the training loop, optimizing the noise predictor and conditional model
|
|
229
|
+
(if applicable) using mixed precision, gradient clipping, and learning rate
|
|
230
|
+
scheduling. Supports validation, early stopping, and checkpointing.
|
|
231
|
+
|
|
232
|
+
Returns
|
|
233
|
+
-------
|
|
234
|
+
tuple
|
|
235
|
+
A tuple containing:
|
|
236
|
+
- train_losses: List of mean training losses per epoch (list of float).
|
|
237
|
+
- best_val_loss: Best validation or training loss achieved (float).
|
|
238
|
+
|
|
239
|
+
Notes
|
|
240
|
+
-----
|
|
241
|
+
- Training uses mixed precision via `torch.cuda.amp` or `torch.amp` for efficiency.
|
|
242
|
+
- Checkpoints are saved when the validation (or training) loss improves, and on early stopping.
|
|
243
|
+
- Early stopping is triggered if no improvement occurs for `patience` epochs.
|
|
244
|
+
"""
|
|
245
|
+
self.noise_predictor.train()
|
|
246
|
+
self.noise_predictor.to(self.device)
|
|
247
|
+
if self.conditional_model is not None:
|
|
248
|
+
self.conditional_model.train()
|
|
249
|
+
self.conditional_model.to(self.device)
|
|
250
|
+
|
|
251
|
+
scaler = GradScaler()
|
|
252
|
+
train_losses = []
|
|
253
|
+
best_val_loss = float("inf")
|
|
254
|
+
wait = 0
|
|
255
|
+
for epoch in range(self.max_epoch):
|
|
256
|
+
train_losses_ = []
|
|
257
|
+
for x, y in tqdm(self.data_loader):
|
|
258
|
+
x = x.to(self.device)
|
|
259
|
+
|
|
260
|
+
if self.conditional_model is not None:
|
|
261
|
+
y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
|
|
262
|
+
y_list = [str(item) for item in y_list]
|
|
263
|
+
y_encoded = self.tokenizer(
|
|
264
|
+
y_list,
|
|
265
|
+
padding="max_length",
|
|
266
|
+
truncation=True,
|
|
267
|
+
max_length=self.max_length,
|
|
268
|
+
return_tensors="pt"
|
|
269
|
+
).to(self.device)
|
|
270
|
+
input_ids = y_encoded["input_ids"]
|
|
271
|
+
attention_mask = y_encoded["attention_mask"]
|
|
272
|
+
y_encoded = self.conditional_model(input_ids, attention_mask)
|
|
273
|
+
else:
|
|
274
|
+
y_encoded = None
|
|
275
|
+
|
|
276
|
+
self.optimizer.zero_grad()
|
|
277
|
+
with autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
|
|
278
|
+
noise = torch.randn_like(x).to(self.device)
|
|
279
|
+
t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
|
|
280
|
+
assert x.device == noise.device == t.device, "Device mismatch detected"
|
|
281
|
+
assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
|
|
282
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
283
|
+
p_noise = self.noise_predictor(noisy_x, t, y_encoded)
|
|
284
|
+
loss = self.objective(p_noise, noise)
|
|
285
|
+
scaler.scale(loss).backward()
|
|
286
|
+
nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
|
|
287
|
+
if self.conditional_model is not None:
|
|
288
|
+
nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
|
|
289
|
+
scaler.step(self.optimizer)
|
|
290
|
+
scaler.update()
|
|
291
|
+
self.warmup_lr_scheduler.step()
|
|
292
|
+
train_losses_.append(loss.item())
|
|
293
|
+
|
|
294
|
+
mean_train_loss = torch.mean(torch.tensor(train_losses_)).item()
|
|
295
|
+
train_losses.append(mean_train_loss)
|
|
296
|
+
print(f"\nEpoch: {epoch + 1} | Train Loss: {mean_train_loss:.4f}", end="")
|
|
297
|
+
|
|
298
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
299
|
+
val_loss = self.validate()
|
|
300
|
+
print(f" | Val Loss: {val_loss:.4f}")
|
|
301
|
+
current_best = val_loss
|
|
302
|
+
self.scheduler.step(val_loss)
|
|
303
|
+
else:
|
|
304
|
+
print()
|
|
305
|
+
current_best = mean_train_loss
|
|
306
|
+
self.scheduler.step(mean_train_loss)
|
|
307
|
+
|
|
308
|
+
if current_best < best_val_loss:
|
|
309
|
+
best_val_loss = current_best
|
|
310
|
+
wait = 0
|
|
311
|
+
try:
|
|
312
|
+
torch.save({
|
|
313
|
+
'epoch': epoch + 1,
|
|
314
|
+
'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
|
|
315
|
+
'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
|
|
316
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
317
|
+
'loss': best_val_loss,
|
|
318
|
+
'hyper_params_model': self.hyper_params_model,
|
|
319
|
+
'max_epoch': self.max_epoch,
|
|
320
|
+
}, self.store_path)
|
|
321
|
+
print(f"Model saved at epoch {epoch + 1}")
|
|
322
|
+
except Exception as e:
|
|
323
|
+
print(f"Failed to save model: {e}")
|
|
324
|
+
else:
|
|
325
|
+
wait += 1
|
|
326
|
+
if wait >= self.patience:
|
|
327
|
+
print("Early stopping triggered")
|
|
328
|
+
try:
|
|
329
|
+
torch.save({
|
|
330
|
+
'epoch': epoch + 1,
|
|
331
|
+
'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
|
|
332
|
+
'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
|
|
333
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
334
|
+
'loss': best_val_loss,
|
|
335
|
+
'hyper_params_model': self.hyper_params_model,
|
|
336
|
+
'max_epoch': self.max_epoch,
|
|
337
|
+
}, self.store_path + "_early_stop.pth")
|
|
338
|
+
print(f"Final model saved at {self.store_path}_early_stop.pth")
|
|
339
|
+
except Exception as e:
|
|
340
|
+
print(f"Failed to save final model: {e}")
|
|
341
|
+
break
|
|
342
|
+
|
|
343
|
+
return train_losses, best_val_loss
|
|
344
|
+
|
|
345
|
+
def validate(self):
|
|
346
|
+
"""Validates the DDIM model on the validation dataset.
|
|
347
|
+
|
|
348
|
+
Computes the validation loss using the noise predictor and forward diffusion
|
|
349
|
+
process, with optional conditional inputs.
|
|
350
|
+
|
|
351
|
+
Returns
|
|
352
|
+
-------
|
|
353
|
+
float
|
|
354
|
+
Mean validation loss across the validation dataset.
|
|
355
|
+
"""
|
|
356
|
+
self.noise_predictor.eval()
|
|
357
|
+
if self.conditional_model is not None:
|
|
358
|
+
self.conditional_model.eval()
|
|
359
|
+
|
|
360
|
+
val_losses = []
|
|
361
|
+
with torch.no_grad():
|
|
362
|
+
for x, y in self.val_loader:
|
|
363
|
+
x = x.to(self.device)
|
|
364
|
+
|
|
365
|
+
if self.conditional_model is not None:
|
|
366
|
+
y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
|
|
367
|
+
y_list = [str(item) for item in y_list]
|
|
368
|
+
y_encoded = self.tokenizer(
|
|
369
|
+
y_list,
|
|
370
|
+
padding="max_length",
|
|
371
|
+
truncation=True,
|
|
372
|
+
max_length=self.max_length,
|
|
373
|
+
return_tensors="pt"
|
|
374
|
+
).to(self.device)
|
|
375
|
+
input_ids = y_encoded["input_ids"]
|
|
376
|
+
attention_mask = y_encoded["attention_mask"]
|
|
377
|
+
y_encoded = self.conditional_model(input_ids, attention_mask)
|
|
378
|
+
else:
|
|
379
|
+
y_encoded = None
|
|
380
|
+
|
|
381
|
+
noise = torch.randn_like(x).to(self.device)
|
|
382
|
+
t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
|
|
383
|
+
assert x.device == noise.device == t.device, "Device mismatch detected"
|
|
384
|
+
assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
|
|
385
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
386
|
+
p_noise = self.noise_predictor(noisy_x, t, y_encoded)
|
|
387
|
+
loss = self.objective(p_noise, noise)
|
|
388
|
+
val_losses.append(loss.item())
|
|
389
|
+
|
|
390
|
+
mean_val_loss = torch.mean(torch.tensor(val_losses)).item()
|
|
391
|
+
self.noise_predictor.train()
|
|
392
|
+
if self.conditional_model is not None:
|
|
393
|
+
self.conditional_model.train()
|
|
394
|
+
return mean_val_loss
|
ddpm/__init__.py
ADDED
|
File without changes
|
ddpm/forward_ddpm.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Forward diffusion process for Denoising Diffusion Probabilistic Models (DDPM).
|
|
2
|
+
|
|
3
|
+
This module implements the forward diffusion process as described in the DDPM paper
|
|
4
|
+
(Ho et al., 2020, "Denoising Diffusion Probabilistic Models"). The forward process
|
|
5
|
+
gradually adds noise to the input data according to a predefined noise schedule.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ForwardDDPM(nn.Module):
|
|
15
|
+
"""Forward diffusion process of DDPM.
|
|
16
|
+
|
|
17
|
+
Implements the forward diffusion process for DDPM, which perturbs input data by
|
|
18
|
+
adding Gaussian noise over a series of time steps, as defined in Ho et al. (2020).
|
|
19
|
+
The noise schedule can be either fixed or trainable, depending on the provided
|
|
20
|
+
hyperparameters.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
hyper_params : object
|
|
25
|
+
Hyperparameter object containing the noise schedule parameters. Expected to have
|
|
26
|
+
attributes:
|
|
27
|
+
- `num_steps`: Number of diffusion steps (int).
|
|
28
|
+
- `trainable_beta`: Whether the noise schedule is trainable (bool).
|
|
29
|
+
- `betas`: Noise schedule parameters (torch.Tensor, optional if trainable_beta is True).
|
|
30
|
+
- `sqrt_alpha_bars`: Precomputed cumulative product of alphas (torch.Tensor, optional if trainable_beta is False).
|
|
31
|
+
- `sqrt_one_minus_alpha_bars`: Precomputed square root of one minus cumulative alpha product (torch.Tensor, optional if trainable_beta is False).
|
|
32
|
+
- `compute_schedule`: Method to compute the noise schedule (callable, optional if trainable_beta is True).
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
hyper_params : object
|
|
37
|
+
Stores the provided hyperparameter object for use in the forward process.
|
|
38
|
+
"""
|
|
39
|
+
def __init__(self, hyper_params):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.hyper_params = hyper_params
|
|
42
|
+
|
|
43
|
+
def forward(self, x0, noise, time_steps):
|
|
44
|
+
"""Applies the forward diffusion process to the input data.
|
|
45
|
+
|
|
46
|
+
Perturbs the input data `x0` by adding Gaussian noise according to the DDPM
|
|
47
|
+
forward process at specified time steps. The noise is scaled based on the
|
|
48
|
+
cumulative noise schedule parameters (`sqrt_alpha_bar_t` and
|
|
49
|
+
`sqrt_one_minus_alpha_bar_t`).
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
x0 : torch.Tensor
|
|
54
|
+
Input data tensor of shape (batch_size, channels, height, width).
|
|
55
|
+
noise : torch.Tensor
|
|
56
|
+
Gaussian noise tensor of the same shape as `x0`.
|
|
57
|
+
time_steps : torch.Tensor
|
|
58
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
59
|
+
is in the range [0, hyper_params.num_steps - 1].
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
torch.Tensor
|
|
64
|
+
Noisy data tensor `xt` at the specified time steps, with the same shape as `x0`.
|
|
65
|
+
|
|
66
|
+
Raises
|
|
67
|
+
------
|
|
68
|
+
ValueError
|
|
69
|
+
If any value in `time_steps` is outside the valid range
|
|
70
|
+
[0, hyper_params.num_steps - 1].
|
|
71
|
+
"""
|
|
72
|
+
if not torch.all((time_steps >= 0) & (time_steps < self.hyper_params.num_steps)):
|
|
73
|
+
raise ValueError(f"time_steps must be between 0 and {self.hyper_params.num_steps - 1}")
|
|
74
|
+
|
|
75
|
+
if self.hyper_params.trainable_beta:
|
|
76
|
+
_, _, _, sqrt_alpha_bar_t, sqrt_one_minus_alpha_bar_t = self.hyper_params.compute_schedule(
|
|
77
|
+
self.hyper_params.betas
|
|
78
|
+
)
|
|
79
|
+
sqrt_alpha_bar_t = sqrt_alpha_bar_t[time_steps].to(x0.device)
|
|
80
|
+
sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t[time_steps].to(x0.device)
|
|
81
|
+
else:
|
|
82
|
+
sqrt_alpha_bar_t = self.hyper_params.sqrt_alpha_bars[time_steps].to(x0.device)
|
|
83
|
+
sqrt_one_minus_alpha_bar_t = self.hyper_params.sqrt_one_minus_alpha_bars[time_steps].to(x0.device)
|
|
84
|
+
|
|
85
|
+
sqrt_alpha_bar_t = sqrt_alpha_bar_t.view(-1, 1, 1, 1)
|
|
86
|
+
sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.view(-1, 1, 1, 1)
|
|
87
|
+
|
|
88
|
+
xt = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise
|
|
89
|
+
return xt
|
ddpm/hyper_param.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""Hyperparameters for Denoising Diffusion Probabilistic Models (DDPM) noise schedule.
|
|
2
|
+
|
|
3
|
+
This module implements a flexible noise schedule for DDPM, as described in Ho et al.
|
|
4
|
+
(2020, "Denoising Diffusion Probabilistic Models"). It supports multiple beta schedule
|
|
5
|
+
methods and allows for trainable or fixed noise schedules.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class HyperParamsDDPM(nn.Module):
|
|
16
|
+
"""Hyperparameters for DDPM noise schedule with flexible beta computation.
|
|
17
|
+
|
|
18
|
+
Manages the noise schedule parameters for DDPM, including the computation of beta
|
|
19
|
+
values and derived quantities (alphas, alpha_bars, etc.), with support for
|
|
20
|
+
trainable or fixed schedules and various beta scheduling methods, as inspired by
|
|
21
|
+
Ho et al. (2020).
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
num_steps : int, optional
|
|
26
|
+
Number of diffusion steps (default: 1000).
|
|
27
|
+
beta_start : float, optional
|
|
28
|
+
Starting value for beta (default: 1e-4).
|
|
29
|
+
beta_end : float, optional
|
|
30
|
+
Ending value for beta (default: 0.02).
|
|
31
|
+
trainable_beta : bool, optional
|
|
32
|
+
Whether the beta schedule is trainable (default: False).
|
|
33
|
+
beta_method : str, optional
|
|
34
|
+
Method for computing the beta schedule (default: "linear").
|
|
35
|
+
Supported methods: "linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
36
|
+
|
|
37
|
+
Attributes
|
|
38
|
+
----------
|
|
39
|
+
num_steps : int
|
|
40
|
+
Number of diffusion steps.
|
|
41
|
+
beta_start : float
|
|
42
|
+
Minimum beta value.
|
|
43
|
+
beta_end : float
|
|
44
|
+
Maximum beta value.
|
|
45
|
+
trainable_beta : bool
|
|
46
|
+
Whether the beta schedule is trainable.
|
|
47
|
+
beta_method : str
|
|
48
|
+
Method used for beta schedule computation.
|
|
49
|
+
betas : torch.Tensor
|
|
50
|
+
Beta schedule values, shape (num_steps,). Trainable if `trainable_beta` is True,
|
|
51
|
+
otherwise a fixed buffer.
|
|
52
|
+
alphas : torch.Tensor, optional
|
|
53
|
+
Alpha values (1 - betas), shape (num_steps,). Available if `trainable_beta` is False.
|
|
54
|
+
alpha_bars : torch.Tensor, optional
|
|
55
|
+
Cumulative product of alphas, shape (num_steps,). Available if `trainable_beta` is False.
|
|
56
|
+
sqrt_alpha_bars : torch.Tensor, optional
|
|
57
|
+
Square root of alpha_bars, shape (num_steps,). Available if `trainable_beta` is False.
|
|
58
|
+
sqrt_one_minus_alpha_bars : torch.Tensor, optional
|
|
59
|
+
Square root of (1 - alpha_bars), shape (num_steps,). Available if `trainable_beta` is False.
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
ValueError
|
|
64
|
+
If `beta_start` or `beta_end` do not satisfy 0 < beta_start < beta_end < 1,
|
|
65
|
+
or if `num_steps` is not positive.
|
|
66
|
+
"""
|
|
67
|
+
def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, trainable_beta=False, beta_method="linear"):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.num_steps = num_steps
|
|
70
|
+
self.beta_start = beta_start
|
|
71
|
+
self.beta_end = beta_end
|
|
72
|
+
self.trainable_beta = trainable_beta
|
|
73
|
+
self.beta_method = beta_method
|
|
74
|
+
|
|
75
|
+
# validate inputs
|
|
76
|
+
if not (0 < beta_start < beta_end < 1):
|
|
77
|
+
raise ValueError(f"beta_start ({beta_start}) and beta_end ({beta_end}) must satisfy 0 < start < end < 1")
|
|
78
|
+
if num_steps <= 0:
|
|
79
|
+
raise ValueError(f"num_steps ({num_steps}) must be positive")
|
|
80
|
+
|
|
81
|
+
# compute initial beta schedule
|
|
82
|
+
beta_range = (beta_start, beta_end)
|
|
83
|
+
betas_init = self.compute_beta_schedule(beta_range, num_steps, beta_method)
|
|
84
|
+
|
|
85
|
+
# initialize betas
|
|
86
|
+
if trainable_beta:
|
|
87
|
+
self.betas = nn.Parameter(betas_init) # Trainable parameter
|
|
88
|
+
else:
|
|
89
|
+
self.register_buffer('betas', betas_init) # Fixed buffer
|
|
90
|
+
self.register_buffer('alphas', 1 - self.betas)
|
|
91
|
+
self.register_buffer('alpha_bars', torch.cumprod(self.alphas, dim=0))
|
|
92
|
+
self.register_buffer('sqrt_alpha_bars', torch.sqrt(self.alpha_bars))
|
|
93
|
+
self.register_buffer('sqrt_one_minus_alpha_bars', torch.sqrt(1 - self.alpha_bars))
|
|
94
|
+
|
|
95
|
+
def compute_beta_schedule(self, beta_range, num_steps, method):
|
|
96
|
+
"""Computes the beta schedule based on the specified method.
|
|
97
|
+
|
|
98
|
+
Generates a sequence of beta values for the DDPM noise schedule using the
|
|
99
|
+
chosen method, ensuring values are clamped within the specified range.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
beta_range : tuple
|
|
104
|
+
Tuple of (min_beta, max_beta) specifying the valid range for beta values.
|
|
105
|
+
num_steps : int
|
|
106
|
+
Number of diffusion steps.
|
|
107
|
+
method : str
|
|
108
|
+
Method for computing the beta schedule. Supported methods:
|
|
109
|
+
"linear", "sigmoid", "quadratic", "constant", "inverse_time".
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
torch.Tensor
|
|
114
|
+
Tensor of beta values, shape (num_steps,).
|
|
115
|
+
|
|
116
|
+
Raises
|
|
117
|
+
------
|
|
118
|
+
ValueError
|
|
119
|
+
If `method` is not one of the supported beta schedule methods.
|
|
120
|
+
"""
|
|
121
|
+
beta_min, beta_max = beta_range
|
|
122
|
+
if method == "sigmoid":
|
|
123
|
+
x = torch.linspace(-6, 6, num_steps)
|
|
124
|
+
beta = torch.sigmoid(x) * (beta_max - beta_min) + beta_min
|
|
125
|
+
elif method == "quadratic":
|
|
126
|
+
x = torch.linspace(beta_min**0.5, beta_max**0.5, num_steps)
|
|
127
|
+
beta = x**2
|
|
128
|
+
elif method == "constant":
|
|
129
|
+
beta = torch.full((num_steps,), beta_max)
|
|
130
|
+
elif method == "inverse_time":
|
|
131
|
+
beta = 1.0 / torch.linspace(num_steps, 1, num_steps)
|
|
132
|
+
# scale to beta_range
|
|
133
|
+
beta = beta_min + (beta_max - beta_min) * (beta - beta.min()) / (beta.max() - beta.min())
|
|
134
|
+
elif method == "linear":
|
|
135
|
+
beta = torch.linspace(beta_min, beta_max, num_steps)
|
|
136
|
+
else:
|
|
137
|
+
raise ValueError(f"Unknown beta_method: {method}. Supported: linear, sigmoid, quadratic, constant, inverse_time")
|
|
138
|
+
|
|
139
|
+
beta = torch.clamp(beta, min=beta_min, max=beta_max)
|
|
140
|
+
return beta
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def compute_schedule(betas):
|
|
144
|
+
"""Computes noise schedule parameters dynamically from betas.
|
|
145
|
+
|
|
146
|
+
Calculates the derived noise schedule parameters (alphas, alpha_bars, etc.)
|
|
147
|
+
from the provided beta values, as used in the DDPM forward and reverse processes.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
betas : torch.Tensor
|
|
152
|
+
Tensor of beta values, shape (num_steps,).
|
|
153
|
+
|
|
154
|
+
Returns
|
|
155
|
+
-------
|
|
156
|
+
tuple
|
|
157
|
+
A tuple containing:
|
|
158
|
+
- betas: Input beta values, shape (num_steps,).
|
|
159
|
+
- alphas: 1 - betas, shape (num_steps,).
|
|
160
|
+
- alpha_bars: Cumulative product of alphas, shape (num_steps,).
|
|
161
|
+
- sqrt_alpha_bars: Square root of alpha_bars, shape (num_steps,).
|
|
162
|
+
- sqrt_one_minus_alpha_bars: Square root of (1 - alpha_bars), shape (num_steps,).
|
|
163
|
+
"""
|
|
164
|
+
alphas = 1 - betas
|
|
165
|
+
alpha_bars = torch.cumprod(alphas, dim=0)
|
|
166
|
+
return betas, alphas, alpha_bars, torch.sqrt(alpha_bars), torch.sqrt(1 - alpha_bars)
|
|
167
|
+
|
|
168
|
+
def constrain_betas(self):
|
|
169
|
+
"""Constrains trainable betas to a valid range during training.
|
|
170
|
+
|
|
171
|
+
Ensures that trainable beta values remain within the specified range
|
|
172
|
+
[beta_start, beta_end] by clamping them in-place.
|
|
173
|
+
|
|
174
|
+
Notes
|
|
175
|
+
-----
|
|
176
|
+
This method only applies when `trainable_beta` is True.
|
|
177
|
+
"""
|
|
178
|
+
if self.trainable_beta:
|
|
179
|
+
with torch.no_grad():
|
|
180
|
+
self.betas.clamp_(min=self.beta_start, max=self.beta_end)
|