TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
ldm/train_ldm.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from tqdm import tqdm
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from torch.cuda.amp import GradScaler, autocast
|
|
5
|
+
from torch.amp import GradScaler, autocast
|
|
6
|
+
from torch.optim.lr_scheduler import LambdaLR
|
|
7
|
+
from transformers import BertTokenizer
|
|
8
|
+
import warnings
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TrainLDM(nn.Module):
|
|
14
|
+
"""Trainer for Latent Diffusion Models (LDM).
|
|
15
|
+
|
|
16
|
+
Manages the training process for LDMs, optimizing a noise predictor to learn the noise
|
|
17
|
+
added by the forward diffusion process in the latent space, as described in Rombach
|
|
18
|
+
et al. (2022). Uses a pre-trained compressor model to encode images into a latent
|
|
19
|
+
space, supports conditional training with text prompts, mixed precision, learning rate
|
|
20
|
+
scheduling, early stopping, and checkpointing.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
forward_model : nn.Module
|
|
25
|
+
Forward diffusion module (e.g., ForwardDDPM, ForwardSDE) to add noise in the latent space.
|
|
26
|
+
hyper_params_model : nn.Module
|
|
27
|
+
Hyperparameter module (e.g., HyperParamsDDPM, HyperParamsSDE) defining the noise schedule.
|
|
28
|
+
noise_predictor : nn.Module
|
|
29
|
+
Model to predict noise added during the forward diffusion process.
|
|
30
|
+
compressor_model : nn.Module
|
|
31
|
+
Pre-trained model to encode images into the latent space and decode back (e.g., autoencoder).
|
|
32
|
+
optimizer : torch.optim.Optimizer
|
|
33
|
+
Optimizer for training the noise predictor and conditional model (if applicable).
|
|
34
|
+
objective : callable
|
|
35
|
+
Loss function to compute the difference between predicted and actual noise.
|
|
36
|
+
data_loader : torch.utils.data.DataLoader
|
|
37
|
+
DataLoader for training data.
|
|
38
|
+
conditional_model : nn.Module, optional
|
|
39
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
40
|
+
val_loader : torch.utils.data.DataLoader, optional
|
|
41
|
+
DataLoader for validation data, default None.
|
|
42
|
+
max_epoch : int, optional
|
|
43
|
+
Maximum number of training epochs (default: 1000).
|
|
44
|
+
device : torch.device, optional
|
|
45
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
46
|
+
store_path : str, optional
|
|
47
|
+
Path to save model checkpoints (default: "ldm_model.pth").
|
|
48
|
+
patience : int, optional
|
|
49
|
+
Number of epochs to wait for improvement before early stopping (default: 10).
|
|
50
|
+
warmup_epochs : int, optional
|
|
51
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
52
|
+
tokenizer : BertTokenizer, optional
|
|
53
|
+
Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
|
|
54
|
+
max_length : int, optional
|
|
55
|
+
Maximum length for tokenized prompts (default: 77).
|
|
56
|
+
val_frequency : int, optional
|
|
57
|
+
Frequency (in epochs) for validation (default: 10).
|
|
58
|
+
|
|
59
|
+
Attributes
|
|
60
|
+
----------
|
|
61
|
+
device : torch.device
|
|
62
|
+
Device used for computation.
|
|
63
|
+
forward_diffusion : nn.Module
|
|
64
|
+
Forward diffusion module.
|
|
65
|
+
hyper_params_model : nn.Module
|
|
66
|
+
Hyperparameter module for the noise schedule.
|
|
67
|
+
noise_predictor : nn.Module
|
|
68
|
+
Noise prediction model.
|
|
69
|
+
compressor_model : nn.Module
|
|
70
|
+
Compressor model for latent space encoding/decoding.
|
|
71
|
+
conditional_model : nn.Module or None
|
|
72
|
+
Conditional model for text-based training, if provided.
|
|
73
|
+
optimizer : torch.optim.Optimizer
|
|
74
|
+
Optimizer for training.
|
|
75
|
+
objective : callable
|
|
76
|
+
Loss function for training.
|
|
77
|
+
data_loader : torch.utils.data.DataLoader
|
|
78
|
+
Training data loader.
|
|
79
|
+
val_loader : torch.utils.data.DataLoader or None
|
|
80
|
+
Validation data loader, if provided.
|
|
81
|
+
max_epoch : int
|
|
82
|
+
Maximum training epochs.
|
|
83
|
+
store_path : str
|
|
84
|
+
Path for saving checkpoints.
|
|
85
|
+
max_length : int
|
|
86
|
+
Maximum length for tokenized prompts.
|
|
87
|
+
patience : int
|
|
88
|
+
Patience for early stopping.
|
|
89
|
+
scheduler : torch.optim.lr_scheduler.ReduceLROnPlateau
|
|
90
|
+
Learning rate scheduler based on validation or training loss.
|
|
91
|
+
warmup_lr_scheduler : torch.optim.lr_scheduler.LambdaLR
|
|
92
|
+
Learning rate scheduler for warmup.
|
|
93
|
+
tokenizer : BertTokenizer
|
|
94
|
+
Tokenizer for text prompts.
|
|
95
|
+
val_frequency : int
|
|
96
|
+
Frequency for validation.
|
|
97
|
+
|
|
98
|
+
Raises
|
|
99
|
+
------
|
|
100
|
+
ValueError
|
|
101
|
+
If the default tokenizer ("bert-base-uncased") fails to load and no tokenizer is provided.
|
|
102
|
+
"""
|
|
103
|
+
def __init__(self, forward_model, hyper_params_model, noise_predictor, compressor_model, optimizer, objective, data_loader,
|
|
104
|
+
conditional_model=None, val_loader=None, max_epoch=1000, device=None, store_path=None,
|
|
105
|
+
patience=10, warmup_epochs=100, tokenizer=None, max_length=77, val_frequency=10):
|
|
106
|
+
super().__init__()
|
|
107
|
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
108
|
+
self.forward_diffusion = forward_model.to(device)
|
|
109
|
+
self.hyper_params_model = hyper_params_model.to(device)
|
|
110
|
+
self.noise_predictor = noise_predictor
|
|
111
|
+
self.compressor_model = compressor_model
|
|
112
|
+
self.conditional_model = conditional_model
|
|
113
|
+
self.optimizer = optimizer
|
|
114
|
+
self.objective = objective
|
|
115
|
+
self.data_loader = data_loader
|
|
116
|
+
self.val_loader = val_loader
|
|
117
|
+
self.max_epoch = max_epoch
|
|
118
|
+
self.store_path = store_path or "ldm_model.pth"
|
|
119
|
+
self.max_length = max_length
|
|
120
|
+
self.patience = patience
|
|
121
|
+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=self.patience, factor=0.5)
|
|
122
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
123
|
+
self.val_frequency = val_frequency
|
|
124
|
+
if tokenizer is None:
|
|
125
|
+
try:
|
|
126
|
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
127
|
+
except Exception as e:
|
|
128
|
+
raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
|
|
129
|
+
|
|
130
|
+
def load_checkpoint(self, checkpoint_path):
|
|
131
|
+
"""Loads a training checkpoint to resume training.
|
|
132
|
+
|
|
133
|
+
Restores the state of the noise predictor, conditional model (if applicable),
|
|
134
|
+
and optimizer from a saved checkpoint.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
checkpoint_path : str
|
|
139
|
+
Path to the checkpoint file.
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
tuple
|
|
144
|
+
A tuple containing:
|
|
145
|
+
- epoch: The epoch at which the checkpoint was saved (int).
|
|
146
|
+
- loss: The loss at the checkpoint (float).
|
|
147
|
+
|
|
148
|
+
Raises
|
|
149
|
+
------
|
|
150
|
+
FileNotFoundError
|
|
151
|
+
If the checkpoint file is not found.
|
|
152
|
+
KeyError
|
|
153
|
+
If the checkpoint is missing required keys ('model_state_dict_noise_predictor'
|
|
154
|
+
or 'optimizer_state_dict').
|
|
155
|
+
|
|
156
|
+
Warns
|
|
157
|
+
-----
|
|
158
|
+
warnings.warn
|
|
159
|
+
If the optimizer state cannot be loaded, if the checkpoint contains a
|
|
160
|
+
conditional model state but none is defined, or if no conditional model
|
|
161
|
+
state is provided when expected.
|
|
162
|
+
"""
|
|
163
|
+
try:
|
|
164
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
165
|
+
except FileNotFoundError:
|
|
166
|
+
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
|
167
|
+
|
|
168
|
+
if 'model_state_dict_noise_predictor' not in checkpoint:
|
|
169
|
+
raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
|
|
170
|
+
self.noise_predictor.load_state_dict(checkpoint['model_state_dict_noise_predictor'])
|
|
171
|
+
|
|
172
|
+
if self.conditional_model is not None:
|
|
173
|
+
if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
|
|
174
|
+
self.conditional_model.load_state_dict(checkpoint['model_state_dict_conditional'])
|
|
175
|
+
else:
|
|
176
|
+
warnings.warn(
|
|
177
|
+
"Checkpoint contains no 'model_state_dict_conditional' or it is None, skipping conditional model loading")
|
|
178
|
+
elif 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
|
|
179
|
+
warnings.warn(
|
|
180
|
+
"Checkpoint contains conditional model state, but no conditional model is defined in this instance")
|
|
181
|
+
|
|
182
|
+
if 'optimizer_state_dict' not in checkpoint:
|
|
183
|
+
raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
|
|
184
|
+
try:
|
|
185
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
186
|
+
except ValueError as e:
|
|
187
|
+
warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
|
|
188
|
+
|
|
189
|
+
epoch = checkpoint.get('epoch', -1)
|
|
190
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
191
|
+
|
|
192
|
+
self.noise_predictor.to(self.device)
|
|
193
|
+
if self.conditional_model is not None:
|
|
194
|
+
self.conditional_model.to(self.device)
|
|
195
|
+
|
|
196
|
+
print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
|
|
197
|
+
return epoch, loss
|
|
198
|
+
|
|
199
|
+
@staticmethod
|
|
200
|
+
def warmup_scheduler(optimizer, warmup_epochs=10):
|
|
201
|
+
"""Creates a learning rate scheduler for warmup.
|
|
202
|
+
|
|
203
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
204
|
+
optimizer's initial value over the specified warmup epochs, then maintains it.
|
|
205
|
+
|
|
206
|
+
Parameters
|
|
207
|
+
----------
|
|
208
|
+
optimizer : torch.optim.Optimizer
|
|
209
|
+
Optimizer to apply the scheduler to.
|
|
210
|
+
warmup_epochs : int, optional
|
|
211
|
+
Number of epochs for the warmup phase (default: 10).
|
|
212
|
+
|
|
213
|
+
Returns
|
|
214
|
+
-------
|
|
215
|
+
torch.optim.lr_scheduler.LambdaLR
|
|
216
|
+
Learning rate scheduler for warmup.
|
|
217
|
+
"""
|
|
218
|
+
def lr_lambda(epoch):
|
|
219
|
+
if epoch < warmup_epochs:
|
|
220
|
+
return epoch / warmup_epochs
|
|
221
|
+
return 1.0
|
|
222
|
+
|
|
223
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
224
|
+
|
|
225
|
+
def forward(self):
|
|
226
|
+
"""Trains the LDM to predict noise added by the forward diffusion process in the latent space.
|
|
227
|
+
|
|
228
|
+
Executes the training loop, optimizing the noise predictor and conditional model
|
|
229
|
+
(if applicable) using mixed precision, gradient clipping, and learning rate
|
|
230
|
+
scheduling. Uses a pre-trained compressor model to encode images into the latent
|
|
231
|
+
space. Supports validation, early stopping, and checkpointing.
|
|
232
|
+
|
|
233
|
+
Returns
|
|
234
|
+
-------
|
|
235
|
+
tuple
|
|
236
|
+
A tuple containing:
|
|
237
|
+
- train_losses: List of mean training losses per epoch (list of float).
|
|
238
|
+
- best_val_loss: Best validation or training loss achieved (float).
|
|
239
|
+
|
|
240
|
+
Notes
|
|
241
|
+
-----
|
|
242
|
+
- Training uses mixed precision via `torch.cuda.amp` for efficiency.
|
|
243
|
+
- The compressor model is assumed pre-trained and set to evaluation mode.
|
|
244
|
+
- Checkpoints are saved when the validation (or training) loss improves, and on
|
|
245
|
+
early stopping.
|
|
246
|
+
- Early stopping is triggered if no improvement occurs for `patience` epochs.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
self.noise_predictor.train()
|
|
250
|
+
self.noise_predictor.to(self.device)
|
|
251
|
+
if self.conditional_model is not None:
|
|
252
|
+
self.conditional_model.train()
|
|
253
|
+
self.conditional_model.to(self.device)
|
|
254
|
+
if self.compressor_model is not None:
|
|
255
|
+
self.compressor_model.eval() # the model is already trained
|
|
256
|
+
self.compressor_model.to(self.device)
|
|
257
|
+
|
|
258
|
+
scaler = GradScaler()
|
|
259
|
+
train_losses = []
|
|
260
|
+
best_val_loss = float("inf")
|
|
261
|
+
wait = 0
|
|
262
|
+
for epoch in range(self.max_epoch):
|
|
263
|
+
train_losses_ = []
|
|
264
|
+
for x, y in tqdm(self.data_loader):
|
|
265
|
+
x = x.to(self.device)
|
|
266
|
+
with torch.no_grad():
|
|
267
|
+
x, _ = self.compressor_model.encode(x) # using compressor model bring x to latent space
|
|
268
|
+
if self.conditional_model is not None:
|
|
269
|
+
y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
|
|
270
|
+
y_list = [str(item) for item in y_list]
|
|
271
|
+
y_encoded = self.tokenizer(
|
|
272
|
+
y_list,
|
|
273
|
+
padding="max_length",
|
|
274
|
+
truncation=True,
|
|
275
|
+
max_length=self.max_length,
|
|
276
|
+
return_tensors="pt"
|
|
277
|
+
).to(self.device)
|
|
278
|
+
input_ids = y_encoded["input_ids"]
|
|
279
|
+
attention_mask = y_encoded["attention_mask"]
|
|
280
|
+
y_encoded = self.conditional_model(input_ids, attention_mask)
|
|
281
|
+
else:
|
|
282
|
+
y_encoded = None
|
|
283
|
+
|
|
284
|
+
self.optimizer.zero_grad()
|
|
285
|
+
with autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
|
|
286
|
+
noise = torch.randn_like(x).to(self.device)
|
|
287
|
+
t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
|
|
288
|
+
assert x.device == noise.device == t.device, "Device mismatch detected"
|
|
289
|
+
assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
|
|
290
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
291
|
+
p_noise = self.noise_predictor(noisy_x, t, y_encoded)
|
|
292
|
+
loss = self.objective(p_noise, noise)
|
|
293
|
+
scaler.scale(loss).backward()
|
|
294
|
+
|
|
295
|
+
nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
|
|
296
|
+
if self.conditional_model is not None:
|
|
297
|
+
nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
|
|
298
|
+
scaler.step(self.optimizer)
|
|
299
|
+
scaler.update()
|
|
300
|
+
self.warmup_lr_scheduler.step()
|
|
301
|
+
train_losses_.append(loss.item())
|
|
302
|
+
|
|
303
|
+
mean_train_loss = torch.mean(torch.tensor(train_losses_)).item()
|
|
304
|
+
train_losses.append(mean_train_loss)
|
|
305
|
+
print(f"\nEpoch: {epoch + 1} | Train Loss: {mean_train_loss:.4f}", end="")
|
|
306
|
+
|
|
307
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
308
|
+
val_loss = self.validate()
|
|
309
|
+
print(f" | Val Loss: {val_loss:.4f}")
|
|
310
|
+
current_best = val_loss
|
|
311
|
+
self.scheduler.step(val_loss)
|
|
312
|
+
else:
|
|
313
|
+
print()
|
|
314
|
+
current_best = mean_train_loss
|
|
315
|
+
self.scheduler.step(mean_train_loss)
|
|
316
|
+
|
|
317
|
+
if current_best < best_val_loss:
|
|
318
|
+
best_val_loss = current_best
|
|
319
|
+
wait = 0
|
|
320
|
+
try:
|
|
321
|
+
torch.save({
|
|
322
|
+
'epoch': epoch + 1,
|
|
323
|
+
'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
|
|
324
|
+
'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
|
|
325
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
326
|
+
'loss': best_val_loss,
|
|
327
|
+
'hyper_params_model': self.hyper_params_model,
|
|
328
|
+
'max_epoch': self.max_epoch,
|
|
329
|
+
}, self.store_path)
|
|
330
|
+
print(f"Model saved at epoch {epoch + 1}")
|
|
331
|
+
except Exception as e:
|
|
332
|
+
print(f"Failed to save model: {e}")
|
|
333
|
+
else:
|
|
334
|
+
wait += 1
|
|
335
|
+
if wait >= self.patience:
|
|
336
|
+
print("Early stopping triggered")
|
|
337
|
+
try:
|
|
338
|
+
torch.save({
|
|
339
|
+
'epoch': epoch + 1,
|
|
340
|
+
'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
|
|
341
|
+
'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
|
|
342
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
343
|
+
'loss': best_val_loss,
|
|
344
|
+
'hyper_params_model': self.hyper_params_model,
|
|
345
|
+
'max_epoch': self.max_epoch,
|
|
346
|
+
}, self.store_path + "_early_stop.pth")
|
|
347
|
+
print(f"Final model saved at {self.store_path}_early_stop.pth")
|
|
348
|
+
except Exception as e:
|
|
349
|
+
print(f"Failed to save final model: {e}")
|
|
350
|
+
break
|
|
351
|
+
|
|
352
|
+
return train_losses, best_val_loss
|
|
353
|
+
|
|
354
|
+
def validate(self):
|
|
355
|
+
"""Validates the LDM on the validation dataset.
|
|
356
|
+
|
|
357
|
+
Computes the validation loss using the noise predictor and forward diffusion
|
|
358
|
+
process in the latent space, with optional conditional inputs.
|
|
359
|
+
|
|
360
|
+
Returns
|
|
361
|
+
-------
|
|
362
|
+
float
|
|
363
|
+
Mean validation loss across the validation dataset.
|
|
364
|
+
|
|
365
|
+
Notes
|
|
366
|
+
-----
|
|
367
|
+
- Validation is performed with `torch.no_grad()` for efficiency.
|
|
368
|
+
- The compressor model is used to encode validation data into the latent space.
|
|
369
|
+
- The noise predictor and conditional model (if applicable) are set to evaluation
|
|
370
|
+
mode during validation and restored to training mode afterward.
|
|
371
|
+
"""
|
|
372
|
+
self.noise_predictor.eval()
|
|
373
|
+
if self.conditional_model is not None:
|
|
374
|
+
self.conditional_model.eval()
|
|
375
|
+
|
|
376
|
+
val_losses = []
|
|
377
|
+
with torch.no_grad():
|
|
378
|
+
for x, y in self.val_loader:
|
|
379
|
+
x = x.to(self.device)
|
|
380
|
+
with torch.no_grad():
|
|
381
|
+
x, _ = self.compressor_model.encode(x)
|
|
382
|
+
|
|
383
|
+
if self.conditional_model is not None:
|
|
384
|
+
y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
|
|
385
|
+
y_list = [str(item) for item in y_list]
|
|
386
|
+
y_encoded = self.tokenizer(
|
|
387
|
+
y_list,
|
|
388
|
+
padding="max_length",
|
|
389
|
+
truncation=True,
|
|
390
|
+
max_length=self.max_length,
|
|
391
|
+
return_tensors="pt"
|
|
392
|
+
).to(self.device)
|
|
393
|
+
input_ids = y_encoded["input_ids"]
|
|
394
|
+
attention_mask = y_encoded["attention_mask"]
|
|
395
|
+
y_encoded = self.conditional_model(input_ids, attention_mask)
|
|
396
|
+
else:
|
|
397
|
+
y_encoded = None
|
|
398
|
+
|
|
399
|
+
noise = torch.randn_like(x).to(self.device)
|
|
400
|
+
t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
|
|
401
|
+
assert x.device == noise.device == t.device, "Device mismatch detected"
|
|
402
|
+
assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
|
|
403
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
404
|
+
p_noise = self.noise_predictor(noisy_x, t, y_encoded)
|
|
405
|
+
loss = self.objective(p_noise, noise)
|
|
406
|
+
val_losses.append(loss.item())
|
|
407
|
+
|
|
408
|
+
mean_val_loss = torch.mean(torch.tensor(val_losses)).item()
|
|
409
|
+
self.noise_predictor.train()
|
|
410
|
+
if self.conditional_model is not None:
|
|
411
|
+
self.conditional_model.train()
|
|
412
|
+
return mean_val_loss
|
sde/__init__.py
ADDED
|
File without changes
|
sde/forward_sde.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ForwardSDE(nn.Module):
|
|
7
|
+
"""Forward diffusion process for SDE-based generative models.
|
|
8
|
+
|
|
9
|
+
Implements the forward diffusion process for score-based generative models using
|
|
10
|
+
Stochastic Differential Equations (SDEs), supporting Variance Exploding (VE),
|
|
11
|
+
Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE methods, as
|
|
12
|
+
described in Song et al. (2021).
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
hyper_params : object
|
|
17
|
+
Hyperparameter object containing SDE-specific parameters. Expected to have
|
|
18
|
+
attributes:
|
|
19
|
+
- `dt`: Time step size for SDE integration (float).
|
|
20
|
+
- `sigmas`: Sigma values for VE method (torch.Tensor, optional).
|
|
21
|
+
- `betas`: Beta values for VP, sub-VP, or ODE methods (torch.Tensor).
|
|
22
|
+
- `cum_betas`: Cumulative beta values for sub-VP method (torch.Tensor, optional).
|
|
23
|
+
method : str
|
|
24
|
+
SDE method to use. Supported methods: "ve", "vp", "sub-vp", "ode".
|
|
25
|
+
|
|
26
|
+
Attributes
|
|
27
|
+
----------
|
|
28
|
+
hyper_params : object
|
|
29
|
+
Stores the provided hyperparameter object.
|
|
30
|
+
method : str
|
|
31
|
+
Selected SDE method.
|
|
32
|
+
|
|
33
|
+
Raises
|
|
34
|
+
------
|
|
35
|
+
ValueError
|
|
36
|
+
If `method` is not one of the supported methods ("ve", "vp", "sub-vp", "ode").
|
|
37
|
+
"""
|
|
38
|
+
def __init__(self, hyper_params, method):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.hyper_params = hyper_params
|
|
41
|
+
self.method = method
|
|
42
|
+
|
|
43
|
+
def forward(self, x0, noise, time_steps):
|
|
44
|
+
"""Applies the forward SDE diffusion process to the input data.
|
|
45
|
+
|
|
46
|
+
Perturbs the input data `x0` by adding noise according to the specified SDE
|
|
47
|
+
method at given time steps, incorporating drift and diffusion terms as applicable.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
x0 : torch.Tensor
|
|
52
|
+
Input data tensor, shape (batch_size, channels, height, width).
|
|
53
|
+
noise : torch.Tensor
|
|
54
|
+
Gaussian noise tensor, same shape as `x0`.
|
|
55
|
+
time_steps : torch.Tensor
|
|
56
|
+
Tensor of time step indices (long), shape (batch_size,), where each value
|
|
57
|
+
is in the range [0, hyper_params.num_steps - 1].
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
torch.Tensor
|
|
62
|
+
Noisy data tensor at the specified time steps, same shape as `x0`.
|
|
63
|
+
|
|
64
|
+
Raises
|
|
65
|
+
------
|
|
66
|
+
ValueError
|
|
67
|
+
If `method` is not one of the supported methods ("ve", "vp", "sub-vp", "ode").
|
|
68
|
+
"""
|
|
69
|
+
dt = self.hyper_params.dt
|
|
70
|
+
if self.method == "ve":
|
|
71
|
+
sigma_t = self.hyper_params.sigmas[time_steps]
|
|
72
|
+
sigma_t_prev = self.hyper_params.sigmas[time_steps - 1] if time_steps.min() > 0 else torch.zeros_like(sigma_t)
|
|
73
|
+
sigma_diff = torch.sqrt(torch.clamp(sigma_t ** 2 - sigma_t_prev ** 2, min=0))
|
|
74
|
+
x0 = x0 + noise * sigma_diff.view(-1, 1, 1, 1)
|
|
75
|
+
|
|
76
|
+
elif self.method == "vp":
|
|
77
|
+
betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
|
|
78
|
+
drift = -0.5 * betas * x0 * dt
|
|
79
|
+
diffusion = torch.sqrt(betas * dt) * noise
|
|
80
|
+
x0 = x0 + drift + diffusion
|
|
81
|
+
|
|
82
|
+
elif self.method == "sub-vp":
|
|
83
|
+
betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
|
|
84
|
+
cum_betas = self.hyper_params.cum_betas[time_steps].view(-1, 1, 1, 1)
|
|
85
|
+
drift = -0.5 * betas * x0 * dt
|
|
86
|
+
diffusion = torch.sqrt(betas * (1 - torch.exp(-2 * cum_betas)) * dt) * noise
|
|
87
|
+
x0 = x0 + drift + diffusion
|
|
88
|
+
|
|
89
|
+
elif self.method == "ode":
|
|
90
|
+
if self.method == "ve":
|
|
91
|
+
x0 = x0
|
|
92
|
+
else:
|
|
93
|
+
betas = self.hyper_params.betas[time_steps].view(-1, 1, 1, 1)
|
|
94
|
+
drift = -0.5 * betas * x0 * dt
|
|
95
|
+
x0 = x0 + drift
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(f"Unknown method: {self.method}")
|
|
98
|
+
return x0
|