TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
unclip/train_prior.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Optional, List, Tuple, Union, Callable
|
|
4
|
+
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
|
5
|
+
import torch.distributed as dist
|
|
6
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
7
|
+
from torch.distributed import init_process_group, destroy_process_group
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
import warnings
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TrainUnCLIPPrior(nn.Module):
|
|
15
|
+
"""Trainer for the UnCLIPTransformerPrior model.
|
|
16
|
+
|
|
17
|
+
Handles the training of the UnCLIP prior model to predict clean image embeddings from
|
|
18
|
+
noisy image embeddings and text embeddings, with support for dimension reduction,
|
|
19
|
+
mixed precision training, and distributed training.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
`prior_model` : nn.Module
|
|
24
|
+
The UnCLIP prior model to be trained (e.g., UnCLIPTransformerPrior).
|
|
25
|
+
`clip_model` : nn.Module
|
|
26
|
+
CLIP model for encoding text and images.
|
|
27
|
+
`train_loader` : torch.utils.data.DataLoader
|
|
28
|
+
DataLoader for training data.
|
|
29
|
+
`optimizer` : torch.optim.Optimizer
|
|
30
|
+
Optimizer for training the prior model.
|
|
31
|
+
`objective` : Callable
|
|
32
|
+
Loss function to compute the difference between predicted and target embeddings.
|
|
33
|
+
`val_loader` : torch.utils.data.DataLoader, optional
|
|
34
|
+
DataLoader for validation data, default None.
|
|
35
|
+
`max_epochs` : int, optional
|
|
36
|
+
Maximum number of training epochs (default: 1000).
|
|
37
|
+
`device` : Union[str, torch.device], optional
|
|
38
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
39
|
+
`store_path` : str, optional
|
|
40
|
+
Directory path to save model checkpoints, default None.
|
|
41
|
+
`patience` : int, optional
|
|
42
|
+
Number of epochs to wait for improvement before early stopping (default: 100).
|
|
43
|
+
`warmup_epochs` : int, optional
|
|
44
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
45
|
+
`val_frequency` : int, optional
|
|
46
|
+
Frequency (in epochs) for validation (default: 10).
|
|
47
|
+
`use_ddp` : bool, optional
|
|
48
|
+
Whether to use Distributed Data Parallel training (default: False).
|
|
49
|
+
`num_grad_accumulation` : int, optional
|
|
50
|
+
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
51
|
+
`log_frequency` : int, optional
|
|
52
|
+
Frequency (in epochs) for printing training progress (default: 1).
|
|
53
|
+
`use_compilation` : bool, optional
|
|
54
|
+
Whether to compile models for optimization (default: False).
|
|
55
|
+
`embedding_output_range` : Tuple[float, float], optional
|
|
56
|
+
Range for clamping output embeddings (default: (-1.0, 1.0)).
|
|
57
|
+
`reduce_clip_embedding_dim` : bool, optional
|
|
58
|
+
Whether to apply dimension reduction to embeddings (default: True).
|
|
59
|
+
`transformer_embedding_dim` : int, optional
|
|
60
|
+
Target dimensionality for reduced embeddings (default: 319).
|
|
61
|
+
`normalize` : bool, optional
|
|
62
|
+
Whether to normalize CLIP embeddings (default: True).
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
prior_model: nn.Module,
|
|
68
|
+
clip_model: nn.Module,
|
|
69
|
+
train_loader: torch.utils.data.DataLoader,
|
|
70
|
+
optimizer: torch.optim.Optimizer,
|
|
71
|
+
objective: Callable,
|
|
72
|
+
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
73
|
+
max_epochs: int = 1000,
|
|
74
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
75
|
+
store_path: Optional[str] = None,
|
|
76
|
+
patience: int = 100,
|
|
77
|
+
warmup_epochs: int = 100,
|
|
78
|
+
val_frequency: int = 10,
|
|
79
|
+
use_ddp: bool = False,
|
|
80
|
+
grad_accumulation_steps: int = 1,
|
|
81
|
+
log_frequency: int = 1,
|
|
82
|
+
use_compilation: bool = False,
|
|
83
|
+
embedding_output_range: Tuple[float, float] = (-1.0, 1.0),
|
|
84
|
+
reduce_clip_embedding_dim: bool = True,
|
|
85
|
+
transformer_embedding_dim: int = 319,
|
|
86
|
+
normalize_clip_embeddings: bool = True
|
|
87
|
+
) -> None:
|
|
88
|
+
super().__init__()
|
|
89
|
+
|
|
90
|
+
# Training configuration
|
|
91
|
+
self.use_ddp = use_ddp
|
|
92
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
93
|
+
if device is None:
|
|
94
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
95
|
+
elif isinstance(device, str):
|
|
96
|
+
self.device = torch.device(device)
|
|
97
|
+
else:
|
|
98
|
+
self.device = device
|
|
99
|
+
|
|
100
|
+
# Setup distributed training
|
|
101
|
+
if self.use_ddp:
|
|
102
|
+
self._setup_ddp()
|
|
103
|
+
else:
|
|
104
|
+
self._setup_single_gpu()
|
|
105
|
+
|
|
106
|
+
# Core models
|
|
107
|
+
self.prior_model = prior_model.to(self.device)
|
|
108
|
+
self.clip_model = clip_model.to(self.device)
|
|
109
|
+
|
|
110
|
+
# Training components
|
|
111
|
+
self.optimizer = optimizer
|
|
112
|
+
self.objective = objective
|
|
113
|
+
self.train_loader = train_loader
|
|
114
|
+
self.val_loader = val_loader
|
|
115
|
+
|
|
116
|
+
# Training parameters
|
|
117
|
+
self.max_epochs = max_epochs
|
|
118
|
+
self.patience = patience
|
|
119
|
+
self.val_frequency = val_frequency
|
|
120
|
+
self.log_frequency = log_frequency
|
|
121
|
+
self.use_compilation = use_compilation
|
|
122
|
+
self.embedding_output_range = embedding_output_range
|
|
123
|
+
self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
|
|
124
|
+
self.normalize_clip_embeddings = normalize_clip_embeddings
|
|
125
|
+
self.transformer_embedding_dim = transformer_embedding_dim
|
|
126
|
+
|
|
127
|
+
# Checkpoint management
|
|
128
|
+
self.store_path = store_path
|
|
129
|
+
# os.makedirs(self.store_path, exist_ok=True)
|
|
130
|
+
|
|
131
|
+
# Learning rate scheduling
|
|
132
|
+
self.scheduler = ReduceLROnPlateau(
|
|
133
|
+
self.optimizer,
|
|
134
|
+
patience=self.patience,
|
|
135
|
+
factor=0.5
|
|
136
|
+
)
|
|
137
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _setup_ddp(self) -> None:
|
|
141
|
+
"""Sets up Distributed Data Parallel training configuration.
|
|
142
|
+
|
|
143
|
+
Initializes the process group, sets up rank information, and configures the CUDA
|
|
144
|
+
device for the current process.
|
|
145
|
+
|
|
146
|
+
Raises
|
|
147
|
+
------
|
|
148
|
+
ValueError
|
|
149
|
+
If required DDP environment variables (RANK, LOCAL_RANK, WORLD_SIZE) are not set.
|
|
150
|
+
RuntimeError
|
|
151
|
+
If CUDA is not available when DDP is enabled.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
|
|
155
|
+
for var in required_env_vars:
|
|
156
|
+
if var not in os.environ:
|
|
157
|
+
raise ValueError(f"DDP enabled but {var} environment variable not set")
|
|
158
|
+
|
|
159
|
+
# Ensure CUDA is available for DDP
|
|
160
|
+
if not torch.cuda.is_available():
|
|
161
|
+
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
162
|
+
|
|
163
|
+
# Initialize process group only if not already initialized
|
|
164
|
+
if not torch.distributed.is_initialized():
|
|
165
|
+
init_process_group(backend="nccl")
|
|
166
|
+
|
|
167
|
+
# Get rank information
|
|
168
|
+
self.ddp_rank = int(os.environ["RANK"]) # Global rank across all nodes
|
|
169
|
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # Local rank on current node
|
|
170
|
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # Total number of processes
|
|
171
|
+
|
|
172
|
+
# Set device and make it current
|
|
173
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
174
|
+
# self.device = f"cuda:{self.ddp_local_rank}"
|
|
175
|
+
torch.cuda.set_device(self.device)
|
|
176
|
+
|
|
177
|
+
# Master process handles logging, checkpointing, etc.
|
|
178
|
+
self.master_process = self.ddp_rank == 0
|
|
179
|
+
|
|
180
|
+
if self.master_process:
|
|
181
|
+
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _setup_single_gpu(self) -> None:
|
|
185
|
+
"""Sets up single GPU or CPU training configuration.
|
|
186
|
+
|
|
187
|
+
Configures the training setup for single-device operation, setting rank and process
|
|
188
|
+
information for non-DDP training.
|
|
189
|
+
"""
|
|
190
|
+
self.ddp_rank = 0
|
|
191
|
+
self.ddp_local_rank = 0
|
|
192
|
+
self.ddp_world_size = 1
|
|
193
|
+
self.master_process = True
|
|
194
|
+
|
|
195
|
+
@staticmethod
|
|
196
|
+
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
|
|
197
|
+
"""Creates a learning rate scheduler for warmup.
|
|
198
|
+
|
|
199
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
200
|
+
optimizer's initial value over the specified warmup epochs.
|
|
201
|
+
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
`optimizer` : torch.optim.Optimizer
|
|
205
|
+
Optimizer to apply the scheduler to.
|
|
206
|
+
`warmup_epochs` : int
|
|
207
|
+
Number of epochs for the warmup phase.
|
|
208
|
+
|
|
209
|
+
Returns
|
|
210
|
+
-------
|
|
211
|
+
lr_scheduler : torch.optim.lr_scheduler.LambdaLR
|
|
212
|
+
Learning rate scheduler for warmup.
|
|
213
|
+
"""
|
|
214
|
+
def lr_lambda(epoch):
|
|
215
|
+
return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
|
|
216
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
217
|
+
|
|
218
|
+
def _wrap_models_for_ddp(self) -> None:
|
|
219
|
+
"""Wraps the prior model with DistributedDataParallel for multi-GPU training.
|
|
220
|
+
|
|
221
|
+
Configures the prior model for DDP, setting device IDs and handling unused parameters.
|
|
222
|
+
"""
|
|
223
|
+
if self.use_ddp:
|
|
224
|
+
# Wrap prior with DDP
|
|
225
|
+
self.prior_model = DDP(
|
|
226
|
+
self.prior_model,
|
|
227
|
+
device_ids=[self.ddp_local_rank],
|
|
228
|
+
find_unused_parameters=True
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def _compile_models(self) -> None:
|
|
232
|
+
"""Compiles models for optimization if supported.
|
|
233
|
+
|
|
234
|
+
Attempts to compile the prior model using torch.compile for performance optimization,
|
|
235
|
+
with fallback to uncompiled models if compilation fails.
|
|
236
|
+
"""
|
|
237
|
+
if self.use_compilation:
|
|
238
|
+
try:
|
|
239
|
+
self.prior_model = torch.compile(self.prior_model)
|
|
240
|
+
|
|
241
|
+
if self.master_process:
|
|
242
|
+
print("Models compiled successfully")
|
|
243
|
+
except Exception as e:
|
|
244
|
+
if self.master_process:
|
|
245
|
+
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
246
|
+
|
|
247
|
+
def forward(self) -> Tuple[List[float], float]:
|
|
248
|
+
"""Trains the UnCLIP prior model.
|
|
249
|
+
|
|
250
|
+
Executes the training loop, optimizing the prior model to predict clean image embeddings
|
|
251
|
+
from noisy embeddings and text conditions, with support for validation, early stopping,
|
|
252
|
+
and checkpointing.
|
|
253
|
+
|
|
254
|
+
Returns
|
|
255
|
+
-------
|
|
256
|
+
train_losses : List[float]
|
|
257
|
+
List of mean training losses per epoch.
|
|
258
|
+
best_val_loss : float
|
|
259
|
+
Best validation or training loss achieved.
|
|
260
|
+
"""
|
|
261
|
+
# Set models to training mode
|
|
262
|
+
self.prior_model.train()
|
|
263
|
+
|
|
264
|
+
# Compile and wrap models
|
|
265
|
+
self._compile_models()
|
|
266
|
+
self._wrap_models_for_ddp()
|
|
267
|
+
|
|
268
|
+
# Initialize training components
|
|
269
|
+
scaler = torch.GradScaler()
|
|
270
|
+
train_losses = []
|
|
271
|
+
best_val_loss = float("inf")
|
|
272
|
+
wait = 0
|
|
273
|
+
|
|
274
|
+
# Main training loop
|
|
275
|
+
for epoch in range(self.max_epochs):
|
|
276
|
+
# Set epoch for distributed sampler if using DDP
|
|
277
|
+
if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
|
|
278
|
+
self.train_loader.sampler.set_epoch(epoch)
|
|
279
|
+
|
|
280
|
+
train_losses_epoch = []
|
|
281
|
+
|
|
282
|
+
# Training step loop with gradient accumulation
|
|
283
|
+
for step, (x, y) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
|
|
284
|
+
x = x.to(self.device, non_blocking=True)
|
|
285
|
+
|
|
286
|
+
# Forward pass with mixed precision
|
|
287
|
+
with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
|
|
288
|
+
loss = self._compute_training_loss(x, y)
|
|
289
|
+
loss = loss / self.grad_accumulation_steps
|
|
290
|
+
|
|
291
|
+
# Backward pass - ONLY ONCE!
|
|
292
|
+
scaler.scale(loss).backward()
|
|
293
|
+
|
|
294
|
+
# Optimizer step with gradient accumulation
|
|
295
|
+
if (step + 1) % self.grad_accumulation_steps == 0:
|
|
296
|
+
self._optimizer_step(scaler)
|
|
297
|
+
# Update learning rate (warmup scheduler)
|
|
298
|
+
self.warmup_lr_scheduler.step()
|
|
299
|
+
|
|
300
|
+
# Record loss (unscaled)
|
|
301
|
+
train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
|
|
302
|
+
|
|
303
|
+
# Compute and sync training loss
|
|
304
|
+
mean_train_loss = self._compute_mean_loss(train_losses_epoch)
|
|
305
|
+
train_losses.append(mean_train_loss)
|
|
306
|
+
|
|
307
|
+
# Print training progress (only master process)
|
|
308
|
+
if self.master_process and (epoch + 1) % self.log_frequency == 0:
|
|
309
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
310
|
+
print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}", end="")
|
|
311
|
+
|
|
312
|
+
# Validation and checkpointing
|
|
313
|
+
current_loss = mean_train_loss
|
|
314
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
315
|
+
val_loss = self.validate()
|
|
316
|
+
current_loss = val_loss
|
|
317
|
+
|
|
318
|
+
if self.master_process:
|
|
319
|
+
print(f" | Val Loss: {val_loss:.4f}")
|
|
320
|
+
elif self.master_process:
|
|
321
|
+
print()
|
|
322
|
+
|
|
323
|
+
# Learning rate scheduling
|
|
324
|
+
self.scheduler.step(current_loss)
|
|
325
|
+
|
|
326
|
+
# Save checkpoint and early stopping
|
|
327
|
+
if self.master_process:
|
|
328
|
+
if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
|
|
329
|
+
best_val_loss = current_loss
|
|
330
|
+
wait = 0
|
|
331
|
+
self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
|
|
332
|
+
else:
|
|
333
|
+
wait += 1
|
|
334
|
+
if wait >= self.patience:
|
|
335
|
+
print("Early stopping triggered")
|
|
336
|
+
self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
|
|
337
|
+
break
|
|
338
|
+
|
|
339
|
+
# Cleanup
|
|
340
|
+
if self.use_ddp:
|
|
341
|
+
destroy_process_group()
|
|
342
|
+
|
|
343
|
+
return train_losses, best_val_loss
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def _compute_training_loss(self, images: torch.Tensor, texts: List[str]) -> torch.Tensor:
|
|
347
|
+
"""Computes the training loss for the UnCLIP prior model.
|
|
348
|
+
|
|
349
|
+
Calculates the loss by encoding images and text with CLIP, applying forward diffusion,
|
|
350
|
+
predicting clean embeddings, and comparing with target embeddings.
|
|
351
|
+
|
|
352
|
+
Parameters
|
|
353
|
+
----------
|
|
354
|
+
`images` : torch.Tensor
|
|
355
|
+
Input images, shape (batch_size, channels, height, width).
|
|
356
|
+
`texts` : List[str]
|
|
357
|
+
List of text prompts for conditioning.
|
|
358
|
+
|
|
359
|
+
Returns
|
|
360
|
+
-------
|
|
361
|
+
loss : torch.Tensor
|
|
362
|
+
Loss value computed between predicted and target embeddings.
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
with torch.no_grad():
|
|
366
|
+
# Encode text and image with CLIP
|
|
367
|
+
text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize_clip_embeddings)
|
|
368
|
+
image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize_clip_embeddings)
|
|
369
|
+
|
|
370
|
+
#print("encoded images: ", image_embeddings.size())
|
|
371
|
+
#print("encoded text: ", text_embeddings.size())
|
|
372
|
+
|
|
373
|
+
# Reduce dimensionality (optional)
|
|
374
|
+
if self.reduce_clip_embedding_dim:
|
|
375
|
+
text_embeddings = self.prior_model.clip_text_projection(text_embeddings)
|
|
376
|
+
image_embeddings = self.prior_model.clip_image_projection(image_embeddings)
|
|
377
|
+
#print("encoded images: ", image_embeddings.size())
|
|
378
|
+
#print("encoded text: ", text_embeddings.size())
|
|
379
|
+
|
|
380
|
+
# Sample timestep t ~ Uniform(1, T)
|
|
381
|
+
batch_size = image_embeddings.shape[0]
|
|
382
|
+
timesteps = torch.randint(0, self.prior_model.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
|
|
383
|
+
#print("time ", timesteps.size())
|
|
384
|
+
|
|
385
|
+
# Sample noise ε ~ N(0, I)
|
|
386
|
+
noise = torch.randn_like(image_embeddings)
|
|
387
|
+
#print("noise ", noise.size())
|
|
388
|
+
|
|
389
|
+
# Compute noised embedding z_{i,t}
|
|
390
|
+
noisy_image_embeddings = self.prior_model.forward_diffusion(image_embeddings, noise, timesteps)
|
|
391
|
+
#print("noisy image: ", noisy_image_embeddings.size())
|
|
392
|
+
|
|
393
|
+
# Predict unnoised embedding ẑ_i
|
|
394
|
+
predicted_image_embeddings = self.prior_model(text_embeddings, noisy_image_embeddings, timesteps)
|
|
395
|
+
|
|
396
|
+
# Transform back to original space if using dimension reduction
|
|
397
|
+
if self.reduce_clip_embedding_dim:
|
|
398
|
+
predicted_image_embeddings = self.prior_model.image_projection.inverse_transform(predicted_image_embeddings)
|
|
399
|
+
target_embeddings = self.prior_model.image_projection.inverse_transform(image_embeddings)
|
|
400
|
+
else:
|
|
401
|
+
target_embeddings = image_embeddings
|
|
402
|
+
|
|
403
|
+
# Compute loss L = ||ẑ_i - z_i||²
|
|
404
|
+
loss = self.objective(predicted_image_embeddings, target_embeddings)
|
|
405
|
+
return loss
|
|
406
|
+
|
|
407
|
+
def _optimizer_step(self, scaler: torch.GradScaler) -> None:
|
|
408
|
+
"""Performs an optimizer step with gradient clipping.
|
|
409
|
+
|
|
410
|
+
Applies gradient clipping, updates the optimizer with scaled gradients, and resets
|
|
411
|
+
gradients for the next iteration.
|
|
412
|
+
|
|
413
|
+
Parameters
|
|
414
|
+
----------
|
|
415
|
+
`scaler` : torch.GradScaler
|
|
416
|
+
Gradient scaler for mixed precision training.
|
|
417
|
+
"""
|
|
418
|
+
scaler.unscale_(self.optimizer)
|
|
419
|
+
|
|
420
|
+
# Gradient clipping
|
|
421
|
+
torch.nn.utils.clip_grad_norm_(self.prior_model.parameters(), max_norm=1.0)
|
|
422
|
+
|
|
423
|
+
scaler.step(self.optimizer)
|
|
424
|
+
scaler.update()
|
|
425
|
+
self.optimizer.zero_grad()
|
|
426
|
+
|
|
427
|
+
def _compute_mean_loss(self, losses: List[float]) -> float:
|
|
428
|
+
"""Computes the mean loss and synchronizes across processes if using DDP.
|
|
429
|
+
|
|
430
|
+
Calculates the mean of the provided loss values and performs an all-reduce operation
|
|
431
|
+
in DDP mode to synchronize the loss across processes.
|
|
432
|
+
|
|
433
|
+
Parameters
|
|
434
|
+
----------
|
|
435
|
+
`losses` : List[float]
|
|
436
|
+
List of loss values from a training or validation epoch.
|
|
437
|
+
|
|
438
|
+
Returns
|
|
439
|
+
-------
|
|
440
|
+
mean_loss : float
|
|
441
|
+
Mean loss value, synchronized across processes if DDP is enabled.
|
|
442
|
+
"""
|
|
443
|
+
mean_loss = torch.tensor(losses).mean().item()
|
|
444
|
+
|
|
445
|
+
if self.use_ddp:
|
|
446
|
+
loss_tensor = torch.tensor(mean_loss, device=self.device)
|
|
447
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
|
448
|
+
mean_loss = loss_tensor.item()
|
|
449
|
+
|
|
450
|
+
return mean_loss
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def validate(self) -> float:
|
|
454
|
+
"""Validates the UnCLIP prior model.
|
|
455
|
+
|
|
456
|
+
Computes the validation loss by encoding images and text, applying forward diffusion,
|
|
457
|
+
predicting clean embeddings, and comparing with target embeddings.
|
|
458
|
+
|
|
459
|
+
Returns
|
|
460
|
+
-------
|
|
461
|
+
val_loss : float
|
|
462
|
+
Mean validation loss, synchronized across processes if DDP is enabled.
|
|
463
|
+
"""
|
|
464
|
+
|
|
465
|
+
self.prior_model.eval()
|
|
466
|
+
|
|
467
|
+
val_losses = []
|
|
468
|
+
|
|
469
|
+
with torch.no_grad():
|
|
470
|
+
for images, texts in self.val_loader:
|
|
471
|
+
images = images.to(self.device, non_blocking=True)
|
|
472
|
+
|
|
473
|
+
# Get embeddings
|
|
474
|
+
text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize_clip_embeddings)
|
|
475
|
+
image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize_clip_embeddings)
|
|
476
|
+
original_image_embeddings = image_embeddings.clone()
|
|
477
|
+
|
|
478
|
+
if self.reduce_clip_embedding_dim:
|
|
479
|
+
text_embeddings = self.prior_model.text_projection(text_embeddings)
|
|
480
|
+
image_embeddings = self.prior_model.image_projection(image_embeddings)
|
|
481
|
+
|
|
482
|
+
# Forward diffusion
|
|
483
|
+
batch_size = image_embeddings.shape[0]
|
|
484
|
+
timesteps = torch.randint(0, self.prior_model.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
|
|
485
|
+
noise = torch.randn_like(image_embeddings)
|
|
486
|
+
noisy_image_embeddings = self.prior_model.forward_diffusion(image_embeddings, noise, timesteps)
|
|
487
|
+
|
|
488
|
+
# Predict
|
|
489
|
+
predicted_embeddings = self.prior_model(text_embeddings, noisy_image_embeddings, timesteps)
|
|
490
|
+
|
|
491
|
+
if self.reduce_clip_embedding_dim:
|
|
492
|
+
predicted_embeddings = self.prior_model.image_projection.inverse_transform(predicted_embeddings)
|
|
493
|
+
|
|
494
|
+
# Compute loss
|
|
495
|
+
loss = self.objective(predicted_embeddings, original_image_embeddings)
|
|
496
|
+
val_losses.append(loss.item())
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
# Compute averages
|
|
500
|
+
val_loss = self._compute_mean_loss(val_losses)
|
|
501
|
+
|
|
502
|
+
# Return to training mode
|
|
503
|
+
self.prior_model.train()
|
|
504
|
+
|
|
505
|
+
return val_loss
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "", is_best: bool = False) -> None:
|
|
509
|
+
"""Saves a model checkpoint.
|
|
510
|
+
|
|
511
|
+
Saves the state of the prior model and optimizer to a checkpoint file, with options
|
|
512
|
+
for best model or early stopping checkpoints.
|
|
513
|
+
|
|
514
|
+
Parameters
|
|
515
|
+
----------
|
|
516
|
+
`epoch` : int
|
|
517
|
+
Current epoch number.
|
|
518
|
+
`loss` : float
|
|
519
|
+
Current loss value.
|
|
520
|
+
`suffix` : str, optional
|
|
521
|
+
Suffix to append to the checkpoint filename, default "".
|
|
522
|
+
`is_best` : bool, optional
|
|
523
|
+
Whether to save the checkpoint as the best model, default False.
|
|
524
|
+
"""
|
|
525
|
+
try:
|
|
526
|
+
# Get state dicts
|
|
527
|
+
prior_state = (
|
|
528
|
+
self.prior_model.module.state_dict() if self.use_ddp
|
|
529
|
+
else self.prior_model.state_dict()
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
checkpoint = {
|
|
533
|
+
'epoch': epoch,
|
|
534
|
+
'prior_model_state_dict': prior_state,
|
|
535
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
536
|
+
'loss': loss,
|
|
537
|
+
'max_epochs': self.max_epochs,
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
# Create the directory if it doesn't exist
|
|
541
|
+
os.makedirs(self.store_path, exist_ok=True)
|
|
542
|
+
|
|
543
|
+
# Define the checkpoint filename
|
|
544
|
+
if is_best:
|
|
545
|
+
filename = "best_model.pth"
|
|
546
|
+
else:
|
|
547
|
+
filename = f"checkpoint_epoch_{epoch}{suffix}.pth"
|
|
548
|
+
|
|
549
|
+
# Construct the full save path
|
|
550
|
+
save_path = os.path.join(self.store_path, filename)
|
|
551
|
+
|
|
552
|
+
# Save checkpoint
|
|
553
|
+
torch.save(checkpoint, save_path)
|
|
554
|
+
if self.master_process: # Only print from the master process in DDP
|
|
555
|
+
print(f"Checkpoint saved: {save_path}")
|
|
556
|
+
|
|
557
|
+
except Exception as e:
|
|
558
|
+
print(f"Failed to save checkpoint: {e}")
|
|
559
|
+
|
|
560
|
+
def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
|
|
561
|
+
"""Loads a model checkpoint to resume training.
|
|
562
|
+
|
|
563
|
+
Restores the prior model and optimizer states from a saved checkpoint, handling
|
|
564
|
+
DDP compatibility for state dictionaries.
|
|
565
|
+
|
|
566
|
+
Parameters
|
|
567
|
+
----------
|
|
568
|
+
`checkpoint_path` : str
|
|
569
|
+
Path to the checkpoint file.
|
|
570
|
+
|
|
571
|
+
Returns
|
|
572
|
+
-------
|
|
573
|
+
epoch : int
|
|
574
|
+
The epoch at which the checkpoint was saved.
|
|
575
|
+
loss : float
|
|
576
|
+
The loss value at the checkpoint.
|
|
577
|
+
"""
|
|
578
|
+
try:
|
|
579
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
580
|
+
except FileNotFoundError:
|
|
581
|
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
582
|
+
|
|
583
|
+
# Load prior model
|
|
584
|
+
if 'prior_model_state_dict' in checkpoint:
|
|
585
|
+
state_dict = checkpoint['prior_model_state_dict']
|
|
586
|
+
|
|
587
|
+
# Handle DDP state dict compatibility
|
|
588
|
+
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
|
|
589
|
+
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
|
|
590
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
591
|
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
592
|
+
|
|
593
|
+
self.prior_model.load_state_dict(state_dict)
|
|
594
|
+
|
|
595
|
+
# Load optimizer
|
|
596
|
+
if 'optimizer_state_dict' in checkpoint:
|
|
597
|
+
try:
|
|
598
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
599
|
+
except Exception as e:
|
|
600
|
+
warnings.warn(f"Failed to load optimizer state: {e}")
|
|
601
|
+
|
|
602
|
+
epoch = checkpoint.get('epoch', 0)
|
|
603
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
604
|
+
|
|
605
|
+
if self.master_process:
|
|
606
|
+
print(f"Loaded checkpoint from {checkpoint_path} (epoch {epoch}, loss {loss:.4f})")
|
|
607
|
+
|
|
608
|
+
return epoch, loss
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
"""
|
|
613
|
+
from prior_diff import ForwardUnCLIP, ReverseUnCLIP, VarianceSchedulerUnCLIP
|
|
614
|
+
from prior_model import UnCLIPTransformerPrior
|
|
615
|
+
from clip_model import CLIPEncoder
|
|
616
|
+
from project_prior import Projection
|
|
617
|
+
from torchvision import datasets, transforms
|
|
618
|
+
from torch.utils.data import DataLoader, Subset, Dataset
|
|
619
|
+
import torch
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
# Option 2A: Use CIFAR-10 with descriptive captions
|
|
623
|
+
class CIFAR10WithCaptions(Dataset):
|
|
624
|
+
def __init__(self, cifar_dataset):
|
|
625
|
+
self.dataset = cifar_dataset
|
|
626
|
+
self.class_names = [
|
|
627
|
+
'airplane', 'automobile', 'bird', 'cat', 'deer',
|
|
628
|
+
'dog', 'frog', 'horse', 'ship', 'truck'
|
|
629
|
+
]
|
|
630
|
+
# More descriptive templates
|
|
631
|
+
self.templates = [
|
|
632
|
+
"A photo of a {}",
|
|
633
|
+
"An image of a {}",
|
|
634
|
+
"A picture of a {}",
|
|
635
|
+
"This is a {}",
|
|
636
|
+
]
|
|
637
|
+
|
|
638
|
+
def __len__(self):
|
|
639
|
+
return len(self.dataset)
|
|
640
|
+
|
|
641
|
+
def __getitem__(self, idx):
|
|
642
|
+
image, label = self.dataset[idx]
|
|
643
|
+
class_name = self.class_names[label]
|
|
644
|
+
# Use different templates for variety
|
|
645
|
+
template = self.templates[idx % len(self.templates)]
|
|
646
|
+
caption = template.format(class_name)
|
|
647
|
+
return image, caption
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
# Updated transforms for CLIP
|
|
651
|
+
transform = transforms.Compose([
|
|
652
|
+
transforms.Resize((224, 224)),
|
|
653
|
+
transforms.ToTensor(),
|
|
654
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
655
|
+
])
|
|
656
|
+
|
|
657
|
+
# Load CIFAR-10 with captions
|
|
658
|
+
cifar_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
|
|
659
|
+
cifar_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
|
|
660
|
+
|
|
661
|
+
train_dataset = CIFAR10WithCaptions(cifar_train)
|
|
662
|
+
test_dataset = CIFAR10WithCaptions(cifar_test)
|
|
663
|
+
|
|
664
|
+
# Small subset for testing
|
|
665
|
+
train_subset_indices = torch.randperm(len(train_dataset))[:100]
|
|
666
|
+
test_subset_indices = torch.randperm(len(test_dataset))[:20]
|
|
667
|
+
|
|
668
|
+
train_subset = Subset(train_dataset, train_subset_indices)
|
|
669
|
+
test_subset = Subset(test_dataset, test_subset_indices)
|
|
670
|
+
|
|
671
|
+
# DataLoaders
|
|
672
|
+
t_loader = DataLoader(train_subset, batch_size=32, shuffle=True, pin_memory=True)
|
|
673
|
+
val = DataLoader(test_subset, batch_size=10, shuffle=False, pin_memory=True)
|
|
674
|
+
|
|
675
|
+
h_model = VarianceSchedulerUnCLIP(
|
|
676
|
+
num_steps=1000,
|
|
677
|
+
beta_start=1e-4,
|
|
678
|
+
beta_end=0.02,
|
|
679
|
+
trainable_beta=True,
|
|
680
|
+
beta_method="cosine"
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
c_model = CLIPEncoder(model_name="openai/clip-vit-base-patch32")
|
|
684
|
+
tp = Projection(
|
|
685
|
+
input_dim=512,
|
|
686
|
+
output_dim=320,
|
|
687
|
+
hidden_dim=480,
|
|
688
|
+
num_layers=2,
|
|
689
|
+
dropout=0.1,
|
|
690
|
+
use_layer_norm=True
|
|
691
|
+
)
|
|
692
|
+
ip = Projection(
|
|
693
|
+
input_dim=512,
|
|
694
|
+
output_dim=320,
|
|
695
|
+
hidden_dim=480,
|
|
696
|
+
num_layers=2,
|
|
697
|
+
dropout=0.1,
|
|
698
|
+
use_layer_norm=True
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
d_model = ForwardUnCLIP(h_model)
|
|
702
|
+
r_model = ReverseUnCLIP(h_model)
|
|
703
|
+
|
|
704
|
+
p_model = UnCLIPTransformerPrior(
|
|
705
|
+
forward_diffusion=d_model,
|
|
706
|
+
reverse_diffusion=r_model, # will be used during training
|
|
707
|
+
text_projection=tp, # used during training instead of PCA in the main paper
|
|
708
|
+
image_projection=ip,
|
|
709
|
+
embedding_dim=320,
|
|
710
|
+
num_layers=12,
|
|
711
|
+
num_attention_heads=8,
|
|
712
|
+
feedforward_dim=512,
|
|
713
|
+
max_sequence_length=2,
|
|
714
|
+
dropout_rate=0.3
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
opt = torch.optim.AdamW([p for p in p_model.parameters() if p.requires_grad], lr=1e-3)
|
|
720
|
+
|
|
721
|
+
models = [h_model, p_model, tp, ip]
|
|
722
|
+
|
|
723
|
+
total_params = 0
|
|
724
|
+
for model in models:
|
|
725
|
+
total_params += sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
726
|
+
print(total_params)
|
|
727
|
+
|
|
728
|
+
obj = nn.MSELoss()
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
train = TrainUnCLIPPrior(
|
|
733
|
+
prior_model=p_model,
|
|
734
|
+
clip_model=c_model,
|
|
735
|
+
train_loader=t_loader,
|
|
736
|
+
optimizer=opt,
|
|
737
|
+
objective=obj,
|
|
738
|
+
val_loader=val,
|
|
739
|
+
max_epochs=5,
|
|
740
|
+
device="cuda",
|
|
741
|
+
store_path="prior",
|
|
742
|
+
patience=3,
|
|
743
|
+
warmup_epochs=2,
|
|
744
|
+
val_frequency=3,
|
|
745
|
+
use_ddp=False,
|
|
746
|
+
num_grad_accumulation=2,
|
|
747
|
+
progress_frequency=1,
|
|
748
|
+
compilation=False,
|
|
749
|
+
output_range=(-1.0, 1.0),
|
|
750
|
+
reduce_dim=True,
|
|
751
|
+
output_dim=320,
|
|
752
|
+
normalize=True
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
train_losses, best_val_loss = train()
|
|
756
|
+
"""
|
|
757
|
+
|