TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
sde/train_sde.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.cuda.amp import GradScaler, autocast
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
from torch.amp import GradScaler, autocast
|
|
6
|
+
from torch.optim.lr_scheduler import LambdaLR
|
|
7
|
+
from transformers import BertTokenizer
|
|
8
|
+
from forward_sde import ForwardSDE
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TrainSDE(nn.Module):
|
|
14
|
+
"""Trainer for score-based generative models using Stochastic Differential Equations.
|
|
15
|
+
|
|
16
|
+
Manages the training process for SDE-based generative models, optimizing a noise
|
|
17
|
+
predictor to learn the noise added by the forward SDE process, as described in Song
|
|
18
|
+
et al. (2021). Supports conditional training with text prompts, mixed precision,
|
|
19
|
+
learning rate scheduling, early stopping, and checkpointing.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
method : str
|
|
24
|
+
SDE method to use for forward diffusion. Supported methods: "ve", "vp", "sub-vp", "ode".
|
|
25
|
+
noise_predictor : nn.Module
|
|
26
|
+
Model to predict noise added during the forward SDE process.
|
|
27
|
+
hyper_params_model : nn.Module
|
|
28
|
+
Hyperparameter module (e.g., HyperParamsSDE) defining the noise schedule and SDE parameters.
|
|
29
|
+
data_loader : torch.utils.data.DataLoader
|
|
30
|
+
DataLoader for training data.
|
|
31
|
+
optimizer : torch.optim.Optimizer
|
|
32
|
+
Optimizer for training the noise predictor and conditional model (if applicable).
|
|
33
|
+
objective : callable
|
|
34
|
+
Loss function to compute the difference between predicted and actual noise.
|
|
35
|
+
val_loader : torch.utils.data.DataLoader, optional
|
|
36
|
+
DataLoader for validation data, default None.
|
|
37
|
+
max_epoch : int, optional
|
|
38
|
+
Maximum number of training epochs (default: 1000).
|
|
39
|
+
device : torch.device, optional
|
|
40
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
41
|
+
conditional_model : nn.Module, optional
|
|
42
|
+
Model for conditional generation (e.g., text embeddings), default None.
|
|
43
|
+
tokenizer : BertTokenizer, optional
|
|
44
|
+
Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
|
|
45
|
+
max_length : int, optional
|
|
46
|
+
Maximum length for tokenized prompts (default: 77).
|
|
47
|
+
store_path : str, optional
|
|
48
|
+
Path to save model checkpoints (default: "sde_model.pth").
|
|
49
|
+
patience : int, optional
|
|
50
|
+
Number of epochs to wait for improvement before early stopping (default: 10).
|
|
51
|
+
warmup_epochs : int, optional
|
|
52
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
53
|
+
val_frequency : int, optional
|
|
54
|
+
Frequency (in epochs) for validation (default: 10).
|
|
55
|
+
|
|
56
|
+
Attributes
|
|
57
|
+
----------
|
|
58
|
+
device : torch.device
|
|
59
|
+
Device used for computation.
|
|
60
|
+
method : str
|
|
61
|
+
Selected SDE method.
|
|
62
|
+
noise_predictor : nn.Module
|
|
63
|
+
Noise prediction model.
|
|
64
|
+
hyper_params_model : nn.Module
|
|
65
|
+
Hyperparameter module for the noise schedule and SDE parameters.
|
|
66
|
+
conditional_model : nn.Module or None
|
|
67
|
+
Conditional model for text-based training, if provided.
|
|
68
|
+
optimizer : torch.optim.Optimizer
|
|
69
|
+
Optimizer for training.
|
|
70
|
+
objective : callable
|
|
71
|
+
Loss function for training.
|
|
72
|
+
store_path : str
|
|
73
|
+
Path for saving checkpoints.
|
|
74
|
+
data_loader : torch.utils.data.DataLoader
|
|
75
|
+
Training data loader.
|
|
76
|
+
val_loader : torch.utils.data.DataLoader or None
|
|
77
|
+
Validation data loader, if provided.
|
|
78
|
+
max_epoch : int
|
|
79
|
+
Maximum training epochs.
|
|
80
|
+
max_length : int
|
|
81
|
+
Maximum length for tokenized prompts.
|
|
82
|
+
patience : int
|
|
83
|
+
Patience for early stopping.
|
|
84
|
+
scheduler : torch.optim.lr_scheduler.ReduceLROnPlateau
|
|
85
|
+
Learning rate scheduler based on validation or training loss.
|
|
86
|
+
forward_diffusion : ForwardSDE
|
|
87
|
+
Forward SDE diffusion module.
|
|
88
|
+
warmup_lr_scheduler : torch.optim.lr_scheduler.LambdaLR
|
|
89
|
+
Learning rate scheduler for warmup.
|
|
90
|
+
val_frequency : int
|
|
91
|
+
Frequency for validation.
|
|
92
|
+
tokenizer : BertTokenizer
|
|
93
|
+
Tokenizer for text prompts.
|
|
94
|
+
|
|
95
|
+
Raises
|
|
96
|
+
------
|
|
97
|
+
ValueError
|
|
98
|
+
If the default tokenizer ("bert-base-uncased") fails to load and no tokenizer is provided.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, method, noise_predictor, hyper_params_model, data_loader, optimizer, objective, val_loader=None,
|
|
102
|
+
max_epoch=1000, device=None, conditional_model=None, tokenizer=None, max_length=77,
|
|
103
|
+
store_path=None, patience=10, warmup_epochs=100, val_frequency=10):
|
|
104
|
+
super().__init__()
|
|
105
|
+
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
106
|
+
self.method = method
|
|
107
|
+
self.noise_predictor = noise_predictor
|
|
108
|
+
self.hyper_params_model = hyper_params_model.to(self.device)
|
|
109
|
+
self.conditional_model = conditional_model
|
|
110
|
+
self.optimizer = optimizer
|
|
111
|
+
self.objective = objective
|
|
112
|
+
self.store_path = store_path or "sde_model.pth"
|
|
113
|
+
self.data_loader = data_loader
|
|
114
|
+
self.val_loader = val_loader
|
|
115
|
+
self.max_epoch = max_epoch
|
|
116
|
+
self.max_length = max_length
|
|
117
|
+
self.patience = patience
|
|
118
|
+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=self.patience, factor=0.5)
|
|
119
|
+
self.forward_diffusion = ForwardSDE(hyper_params=self.hyper_params_model, method=self.method).to(self.device)
|
|
120
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
121
|
+
self.val_frequency = val_frequency
|
|
122
|
+
if tokenizer is None:
|
|
123
|
+
try:
|
|
124
|
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
125
|
+
except Exception as e:
|
|
126
|
+
raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def load_checkpoint(self, checkpoint_path):
|
|
130
|
+
"""Loads a training checkpoint to resume training.
|
|
131
|
+
|
|
132
|
+
Restores the state of the noise predictor, conditional model (if applicable),
|
|
133
|
+
and optimizer from a saved checkpoint.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
checkpoint_path : str
|
|
138
|
+
Path to the checkpoint file.
|
|
139
|
+
|
|
140
|
+
Returns
|
|
141
|
+
-------
|
|
142
|
+
tuple
|
|
143
|
+
A tuple containing:
|
|
144
|
+
- epoch: The epoch at which the checkpoint was saved (int).
|
|
145
|
+
- loss: The loss at the checkpoint (float).
|
|
146
|
+
|
|
147
|
+
Raises
|
|
148
|
+
------
|
|
149
|
+
FileNotFoundError
|
|
150
|
+
If the checkpoint file is not found.
|
|
151
|
+
KeyError
|
|
152
|
+
If the checkpoint is missing required keys ('model_state_dict_noise_predictor'
|
|
153
|
+
or 'optimizer_state_dict').
|
|
154
|
+
|
|
155
|
+
Warns
|
|
156
|
+
-----
|
|
157
|
+
warnings.warn
|
|
158
|
+
If the optimizer state cannot be loaded, if the checkpoint contains a
|
|
159
|
+
conditional model state but none is defined, or if no conditional model
|
|
160
|
+
state is provided when expected.
|
|
161
|
+
"""
|
|
162
|
+
try:
|
|
163
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
164
|
+
except FileNotFoundError:
|
|
165
|
+
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
|
|
166
|
+
|
|
167
|
+
if 'model_state_dict_noise_predictor' not in checkpoint:
|
|
168
|
+
raise KeyError("Checkpoint missing 'model_state_dict_noise_predictor' key")
|
|
169
|
+
self.noise_predictor.load_state_dict(checkpoint['model_state_dict_noise_predictor'])
|
|
170
|
+
|
|
171
|
+
if self.conditional_model is not None:
|
|
172
|
+
if 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
|
|
173
|
+
self.conditional_model.load_state_dict(checkpoint['model_state_dict_conditional'])
|
|
174
|
+
else:
|
|
175
|
+
warnings.warn(
|
|
176
|
+
"Checkpoint contains no 'model_state_dict_conditional' or it is None, skipping conditional model loading")
|
|
177
|
+
elif 'model_state_dict_conditional' in checkpoint and checkpoint['model_state_dict_conditional'] is not None:
|
|
178
|
+
warnings.warn(
|
|
179
|
+
"Checkpoint contains conditional model state, but no conditional model is defined in this instance")
|
|
180
|
+
|
|
181
|
+
if 'optimizer_state_dict' not in checkpoint:
|
|
182
|
+
raise KeyError("Checkpoint missing 'optimizer_state_dict' key")
|
|
183
|
+
try:
|
|
184
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
185
|
+
except ValueError as e:
|
|
186
|
+
warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
|
|
187
|
+
|
|
188
|
+
epoch = checkpoint.get('epoch', -1)
|
|
189
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
190
|
+
|
|
191
|
+
self.noise_predictor.to(self.device)
|
|
192
|
+
if self.conditional_model is not None:
|
|
193
|
+
self.conditional_model.to(self.device)
|
|
194
|
+
|
|
195
|
+
print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
|
|
196
|
+
return epoch, loss
|
|
197
|
+
|
|
198
|
+
@staticmethod
|
|
199
|
+
def warmup_scheduler(optimizer, warmup_epochs=10):
|
|
200
|
+
"""Creates a learning rate scheduler for warmup.
|
|
201
|
+
|
|
202
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
203
|
+
optimizer's initial value over the specified warmup epochs, then maintains it.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
optimizer : torch.optim.Optimizer
|
|
208
|
+
Optimizer to apply the scheduler to.
|
|
209
|
+
warmup_epochs : int, optional
|
|
210
|
+
Number of epochs for the warmup phase (default: 10).
|
|
211
|
+
|
|
212
|
+
Returns
|
|
213
|
+
-------
|
|
214
|
+
torch.optim.lr_scheduler.LambdaLR
|
|
215
|
+
Learning rate scheduler for warmup.
|
|
216
|
+
"""
|
|
217
|
+
def lr_lambda(epoch):
|
|
218
|
+
if epoch < warmup_epochs:
|
|
219
|
+
return epoch / warmup_epochs
|
|
220
|
+
return 1.0
|
|
221
|
+
|
|
222
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
223
|
+
|
|
224
|
+
def forward(self):
|
|
225
|
+
"""Trains the SDE model to predict noise added by the forward diffusion process.
|
|
226
|
+
|
|
227
|
+
Executes the training loop, optimizing the noise predictor and conditional model
|
|
228
|
+
(if applicable) using mixed precision, gradient clipping, and learning rate
|
|
229
|
+
scheduling. Supports validation, early stopping, and checkpointing.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
tuple
|
|
234
|
+
A tuple containing:
|
|
235
|
+
- train_losses: List of mean training losses per epoch (list of float).
|
|
236
|
+
- best_val_loss: Best validation or training loss achieved (float).
|
|
237
|
+
|
|
238
|
+
Notes
|
|
239
|
+
-----
|
|
240
|
+
- Training uses mixed precision via `torch.cuda.amp` or `torch.amp` for efficiency.
|
|
241
|
+
- Checkpoints are saved when the validation (or training) loss improves, and on
|
|
242
|
+
early stopping.
|
|
243
|
+
- Early stopping is triggered if no improvement occurs for `patience` epochs.
|
|
244
|
+
"""
|
|
245
|
+
self.noise_predictor.train()
|
|
246
|
+
self.noise_predictor.to(self.device)
|
|
247
|
+
if self.conditional_model is not None:
|
|
248
|
+
self.conditional_model.train()
|
|
249
|
+
self.conditional_model.to(self.device)
|
|
250
|
+
|
|
251
|
+
scaler = GradScaler()
|
|
252
|
+
train_losses = []
|
|
253
|
+
best_val_loss = float("inf")
|
|
254
|
+
wait = 0
|
|
255
|
+
for epoch in range(self.max_epoch):
|
|
256
|
+
train_losses_ = []
|
|
257
|
+
for x, y in tqdm(self.data_loader):
|
|
258
|
+
x = x.to(self.device)
|
|
259
|
+
|
|
260
|
+
if self.conditional_model is not None:
|
|
261
|
+
y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
|
|
262
|
+
y_list = [str(item) for item in y_list]
|
|
263
|
+
y_encoded = self.tokenizer(
|
|
264
|
+
y_list,
|
|
265
|
+
padding="max_length",
|
|
266
|
+
truncation=True,
|
|
267
|
+
max_length=self.max_length,
|
|
268
|
+
return_tensors="pt"
|
|
269
|
+
).to(self.device)
|
|
270
|
+
input_ids = y_encoded["input_ids"]
|
|
271
|
+
attention_mask = y_encoded["attention_mask"]
|
|
272
|
+
y_encoded = self.conditional_model(input_ids, attention_mask)
|
|
273
|
+
else:
|
|
274
|
+
y_encoded = None
|
|
275
|
+
|
|
276
|
+
self.optimizer.zero_grad()
|
|
277
|
+
with autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
|
|
278
|
+
noise = torch.randn_like(x).to(self.device)
|
|
279
|
+
t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
|
|
280
|
+
assert x.device == noise.device == t.device, "Device mismatch detected"
|
|
281
|
+
assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
|
|
282
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
283
|
+
p_noise = self.noise_predictor(noisy_x, t, y_encoded)
|
|
284
|
+
loss = self.objective(p_noise, noise)
|
|
285
|
+
scaler.scale(loss).backward()
|
|
286
|
+
nn.utils.clip_grad_norm_(self.noise_predictor.parameters(), max_norm=1.0)
|
|
287
|
+
if self.conditional_model is not None:
|
|
288
|
+
nn.utils.clip_grad_norm_(self.conditional_model.parameters(), max_norm=1.0)
|
|
289
|
+
scaler.step(self.optimizer)
|
|
290
|
+
scaler.update()
|
|
291
|
+
self.warmup_lr_scheduler.step()
|
|
292
|
+
train_losses_.append(loss.item())
|
|
293
|
+
|
|
294
|
+
mean_train_loss = torch.mean(torch.tensor(train_losses_)).item()
|
|
295
|
+
train_losses.append(mean_train_loss)
|
|
296
|
+
print(f"\nEpoch: {epoch + 1} | Train Loss: {mean_train_loss:.4f}", end="")
|
|
297
|
+
|
|
298
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
299
|
+
val_loss = self.validate()
|
|
300
|
+
print(f" | Val Loss: {val_loss:.4f}")
|
|
301
|
+
current_best = val_loss
|
|
302
|
+
self.scheduler.step(val_loss)
|
|
303
|
+
else:
|
|
304
|
+
print()
|
|
305
|
+
current_best = mean_train_loss
|
|
306
|
+
self.scheduler.step(mean_train_loss)
|
|
307
|
+
|
|
308
|
+
if current_best < best_val_loss:
|
|
309
|
+
best_val_loss = current_best
|
|
310
|
+
wait = 0
|
|
311
|
+
try:
|
|
312
|
+
torch.save({
|
|
313
|
+
'epoch': epoch+1,
|
|
314
|
+
'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
|
|
315
|
+
'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
|
|
316
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
317
|
+
'loss': best_val_loss,
|
|
318
|
+
'hyper_params_model': self.hyper_params_model,
|
|
319
|
+
'max_epoch': self.max_epoch,
|
|
320
|
+
}, self.store_path)
|
|
321
|
+
print(f"Model saved at epoch {epoch + 1}")
|
|
322
|
+
except Exception as e:
|
|
323
|
+
print(f"Failed to save model: {e}")
|
|
324
|
+
else:
|
|
325
|
+
wait += 1
|
|
326
|
+
if wait >= self.patience:
|
|
327
|
+
print("Early stopping triggered")
|
|
328
|
+
try:
|
|
329
|
+
torch.save({
|
|
330
|
+
'epoch': epoch+1,
|
|
331
|
+
'model_state_dict_noise_predictor': self.noise_predictor.state_dict(),
|
|
332
|
+
'model_state_dict_conditional': self.conditional_model.state_dict() if self.conditional_model is not None else None,
|
|
333
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
334
|
+
'loss': best_val_loss,
|
|
335
|
+
'hyper_params_model': self.hyper_params_model,
|
|
336
|
+
'max_epoch': self.max_epoch,
|
|
337
|
+
}, self.store_path + "_early_stop.pth")
|
|
338
|
+
print(f"Final model saved at {self.store_path}_early_stop.pth")
|
|
339
|
+
except Exception as e:
|
|
340
|
+
print(f"Failed to save final model: {e}")
|
|
341
|
+
break
|
|
342
|
+
|
|
343
|
+
return train_losses, best_val_loss
|
|
344
|
+
|
|
345
|
+
def validate(self):
|
|
346
|
+
"""Validates the SDE model on the validation dataset.
|
|
347
|
+
|
|
348
|
+
Computes the validation loss using the noise predictor and forward SDE diffusion
|
|
349
|
+
process, with optional conditional inputs.
|
|
350
|
+
|
|
351
|
+
Returns
|
|
352
|
+
-------
|
|
353
|
+
float
|
|
354
|
+
Mean validation loss across the validation dataset.
|
|
355
|
+
|
|
356
|
+
Notes
|
|
357
|
+
-----
|
|
358
|
+
- Validation is performed with `torch.no_grad()` for efficiency.
|
|
359
|
+
- The noise predictor and conditional model (if applicable) are set to evaluation
|
|
360
|
+
mode during validation and restored to training mode afterward.
|
|
361
|
+
"""
|
|
362
|
+
self.noise_predictor.eval()
|
|
363
|
+
if self.conditional_model is not None:
|
|
364
|
+
self.conditional_model.eval()
|
|
365
|
+
|
|
366
|
+
val_losses = []
|
|
367
|
+
with torch.no_grad():
|
|
368
|
+
for x, y in self.val_loader:
|
|
369
|
+
x = x.to(self.device)
|
|
370
|
+
|
|
371
|
+
if self.conditional_model is not None:
|
|
372
|
+
y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
|
|
373
|
+
y_list = [str(item) for item in y_list]
|
|
374
|
+
y_encoded = self.tokenizer(
|
|
375
|
+
y_list,
|
|
376
|
+
padding="max_length",
|
|
377
|
+
truncation=True,
|
|
378
|
+
max_length=self.max_length,
|
|
379
|
+
return_tensors="pt"
|
|
380
|
+
).to(self.device)
|
|
381
|
+
input_ids = y_encoded["input_ids"]
|
|
382
|
+
attention_mask = y_encoded["attention_mask"]
|
|
383
|
+
y_encoded = self.conditional_model(input_ids, attention_mask)
|
|
384
|
+
else:
|
|
385
|
+
y_encoded = None
|
|
386
|
+
|
|
387
|
+
noise = torch.randn_like(x).to(self.device)
|
|
388
|
+
t = torch.randint(0, self.hyper_params_model.num_steps, (x.shape[0],)).to(self.device)
|
|
389
|
+
assert x.device == noise.device == t.device, "Device mismatch detected"
|
|
390
|
+
assert t.shape[0] == x.shape[0], "Timestep batch size mismatch"
|
|
391
|
+
noisy_x = self.forward_diffusion(x, noise, t)
|
|
392
|
+
p_noise = self.noise_predictor(noisy_x, t, y_encoded)
|
|
393
|
+
loss = self.objective(p_noise, noise)
|
|
394
|
+
val_losses.append(loss.item())
|
|
395
|
+
|
|
396
|
+
mean_val_loss = torch.mean(torch.tensor(val_losses)).item()
|
|
397
|
+
self.noise_predictor.train()
|
|
398
|
+
if self.conditional_model is not None:
|
|
399
|
+
self.conditional_model.train()
|
|
400
|
+
return mean_val_loss
|
torchdiff/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
__version__ = "2.0.0"
|
|
2
|
+
|
|
3
|
+
from .ddim import ForwardDDIM, ReverseDDIM, VarianceSchedulerDDIM, TrainDDIM, SampleDDIM
|
|
4
|
+
from .ddpm import ForwardDDPM, ReverseDDPM, VarianceSchedulerDDPM, TrainDDPM, SampleDDPM
|
|
5
|
+
from .ldm import TrainLDM, TrainAE, AutoencoderLDM, SampleLDM
|
|
6
|
+
from .sde import ForwardSDE, ReverseSDE, VarianceSchedulerSDE, TrainSDE, SampleSDE
|
|
7
|
+
from .unclip import ForwardUnCLIP, ReverseUnCLIP, VarianceSchedulerUnCLIP, CLIPEncoder, SampleUnCLIP, UnClipDecoder, UnCLIPTransformerPrior, CLIPContextProjection, CLIPEmbeddingProjection, TrainUnClipDecoder, SampleUnCLIP, UpsamplerUnCLIP, TrainUpsamplerUnCLIP
|
|
8
|
+
from .utils import NoisePredictor, TextEncoder, Metrics
|