TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
unclip/train_decoder.py
ADDED
|
@@ -0,0 +1,1059 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Optional, List, Tuple, Union, Callable, Any
|
|
4
|
+
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
|
5
|
+
import torch.distributed as dist
|
|
6
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
7
|
+
from torch.distributed import init_process_group, destroy_process_group
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
import os
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TrainUnClipDecoder(nn.Module):
|
|
16
|
+
"""Trainer for the UnCLIP decoder model.
|
|
17
|
+
|
|
18
|
+
Orchestrates the training of the UnCLIP decoder model, integrating CLIP embeddings, forward
|
|
19
|
+
and reverse diffusion processes, and optional dimensionality reduction. Supports mixed
|
|
20
|
+
precision, gradient accumulation, DDP, and comprehensive evaluation metrics.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
`clip_embedding_dim` : int
|
|
25
|
+
Dimensionality of the input embeddings.
|
|
26
|
+
`decoder_model` : nn.Module
|
|
27
|
+
The UnCLIP decoder model (e.g., UnClipDecoder) to be trained.
|
|
28
|
+
`clip_model` : nn.Module
|
|
29
|
+
CLIP model for generating text and image embeddings.
|
|
30
|
+
`train_loader` : torch.utils.data.DataLoader
|
|
31
|
+
DataLoader for training data.
|
|
32
|
+
`optimizer` : torch.optim.Optimizer
|
|
33
|
+
Optimizer for training the decoder model.
|
|
34
|
+
`objective` : Callable
|
|
35
|
+
Loss function to compute the difference between predicted and target noise.
|
|
36
|
+
`clip_text_projection` : nn.Module, optional
|
|
37
|
+
Projection module for text embeddings, default None.
|
|
38
|
+
`clip_image_projection` : nn.Module, optional
|
|
39
|
+
Projection module for image embeddings, default None.
|
|
40
|
+
`val_loader` : torch.utils.data.DataLoader, optional
|
|
41
|
+
DataLoader for validation data, default None.
|
|
42
|
+
`metrics_` : Any, optional
|
|
43
|
+
Object providing evaluation metrics (e.g., FID, MSE, PSNR, SSIM, LPIPS), default None.
|
|
44
|
+
`max_epochs` : int, optional
|
|
45
|
+
Maximum number of training epochs (default: 1000).
|
|
46
|
+
`device` : Union[str, torch.device], optional
|
|
47
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
48
|
+
`store_path` : str, optional
|
|
49
|
+
Directory to save model checkpoints (default: "unclip_decoder").
|
|
50
|
+
`patience` : int, optional
|
|
51
|
+
Number of epochs to wait for improvement before early stopping (default: 100).
|
|
52
|
+
`warmup_epochs` : int, optional
|
|
53
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
54
|
+
`val_frequency` : int, optional
|
|
55
|
+
Frequency (in epochs) for validation (default: 10).
|
|
56
|
+
`use_ddp` : bool, optional
|
|
57
|
+
Whether to use Distributed Data Parallel training (default: False).
|
|
58
|
+
`grad_accumulation_steps` : int, optional
|
|
59
|
+
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
60
|
+
`log_frequency` : int, optional
|
|
61
|
+
Frequency (in epochs) for printing progress (default: 1).
|
|
62
|
+
`use_compilation` : bool, optional
|
|
63
|
+
Whether to compile the model using torch.compile (default: False).
|
|
64
|
+
`image_output_range` : Tuple[float, float], optional
|
|
65
|
+
Range for clamping output images (default: (-1.0, 1.0)).
|
|
66
|
+
`reduce_clip_embedding_dim` : bool, optional
|
|
67
|
+
Whether to apply dimensionality reduction to embeddings (default: True).
|
|
68
|
+
`transformer_embedding_dim` : int, optional
|
|
69
|
+
Output dimensionality for reduced embeddings (default: 312).
|
|
70
|
+
`normalize_clip_embeddings` : bool, optional
|
|
71
|
+
Whether to normalize CLIP embeddings (default: True).
|
|
72
|
+
`finetune_clip_projections` : bool, optional
|
|
73
|
+
Whether to fine-tune projection layers (default: False).
|
|
74
|
+
"""
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
clip_embedding_dim: int,
|
|
78
|
+
decoder_model: nn.Module,
|
|
79
|
+
clip_model: nn.Module,
|
|
80
|
+
train_loader: torch.utils.data.DataLoader,
|
|
81
|
+
optimizer: torch.optim.Optimizer,
|
|
82
|
+
objective: Callable,
|
|
83
|
+
clip_text_projection: Optional[nn.Module] = None,
|
|
84
|
+
clip_image_projection: Optional[nn.Module] = None,
|
|
85
|
+
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
86
|
+
metrics_: Optional[Any] = None,
|
|
87
|
+
max_epochs: int = 1000,
|
|
88
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
89
|
+
store_path: str = "unclip_decoder",
|
|
90
|
+
patience: int = 100,
|
|
91
|
+
warmup_epochs: int = 100,
|
|
92
|
+
val_frequency: int = 10,
|
|
93
|
+
use_ddp: bool = False,
|
|
94
|
+
grad_accumulation_steps: int = 1,
|
|
95
|
+
log_frequency: int = 1,
|
|
96
|
+
use_compilation: bool = False,
|
|
97
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0),
|
|
98
|
+
reduce_clip_embedding_dim: bool = True,
|
|
99
|
+
transformer_embedding_dim: int = 312,
|
|
100
|
+
normalize_clip_embeddings: bool = True,
|
|
101
|
+
finetune_clip_projections: bool = False # if text_projection and image_projection model should be finetune
|
|
102
|
+
):
|
|
103
|
+
super().__init__()
|
|
104
|
+
# training configuration
|
|
105
|
+
self.use_ddp = use_ddp
|
|
106
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
107
|
+
self.use_compilation = use_compilation
|
|
108
|
+
if device is None:
|
|
109
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
110
|
+
elif isinstance(device, str):
|
|
111
|
+
self.device = torch.device(device)
|
|
112
|
+
else:
|
|
113
|
+
self.device = device
|
|
114
|
+
|
|
115
|
+
# core models
|
|
116
|
+
self.decoder_model = decoder_model.to(self.device)
|
|
117
|
+
self.clip_model = clip_model.to(self.device)
|
|
118
|
+
|
|
119
|
+
self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
|
|
120
|
+
|
|
121
|
+
# setup distributed training
|
|
122
|
+
if self.use_ddp:
|
|
123
|
+
self._setup_ddp()
|
|
124
|
+
else:
|
|
125
|
+
self._setup_single_gpu()
|
|
126
|
+
|
|
127
|
+
# compile and wrap models
|
|
128
|
+
self._compile_models()
|
|
129
|
+
self._wrap_models_for_ddp()
|
|
130
|
+
|
|
131
|
+
# projection models (PCA equivalent in the paper)
|
|
132
|
+
if self.reduce_clip_embedding_dim and clip_text_projection is not None and clip_image_projection is not None:
|
|
133
|
+
self.clip_text_projection = clip_text_projection.to(self.device)
|
|
134
|
+
self.clip_image_projection = clip_image_projection.to(self.device)
|
|
135
|
+
else:
|
|
136
|
+
self.clip_text_projection = None
|
|
137
|
+
self.clip_image_projection = None
|
|
138
|
+
|
|
139
|
+
# training components
|
|
140
|
+
self.clip_embedding_dim = transformer_embedding_dim if self.reduce_clip_embedding_dim else clip_embedding_dim
|
|
141
|
+
self.metrics_ = metrics_
|
|
142
|
+
self.optimizer = optimizer
|
|
143
|
+
self.objective = objective
|
|
144
|
+
self.train_loader = train_loader
|
|
145
|
+
self.val_loader = val_loader
|
|
146
|
+
|
|
147
|
+
# training parameters
|
|
148
|
+
self.max_epochs = max_epochs
|
|
149
|
+
self.patience = patience
|
|
150
|
+
self.val_frequency = val_frequency
|
|
151
|
+
self.log_frequency = log_frequency
|
|
152
|
+
self.image_output_range = image_output_range
|
|
153
|
+
self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
|
|
154
|
+
self.normalize_clip_embeddings = normalize_clip_embeddings
|
|
155
|
+
self.transformer_embedding_dim = transformer_embedding_dim
|
|
156
|
+
self.finetune_clip_projections = finetune_clip_projections
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# checkpoint management
|
|
160
|
+
self.store_path = store_path
|
|
161
|
+
|
|
162
|
+
# learning rate scheduling
|
|
163
|
+
self.scheduler = ReduceLROnPlateau(
|
|
164
|
+
self.optimizer,
|
|
165
|
+
patience=self.patience,
|
|
166
|
+
factor=0.5
|
|
167
|
+
)
|
|
168
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
169
|
+
|
|
170
|
+
def forward(self) -> Tuple[List[float], float]:
|
|
171
|
+
"""Trains the UnCLIP decoder model to predict noise for denoising.
|
|
172
|
+
|
|
173
|
+
Executes the training loop, optimizing the decoder model using CLIP embeddings, mixed
|
|
174
|
+
precision, gradient clipping, and learning rate scheduling. Supports validation, early
|
|
175
|
+
stopping, and checkpointing.
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
train_losses : List[float]
|
|
180
|
+
List of mean training losses per epoch.
|
|
181
|
+
best_val_loss : float
|
|
182
|
+
Best validation or training loss achieved.
|
|
183
|
+
"""
|
|
184
|
+
# set models to training mode
|
|
185
|
+
self.decoder_model.train() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj to train mode
|
|
186
|
+
if not self.decoder_model.forward_diffusion.variance_scheduler.trainable_beta: # ff beta is not trainable
|
|
187
|
+
self.decoder_model.forward_diffusion.variance_scheduler.eval()
|
|
188
|
+
|
|
189
|
+
# set text_projection and image_projection to train mode if fine-tuning
|
|
190
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
191
|
+
if self.finetune_clip_projections:
|
|
192
|
+
self.clip_text_projection.train()
|
|
193
|
+
self.clip_image_projection.train()
|
|
194
|
+
else:
|
|
195
|
+
self.clip_text_projection.eval()
|
|
196
|
+
self.clip_image_projection.eval()
|
|
197
|
+
|
|
198
|
+
# set CLIP model to eval mode (frozen)
|
|
199
|
+
if self.clip_model is not None:
|
|
200
|
+
self.clip_model.eval()
|
|
201
|
+
|
|
202
|
+
# initialize training components
|
|
203
|
+
scaler = torch.GradScaler()
|
|
204
|
+
train_losses = []
|
|
205
|
+
best_val_loss = float("inf")
|
|
206
|
+
wait = 0
|
|
207
|
+
|
|
208
|
+
# main training loop
|
|
209
|
+
for epoch in range(self.max_epochs):
|
|
210
|
+
# set epoch for distributed sampler if using DDP
|
|
211
|
+
if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
|
|
212
|
+
self.train_loader.sampler.set_epoch(epoch)
|
|
213
|
+
|
|
214
|
+
train_losses_epoch = []
|
|
215
|
+
|
|
216
|
+
# training step loop with gradient accumulation
|
|
217
|
+
for step, (images, texts) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
|
|
218
|
+
images = images.to(self.device, non_blocking=True)
|
|
219
|
+
|
|
220
|
+
# forward pass with mixed precision
|
|
221
|
+
with torch.autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
|
|
222
|
+
# encode text and image with CLIP
|
|
223
|
+
text_embeddings, image_embeddings = self._get_clip_embeddings(images, texts)
|
|
224
|
+
|
|
225
|
+
# reduce dimensionality (PCA equivalent)
|
|
226
|
+
text_embeddings, image_embeddings = self._apply_dimensionality_reduction(
|
|
227
|
+
text_embeddings, image_embeddings
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# use decoder model to predict noise
|
|
231
|
+
p_classifier_free = torch.rand(1).item()
|
|
232
|
+
p_text_drop = torch.rand(1).item()
|
|
233
|
+
predicted_noise, noise = self.decoder_model(
|
|
234
|
+
image_embeddings,
|
|
235
|
+
text_embeddings,
|
|
236
|
+
images,
|
|
237
|
+
texts,
|
|
238
|
+
p_classifier_free,
|
|
239
|
+
p_text_drop
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# compute loss
|
|
243
|
+
loss = self.objective(predicted_noise, noise) / self.num_grad_accumulation
|
|
244
|
+
|
|
245
|
+
scaler.scale(loss).backward()
|
|
246
|
+
|
|
247
|
+
if (step + 1) % self.num_grad_accumulation == 0:
|
|
248
|
+
# clip gradients
|
|
249
|
+
scaler.unscale_(self.optimizer)
|
|
250
|
+
torch.nn.utils.clip_grad_norm_(self.decoder_model.parameters(), max_norm=1.0) # covers all submodules
|
|
251
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
|
|
252
|
+
torch.nn.utils.clip_grad_norm_(self.clip_text_projection.parameters(), max_norm=1.0)
|
|
253
|
+
torch.nn.utils.clip_grad_norm_(self.clip_image_projection.parameters(), max_norm=1.0)
|
|
254
|
+
|
|
255
|
+
scaler.step(self.optimizer)
|
|
256
|
+
scaler.update()
|
|
257
|
+
self.optimizer.zero_grad()
|
|
258
|
+
self.warmup_lr_scheduler.step()
|
|
259
|
+
torch.cuda.empty_cache() # clear memory after optimizer step
|
|
260
|
+
|
|
261
|
+
train_losses_epoch.append(loss.item() * self.num_grad_accumulation)
|
|
262
|
+
|
|
263
|
+
mean_train_loss = self._compute_mean_loss(train_losses_epoch)
|
|
264
|
+
train_losses.append(mean_train_loss)
|
|
265
|
+
|
|
266
|
+
if self.master_process and (epoch + 1) % self.log_frequency == 0:
|
|
267
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
268
|
+
print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
|
|
269
|
+
|
|
270
|
+
current_loss = mean_train_loss
|
|
271
|
+
|
|
272
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
273
|
+
val_metrics = self.validate()
|
|
274
|
+
val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
|
|
275
|
+
|
|
276
|
+
if self.master_process:
|
|
277
|
+
print(f" | Val Loss: {val_loss:.4f}", end="")
|
|
278
|
+
if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
|
|
279
|
+
print(f" | FID: {fid:.4f}", end="")
|
|
280
|
+
if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
|
|
281
|
+
print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
|
|
282
|
+
if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
|
|
283
|
+
print(f" | LPIPS: {lpips_score:.4f}", end="")
|
|
284
|
+
print()
|
|
285
|
+
|
|
286
|
+
self.scheduler.step(current_loss)
|
|
287
|
+
|
|
288
|
+
if self.master_process:
|
|
289
|
+
if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
|
|
290
|
+
best_val_loss = current_loss
|
|
291
|
+
wait = 0
|
|
292
|
+
self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
|
|
293
|
+
else:
|
|
294
|
+
wait += 1
|
|
295
|
+
if wait >= self.patience:
|
|
296
|
+
print("Early stopping triggered")
|
|
297
|
+
self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
|
|
298
|
+
break
|
|
299
|
+
|
|
300
|
+
if self.use_ddp:
|
|
301
|
+
destroy_process_group()
|
|
302
|
+
|
|
303
|
+
return train_losses, best_val_loss
|
|
304
|
+
|
|
305
|
+
def _setup_ddp(self) -> None:
|
|
306
|
+
"""Sets up Distributed Data Parallel training configuration.
|
|
307
|
+
|
|
308
|
+
Initializes the process group, sets up rank information, and configures the CUDA
|
|
309
|
+
device for the current process in DDP mode.
|
|
310
|
+
"""
|
|
311
|
+
required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
|
|
312
|
+
for var in required_env_vars:
|
|
313
|
+
if var not in os.environ:
|
|
314
|
+
raise ValueError(f"DDP enabled but {var} environment variable not set")
|
|
315
|
+
|
|
316
|
+
if not torch.cuda.is_available():
|
|
317
|
+
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
318
|
+
|
|
319
|
+
if not torch.distributed.is_initialized():
|
|
320
|
+
init_process_group(backend="nccl")
|
|
321
|
+
|
|
322
|
+
self.ddp_rank = int(os.environ["RANK"])
|
|
323
|
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
|
324
|
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"])
|
|
325
|
+
|
|
326
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
327
|
+
torch.cuda.set_device(self.device)
|
|
328
|
+
|
|
329
|
+
self.master_process = self.ddp_rank == 0
|
|
330
|
+
|
|
331
|
+
if self.master_process:
|
|
332
|
+
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
333
|
+
|
|
334
|
+
def _setup_single_gpu(self) -> None:
|
|
335
|
+
"""Sets up single GPU or CPU training configuration.
|
|
336
|
+
|
|
337
|
+
Configures the training setup for single-device operation, setting rank and process
|
|
338
|
+
information for non-DDP training.
|
|
339
|
+
"""
|
|
340
|
+
self.ddp_rank = 0
|
|
341
|
+
self.ddp_local_rank = 0
|
|
342
|
+
self.ddp_world_size = 1
|
|
343
|
+
self.master_process = True
|
|
344
|
+
|
|
345
|
+
@staticmethod
|
|
346
|
+
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
|
|
347
|
+
"""Creates a learning rate scheduler for warmup.
|
|
348
|
+
|
|
349
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
350
|
+
optimizer's initial value over the specified warmup epochs.
|
|
351
|
+
|
|
352
|
+
Parameters
|
|
353
|
+
----------
|
|
354
|
+
`optimizer` : torch.optim.Optimizer
|
|
355
|
+
Optimizer to apply the scheduler to.
|
|
356
|
+
`warmup_epochs` : int
|
|
357
|
+
Number of epochs for the warmup phase.
|
|
358
|
+
|
|
359
|
+
Returns
|
|
360
|
+
-------
|
|
361
|
+
lr_scheduler : torch.optim.lr_scheduler.LambdaLR
|
|
362
|
+
Learning rate scheduler for warmup.
|
|
363
|
+
"""
|
|
364
|
+
def lr_lambda(epoch):
|
|
365
|
+
return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
|
|
366
|
+
|
|
367
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
368
|
+
|
|
369
|
+
def _wrap_models_for_ddp(self) -> None:
|
|
370
|
+
"""Wraps models with DistributedDataParallel for multi-GPU training.
|
|
371
|
+
|
|
372
|
+
Configures the decoder model and, if fine-tuning, the projection models for DDP training.
|
|
373
|
+
"""
|
|
374
|
+
if self.use_ddp:
|
|
375
|
+
self.decoder_model = self.decoder_model.to(self.ddp_local_rank)
|
|
376
|
+
self.decoder_model = DDP(
|
|
377
|
+
self.decoder_model,
|
|
378
|
+
device_ids=[self.ddp_local_rank],
|
|
379
|
+
find_unused_parameters=True
|
|
380
|
+
)
|
|
381
|
+
# only wrap text_projection and image_projection if they are trainable
|
|
382
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
|
|
383
|
+
self.clip_text_projection = self.clip_text_projection.to(self.ddp_local_rank)
|
|
384
|
+
self.clip_image_projection = self.clip_image_projection.to(self.ddp_local_rank)
|
|
385
|
+
self.clip_text_projection = DDP(self.clip_text_projection, device_ids=[self.ddp_local_rank])
|
|
386
|
+
self.clip_image_projection = DDP(self.clip_image_projection, device_ids=[self.ddp_local_rank])
|
|
387
|
+
|
|
388
|
+
def _compile_models(self) -> None:
|
|
389
|
+
"""Compiles models for optimization if supported.
|
|
390
|
+
|
|
391
|
+
Attempts to compile the decoder model and, if fine-tuning, the projection models using
|
|
392
|
+
torch.compile for optimization, falling back to uncompiled execution if compilation fails.
|
|
393
|
+
"""
|
|
394
|
+
if self.use_compilation:
|
|
395
|
+
try:
|
|
396
|
+
self.decoder_model = self.decoder_model.to(self.device)
|
|
397
|
+
self.decoder_model = torch.compile(self.decoder_model, mode="reduce-overhead")
|
|
398
|
+
# only compile text_projection and image_projection if they are trainable
|
|
399
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
|
|
400
|
+
self.clip_text_projection = self.clip_text_projection.to(self.device)
|
|
401
|
+
self.clip_image_projection = self.clip_image_projection.to(self.device)
|
|
402
|
+
self.clip_text_projection = torch.compile(self.clip_text_projection, mode="reduce-overhead")
|
|
403
|
+
self.clip_image_projection = torch.compile(self.clip_image_projection, mode="reduce-overhead")
|
|
404
|
+
if self.master_process:
|
|
405
|
+
print("Models compiled successfully")
|
|
406
|
+
except Exception as e:
|
|
407
|
+
if self.master_process:
|
|
408
|
+
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
409
|
+
|
|
410
|
+
def _get_clip_embeddings(
|
|
411
|
+
self,
|
|
412
|
+
images: torch.Tensor,
|
|
413
|
+
texts: Union[List, torch.Tensor]
|
|
414
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
415
|
+
"""Encodes images and texts using the CLIP model.
|
|
416
|
+
|
|
417
|
+
Generates text and image embeddings using the CLIP model, with optional normalization.
|
|
418
|
+
|
|
419
|
+
Parameters
|
|
420
|
+
----------
|
|
421
|
+
`images` : torch.Tensor
|
|
422
|
+
Input images, shape (batch_size, channels, height, width).
|
|
423
|
+
`texts` : Union[List, torch.Tensor]
|
|
424
|
+
Text prompts for conditional generation.
|
|
425
|
+
|
|
426
|
+
Returns
|
|
427
|
+
-------
|
|
428
|
+
text_embeddings : torch.Tensor
|
|
429
|
+
CLIP text embeddings, shape (batch_size, embedding_dim).
|
|
430
|
+
image_embeddings : torch.Tensor
|
|
431
|
+
CLIP image embeddings, shape (batch_size, embedding_dim).
|
|
432
|
+
"""
|
|
433
|
+
with torch.no_grad():
|
|
434
|
+
# encode text y with CLIP text encoder: z_t ← CLIP_text(y)
|
|
435
|
+
text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize)
|
|
436
|
+
# encode image x with CLIP image encoder: z_i ← CLIP_image(x)
|
|
437
|
+
image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize)
|
|
438
|
+
return text_embeddings, image_embeddings
|
|
439
|
+
|
|
440
|
+
def _apply_dimensionality_reduction(
|
|
441
|
+
self,
|
|
442
|
+
text_embeddings: torch.Tensor,
|
|
443
|
+
image_embeddings: torch.Tensor
|
|
444
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
445
|
+
"""Applies dimensionality reduction to embeddings if enabled.
|
|
446
|
+
|
|
447
|
+
Projects text and image embeddings to a lower-dimensional space using learned
|
|
448
|
+
projection layers, mimicking PCA as used in the UnCLIP paper.
|
|
449
|
+
|
|
450
|
+
Parameters
|
|
451
|
+
----------
|
|
452
|
+
`text_embeddings` : torch.Tensor
|
|
453
|
+
CLIP text embeddings, shape (batch_size, embedding_dim).
|
|
454
|
+
`image_embeddings` : torch.Tensor
|
|
455
|
+
CLIP image embeddings, shape (batch_size, embedding_dim).
|
|
456
|
+
|
|
457
|
+
Returns
|
|
458
|
+
-------
|
|
459
|
+
text_embeddings : torch.Tensor
|
|
460
|
+
Projected text embeddings, shape (batch_size, output_dim) if reduced, else unchanged.
|
|
461
|
+
image_embeddings : torch.Tensor
|
|
462
|
+
Projected image embeddings, shape (batch_size, output_dim) if reduced, else unchanged.
|
|
463
|
+
"""
|
|
464
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
465
|
+
if not self.finetune_clip_projections:
|
|
466
|
+
with torch.no_grad():
|
|
467
|
+
text_embeddings = self.clip_text_projection(text_embeddings.to(self.device))
|
|
468
|
+
image_embeddings = self.clip_image_projection(image_embeddings.to(self.device))
|
|
469
|
+
else:
|
|
470
|
+
text_embeddings = self.clip_text_projection(text_embeddings.to(self.device))
|
|
471
|
+
image_embeddings = self.clip_image_projection(image_embeddings.to(self.device))
|
|
472
|
+
return text_embeddings.to(self.device), image_embeddings.to(self.device)
|
|
473
|
+
|
|
474
|
+
def _compute_mean_loss(self, losses: List[float]) -> float:
|
|
475
|
+
"""Computes mean loss with DDP synchronization if needed.
|
|
476
|
+
|
|
477
|
+
Calculates the mean of the provided losses and synchronizes the result across
|
|
478
|
+
processes in DDP mode.
|
|
479
|
+
|
|
480
|
+
Parameters
|
|
481
|
+
----------
|
|
482
|
+
`losses` : List[float]
|
|
483
|
+
List of loss values for the current epoch.
|
|
484
|
+
|
|
485
|
+
Returns
|
|
486
|
+
-------
|
|
487
|
+
mean_loss : float
|
|
488
|
+
Mean loss value, synchronized if using DDP.
|
|
489
|
+
"""
|
|
490
|
+
if not losses:
|
|
491
|
+
return 0.0
|
|
492
|
+
mean_loss = sum(losses) / len(losses)
|
|
493
|
+
if self.use_ddp:
|
|
494
|
+
# synchronize loss across all processes
|
|
495
|
+
loss_tensor = torch.tensor(mean_loss, device=self.device)
|
|
496
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
|
|
497
|
+
mean_loss = (loss_tensor / self.ddp_world_size).item()
|
|
498
|
+
|
|
499
|
+
return mean_loss
|
|
500
|
+
|
|
501
|
+
def _save_checkpoint(self, epoch: int, loss: float, is_best: bool = False, suffix: str = ""):
|
|
502
|
+
"""Saves model checkpoint.
|
|
503
|
+
|
|
504
|
+
Saves the state of the decoder model, its submodules, optimizer, and schedulers,
|
|
505
|
+
with options for best model and epoch-specific checkpoints.
|
|
506
|
+
|
|
507
|
+
Parameters
|
|
508
|
+
----------
|
|
509
|
+
`epoch` : int
|
|
510
|
+
Current epoch number.
|
|
511
|
+
`loss` : float
|
|
512
|
+
Current loss value.
|
|
513
|
+
`is_best` : bool, optional
|
|
514
|
+
Whether to save as the best model checkpoint (default: False).
|
|
515
|
+
`suffix` : str, optional
|
|
516
|
+
Suffix to add to checkpoint filename, default "".
|
|
517
|
+
"""
|
|
518
|
+
if not self.master_process:
|
|
519
|
+
return
|
|
520
|
+
checkpoint = {
|
|
521
|
+
'epoch': epoch,
|
|
522
|
+
'loss': loss,
|
|
523
|
+
# Core models (submodules of decoder_model)
|
|
524
|
+
'noise_predictor_state_dict': self.decoder_model.module.noise_predictor.state_dict() if self.use_ddp else self.decoder_model.noise_predictor.state_dict(),
|
|
525
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
526
|
+
# Training configuration
|
|
527
|
+
'embedding_dim': self.clip_embedding_dim,
|
|
528
|
+
'output_dim': self.transformer_embedding_dim,
|
|
529
|
+
'reduce_dim': self.reduce_clip_embedding_dim,
|
|
530
|
+
'normalize': self.normalize
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
# Save conditional model (submodule of decoder_model)
|
|
534
|
+
if self.decoder_model.conditional_model is not None:
|
|
535
|
+
checkpoint['conditional_model_state_dict'] = (
|
|
536
|
+
self.decoder_model.module.conditional_model.state_dict() if self.use_ddp
|
|
537
|
+
else self.decoder_model.conditional_model.state_dict()
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
# Save variance scheduler (submodule of decoder_model, always saved)
|
|
541
|
+
checkpoint['variance_scheduler_state_dict'] = (
|
|
542
|
+
self.decoder_model.forward_diffusion.module.variance_scheduler.state_dict() if self.use_ddp
|
|
543
|
+
else self.decoder_model.forward_diffusion.variance_scheduler.state_dict()
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Save CLIP time projection layer (submodule of decoder_model)
|
|
547
|
+
checkpoint['clip_time_proj_state_dict'] = (
|
|
548
|
+
self.decoder_model.module.clip_time_proj.state_dict() if self.use_ddp
|
|
549
|
+
else self.decoder_model.clip_time_proj.state_dict()
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
# Save decoder projection layer (submodule of decoder_model)
|
|
553
|
+
checkpoint['decoder_projection_state_dict'] = (
|
|
554
|
+
self.decoder_model.module.decoder_projection.state_dict() if self.use_ddp
|
|
555
|
+
else self.decoder_model.decoder_projection.state_dict()
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# Save projection models (PCA equivalent)
|
|
559
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
560
|
+
checkpoint['text_projection_state_dict'] = (
|
|
561
|
+
self.clip_text_projection.module.state_dict() if self.use_ddp
|
|
562
|
+
else self.clip_text_projection.state_dict()
|
|
563
|
+
)
|
|
564
|
+
checkpoint['image_projection_state_dict'] = (
|
|
565
|
+
self.clip_image_projection.module.state_dict() if self.use_ddp
|
|
566
|
+
else self.clip_image_projection.state_dict()
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
# Save schedulers state
|
|
570
|
+
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
|
571
|
+
checkpoint['warmup_scheduler_state_dict'] = self.warmup_lr_scheduler.state_dict()
|
|
572
|
+
|
|
573
|
+
filename = f"unclip_decoder_epoch_{epoch}{suffix}.pth"
|
|
574
|
+
if is_best:
|
|
575
|
+
filename = f"unclip_decoder_best{suffix}.pth"
|
|
576
|
+
|
|
577
|
+
filepath = os.path.join(self.store_path, filename)
|
|
578
|
+
os.makedirs(self.store_path, exist_ok=True)
|
|
579
|
+
torch.save(checkpoint, filepath)
|
|
580
|
+
|
|
581
|
+
if is_best:
|
|
582
|
+
print(f"Best model saved: {filepath}")
|
|
583
|
+
|
|
584
|
+
def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
|
|
585
|
+
"""Loads model checkpoint.
|
|
586
|
+
|
|
587
|
+
Restores the state of the decoder model, its submodules, optimizer, and schedulers
|
|
588
|
+
from a saved checkpoint, handling DDP compatibility.
|
|
589
|
+
|
|
590
|
+
Parameters
|
|
591
|
+
----------
|
|
592
|
+
`checkpoint_path` : str
|
|
593
|
+
Path to the checkpoint file.
|
|
594
|
+
|
|
595
|
+
Returns
|
|
596
|
+
-------
|
|
597
|
+
epoch : int
|
|
598
|
+
The epoch at which the checkpoint was saved.
|
|
599
|
+
loss : float
|
|
600
|
+
The loss at the checkpoint.
|
|
601
|
+
"""
|
|
602
|
+
try:
|
|
603
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
604
|
+
except FileNotFoundError:
|
|
605
|
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
606
|
+
|
|
607
|
+
def _load_model_state_dict(model: nn.Module, state_dict: dict, model_name: str) -> None:
|
|
608
|
+
"""Helper function to load state dict with DDP compatibility."""
|
|
609
|
+
try:
|
|
610
|
+
# Handle DDP state dict compatibility
|
|
611
|
+
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
|
|
612
|
+
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
|
|
613
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
614
|
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
615
|
+
|
|
616
|
+
model.load_state_dict(state_dict)
|
|
617
|
+
if self.master_process:
|
|
618
|
+
print(f"✓ Loaded {model_name}")
|
|
619
|
+
except Exception as e:
|
|
620
|
+
warnings.warn(f"Failed to load {model_name}: {e}")
|
|
621
|
+
|
|
622
|
+
# Load core noise predictor model (submodule of decoder_model)
|
|
623
|
+
if 'noise_predictor_state_dict' in checkpoint:
|
|
624
|
+
_load_model_state_dict(self.decoder_model.noise_predictor, checkpoint['noise_predictor_state_dict'],
|
|
625
|
+
'noise_predictor')
|
|
626
|
+
|
|
627
|
+
# Load conditional model (submodule of decoder_model)
|
|
628
|
+
if self.decoder_model.conditional_model is not None and 'conditional_model_state_dict' in checkpoint:
|
|
629
|
+
_load_model_state_dict(self.decoder_model.conditional_model, checkpoint['conditional_model_state_dict'],
|
|
630
|
+
'conditional_model')
|
|
631
|
+
|
|
632
|
+
# Load variance scheduler (submodule of decoder_model)
|
|
633
|
+
if 'variance_scheduler_state_dict' in checkpoint:
|
|
634
|
+
state_dict = checkpoint.get('variance_scheduler_state_dict')
|
|
635
|
+
try:
|
|
636
|
+
_load_model_state_dict(self.decoder_model.forward_diffusion.variance_scheduler, state_dict, 'variance_scheduler')
|
|
637
|
+
except Exception as e:
|
|
638
|
+
warnings.warn(f"Failed to load variance scheduler: {e}")
|
|
639
|
+
|
|
640
|
+
# Load CLIP time projection layer (submodule of decoder_model)
|
|
641
|
+
if 'clip_time_proj_state_dict' in checkpoint:
|
|
642
|
+
try:
|
|
643
|
+
_load_model_state_dict(self.decoder_model.clip_time_proj, checkpoint['clip_time_proj_state_dict'],
|
|
644
|
+
'clip_time_proj')
|
|
645
|
+
except Exception as e:
|
|
646
|
+
warnings.warn(f"Failed to load CLIP time projection: {e}")
|
|
647
|
+
|
|
648
|
+
# Load decoder projection layer (submodule of decoder_model)
|
|
649
|
+
if 'decoder_projection_state_dict' in checkpoint:
|
|
650
|
+
try:
|
|
651
|
+
_load_model_state_dict(self.decoder_model.decoder_projection,
|
|
652
|
+
checkpoint['decoder_projection_state_dict'], 'decoder_projection')
|
|
653
|
+
except Exception as e:
|
|
654
|
+
warnings.warn(f"Failed to load decoder projection: {e}")
|
|
655
|
+
|
|
656
|
+
# Load projection models (PCA equivalent)
|
|
657
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
658
|
+
if 'text_projection_state_dict' in checkpoint:
|
|
659
|
+
_load_model_state_dict(self.clip_text_projection, checkpoint['text_projection_state_dict'],
|
|
660
|
+
'text_projection')
|
|
661
|
+
if 'image_projection_state_dict' in checkpoint:
|
|
662
|
+
_load_model_state_dict(self.clip_image_projection, checkpoint['image_projection_state_dict'],
|
|
663
|
+
'image_projection')
|
|
664
|
+
|
|
665
|
+
# Load optimizer
|
|
666
|
+
if 'optimizer_state_dict' in checkpoint:
|
|
667
|
+
try:
|
|
668
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
669
|
+
if self.master_process:
|
|
670
|
+
print("✓ Loaded optimizer")
|
|
671
|
+
except Exception as e:
|
|
672
|
+
warnings.warn(f"Failed to load optimizer state: {e}")
|
|
673
|
+
|
|
674
|
+
# Load schedulers
|
|
675
|
+
if 'scheduler_state_dict' in checkpoint:
|
|
676
|
+
try:
|
|
677
|
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
678
|
+
if self.master_process:
|
|
679
|
+
print("✓ Loaded main scheduler")
|
|
680
|
+
except Exception as e:
|
|
681
|
+
warnings.warn(f"Failed to load scheduler state: {e}")
|
|
682
|
+
|
|
683
|
+
if 'warmup_scheduler_state_dict' in checkpoint:
|
|
684
|
+
try:
|
|
685
|
+
self.warmup_lr_scheduler.load_state_dict(checkpoint['warmup_scheduler_state_dict'])
|
|
686
|
+
if self.master_process:
|
|
687
|
+
print("✓ Loaded warmup scheduler")
|
|
688
|
+
except Exception as e:
|
|
689
|
+
warnings.warn(f"Failed to load warmup scheduler state: {e}")
|
|
690
|
+
|
|
691
|
+
# Verify configuration compatibility
|
|
692
|
+
if 'embedding_dim' in checkpoint:
|
|
693
|
+
if checkpoint['embedding_dim'] != self.clip_embedding_dim:
|
|
694
|
+
warnings.warn(
|
|
695
|
+
f"Embedding dimension mismatch: checkpoint={checkpoint['embedding_dim']}, current={self.clip_embedding_dim}")
|
|
696
|
+
|
|
697
|
+
if 'reduce_dim' in checkpoint:
|
|
698
|
+
if checkpoint['reduce_dim'] != self.reduce_clip_embedding_dim:
|
|
699
|
+
warnings.warn(
|
|
700
|
+
f"Reduce dimension setting mismatch: checkpoint={checkpoint['reduce_dim']}, current={self.reduce_clip_embedding_dim}")
|
|
701
|
+
|
|
702
|
+
epoch = checkpoint.get('epoch', 0)
|
|
703
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
704
|
+
|
|
705
|
+
if self.master_process:
|
|
706
|
+
print(f"Successfully loaded checkpoint from {checkpoint_path}")
|
|
707
|
+
print(f"Epoch: {epoch}, Loss: {loss:.4f}")
|
|
708
|
+
|
|
709
|
+
return epoch, loss
|
|
710
|
+
|
|
711
|
+
def validate(self) -> Tuple[float, Optional[float], Optional[float], Optional[float], Optional[float], Optional[float]]:
|
|
712
|
+
"""Validates the UnCLIP decoder model.
|
|
713
|
+
|
|
714
|
+
Computes validation loss and optional metrics (FID, MSE, PSNR, SSIM, LPIPS) by
|
|
715
|
+
encoding images and texts, applying forward diffusion, predicting noise, and
|
|
716
|
+
reconstructing images through reverse diffusion.
|
|
717
|
+
|
|
718
|
+
Returns
|
|
719
|
+
-------
|
|
720
|
+
val_loss : float
|
|
721
|
+
Mean validation loss.
|
|
722
|
+
fid_avg : float or None
|
|
723
|
+
Average FID score, if computed.
|
|
724
|
+
mse_avg : float or None
|
|
725
|
+
Average MSE score, if computed.
|
|
726
|
+
psnr_avg : float or None
|
|
727
|
+
Average PSNR score, if computed.
|
|
728
|
+
ssim_avg : float or None
|
|
729
|
+
Average SSIM score, if computed.
|
|
730
|
+
lpips_avg : float or None
|
|
731
|
+
Average LPIPS score, if computed.
|
|
732
|
+
"""
|
|
733
|
+
|
|
734
|
+
# set models to eval mode for evaluation
|
|
735
|
+
self.decoder_model.eval() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj, decoder_projection to eval mode
|
|
736
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
737
|
+
self.clip_text_projection.eval()
|
|
738
|
+
self.clip_image_projection.eval()
|
|
739
|
+
if self.clip_model is not None:
|
|
740
|
+
self.clip_model.eval()
|
|
741
|
+
|
|
742
|
+
val_losses = []
|
|
743
|
+
fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
|
|
744
|
+
|
|
745
|
+
with torch.no_grad():
|
|
746
|
+
for images, texts in self.val_loader:
|
|
747
|
+
images = images.to(self.device, non_blocking=True)
|
|
748
|
+
images_orig = images.clone()
|
|
749
|
+
text_embeddings, image_embeddings = self._get_clip_embeddings(images, texts)
|
|
750
|
+
text_embeddings, image_embeddings = self._apply_dimensionality_reduction(
|
|
751
|
+
text_embeddings, image_embeddings
|
|
752
|
+
)
|
|
753
|
+
p_classifier_free = torch.rand(1).item()
|
|
754
|
+
p_text_drop = torch.rand(1).item()
|
|
755
|
+
predicted_noise, noise = self.decoder_model(
|
|
756
|
+
image_embeddings,
|
|
757
|
+
text_embeddings,
|
|
758
|
+
images,
|
|
759
|
+
texts,
|
|
760
|
+
p_classifier_free,
|
|
761
|
+
p_text_drop
|
|
762
|
+
)
|
|
763
|
+
loss = self.objective(predicted_noise, noise)
|
|
764
|
+
val_losses.append(loss.item())
|
|
765
|
+
|
|
766
|
+
if self.metrics_ is not None and self.decoder_model.reverse_diffusion is not None:
|
|
767
|
+
xt = torch.randn_like(images).to(self.device)
|
|
768
|
+
for t in reversed(range(self.decoder_model.forward_diffusion.variance_scheduler.tau_num_steps)):
|
|
769
|
+
time_steps = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
|
|
770
|
+
prev_time_steps = torch.full((xt.shape[0],), max(t - 1, 0), device=self.device, dtype=torch.long)
|
|
771
|
+
image_embeddings = self.decoder_model._apply_classifier_free_guidance(image_embeddings, p_classifier_free)
|
|
772
|
+
text_embeddings = self.decoder_model._apply_text_dropout(text_embeddings, p_text_drop)
|
|
773
|
+
c = self.decoder_model.decoder_projection(image_embeddings) # updated to submodule
|
|
774
|
+
y_encoded = self.decoder_model._encode_text_with_glide(texts if text_embeddings is not None else None)
|
|
775
|
+
context = self.decoder_model._concatenate_embeddings(y_encoded, c)
|
|
776
|
+
clip_image_embedding = self.decoder_model.clip_time_proj(image_embeddings)
|
|
777
|
+
predicted_noise = self.decoder_model.noise_predictor(xt, time_steps, context, clip_image_embedding)
|
|
778
|
+
xt, _ = self.decoder_model.reverse_diffusion(xt, predicted_noise, time_steps, prev_time_steps)
|
|
779
|
+
|
|
780
|
+
x_hat = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
|
|
781
|
+
|
|
782
|
+
if self.normalize:
|
|
783
|
+
x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
784
|
+
x_orig = (images_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
|
|
785
|
+
|
|
786
|
+
metrics_result = self.metrics_.forward(x_orig, x_hat)
|
|
787
|
+
fid = metrics_result[0] if getattr(self.metrics_, 'fid', False) else float('inf')
|
|
788
|
+
mse = metrics_result[1] if getattr(self.metrics_, 'metrics', False) else None
|
|
789
|
+
psnr = metrics_result[2] if getattr(self.metrics_, 'metrics', False) else None
|
|
790
|
+
ssim = metrics_result[3] if getattr(self.metrics_, 'metrics', False) else None
|
|
791
|
+
lpips_score = metrics_result[4] if getattr(self.metrics_, 'lpips', False) else None
|
|
792
|
+
|
|
793
|
+
if fid != float('inf'):
|
|
794
|
+
fid_scores.append(fid)
|
|
795
|
+
if mse is not None:
|
|
796
|
+
mse_scores.append(mse)
|
|
797
|
+
if psnr is not None:
|
|
798
|
+
psnr_scores.append(psnr)
|
|
799
|
+
if ssim is not None:
|
|
800
|
+
ssim_scores.append(ssim)
|
|
801
|
+
if lpips_score is not None:
|
|
802
|
+
lpips_scores.append(lpips_score)
|
|
803
|
+
|
|
804
|
+
# compute averages
|
|
805
|
+
val_loss = torch.tensor(val_losses).mean().item()
|
|
806
|
+
fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
|
|
807
|
+
mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
|
|
808
|
+
psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
|
|
809
|
+
ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
|
|
810
|
+
lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
|
|
811
|
+
|
|
812
|
+
# synchronize metrics across GPUs in DDP mode
|
|
813
|
+
if self.use_ddp:
|
|
814
|
+
metrics = [val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg]
|
|
815
|
+
metrics_tensors = [torch.tensor(m, device=self.device) if m is not None else torch.tensor(float('inf'), device=self.device) for m in metrics]
|
|
816
|
+
for tensor in metrics_tensors:
|
|
817
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.AVG)
|
|
818
|
+
val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg = [t.item() if t.item() != float('inf') else (None if i > 1 else float('inf')) for i, t in enumerate(metrics_tensors)]
|
|
819
|
+
|
|
820
|
+
# return to training mode
|
|
821
|
+
self.decoder_model.train() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj, decoder_projection to train mode
|
|
822
|
+
if not self.decoder_model.variance_scheduler.trainable_beta:
|
|
823
|
+
self.decoder_model.variance_scheduler.eval()
|
|
824
|
+
if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
|
|
825
|
+
if self.finetune_clip_projections:
|
|
826
|
+
self.clip_text_projection.train()
|
|
827
|
+
self.clip_image_projection.train()
|
|
828
|
+
else:
|
|
829
|
+
self.clip_text_projection.eval()
|
|
830
|
+
self.clip_image_projection.eval()
|
|
831
|
+
if self.clip_model is not None:
|
|
832
|
+
self.clip_model.eval()
|
|
833
|
+
|
|
834
|
+
return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
"""
|
|
838
|
+
from utils import NoisePredictor, TextEncoder, Metrics
|
|
839
|
+
from clip_model import CLIPEncoder
|
|
840
|
+
from torchvision import datasets, transforms
|
|
841
|
+
from torch.utils.data import DataLoader, Subset, Dataset
|
|
842
|
+
from project_prior import Projection
|
|
843
|
+
import torch
|
|
844
|
+
from prior_diff import VarianceSchedulerUnCLIP, ForwardUnCLIP, ReverseUnCLIP
|
|
845
|
+
from decoder_model import UnClipDecoder
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
class CIFAR10WithCaptions(Dataset):
|
|
849
|
+
def __init__(self, cifar_dataset):
|
|
850
|
+
self.dataset = cifar_dataset
|
|
851
|
+
self.class_names = [
|
|
852
|
+
'airplane', 'automobile', 'bird', 'cat', 'deer',
|
|
853
|
+
'dog', 'frog', 'horse', 'ship', 'truck'
|
|
854
|
+
]
|
|
855
|
+
# More descriptive templates
|
|
856
|
+
self.templates = [
|
|
857
|
+
"A photo of a {}",
|
|
858
|
+
"An image of a {}",
|
|
859
|
+
"A picture of a {}",
|
|
860
|
+
"This is a {}",
|
|
861
|
+
]
|
|
862
|
+
|
|
863
|
+
def __len__(self):
|
|
864
|
+
return len(self.dataset)
|
|
865
|
+
|
|
866
|
+
def __getitem__(self, idx):
|
|
867
|
+
image, label = self.dataset[idx]
|
|
868
|
+
class_name = self.class_names[label]
|
|
869
|
+
# Use different templates for variety
|
|
870
|
+
template = self.templates[idx % len(self.templates)]
|
|
871
|
+
caption = template.format(class_name)
|
|
872
|
+
return image, caption
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
|
|
876
|
+
# Updated transforms for CLIP
|
|
877
|
+
transform = transforms.Compose([
|
|
878
|
+
transforms.Resize((224, 224)),
|
|
879
|
+
transforms.ToTensor(),
|
|
880
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
881
|
+
])
|
|
882
|
+
|
|
883
|
+
# Load CIFAR-10 with captions
|
|
884
|
+
cifar_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
|
|
885
|
+
cifar_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
|
|
886
|
+
|
|
887
|
+
train_dataset = CIFAR10WithCaptions(cifar_train)
|
|
888
|
+
test_dataset = CIFAR10WithCaptions(cifar_test)
|
|
889
|
+
|
|
890
|
+
# Small subset for testing
|
|
891
|
+
train_subset_indices = torch.randperm(len(train_dataset))[:4]
|
|
892
|
+
test_subset_indices = torch.randperm(len(test_dataset))[:2]
|
|
893
|
+
train_subset = Subset(train_dataset, train_subset_indices)
|
|
894
|
+
test_subset = Subset(test_dataset, test_subset_indices)
|
|
895
|
+
|
|
896
|
+
# DataLoaders
|
|
897
|
+
t_loader = DataLoader(train_subset, batch_size=2, shuffle=True, pin_memory=True)
|
|
898
|
+
v_loader = DataLoader(test_subset, batch_size=1, shuffle=False, pin_memory=True)
|
|
899
|
+
|
|
900
|
+
d = torch.device("cuda")
|
|
901
|
+
|
|
902
|
+
n_model = NoisePredictor(
|
|
903
|
+
in_channels=3,
|
|
904
|
+
down_channels=[16, 32],
|
|
905
|
+
mid_channels=[32, 32],
|
|
906
|
+
up_channels=[32, 16],
|
|
907
|
+
down_sampling=[True, True],
|
|
908
|
+
time_embed_dim=32,
|
|
909
|
+
y_embed_dim=32,
|
|
910
|
+
num_down_blocks=2,
|
|
911
|
+
num_mid_blocks=2,
|
|
912
|
+
num_up_blocks=2,
|
|
913
|
+
down_sampling_factor=2
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
c_model = CLIPEncoder(
|
|
918
|
+
model_name="openai/clip-vit-base-patch32",
|
|
919
|
+
device="cuda",
|
|
920
|
+
use_fast=False
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
t_proj = Projection(
|
|
925
|
+
input_dim=512,
|
|
926
|
+
output_dim=32,
|
|
927
|
+
hidden_dim=128,
|
|
928
|
+
num_layers=2,
|
|
929
|
+
dropout=0.1,
|
|
930
|
+
use_layer_norm=True
|
|
931
|
+
)
|
|
932
|
+
i_proj = Projection(
|
|
933
|
+
input_dim=512,
|
|
934
|
+
output_dim=32,
|
|
935
|
+
hidden_dim=128,
|
|
936
|
+
num_layers=2,
|
|
937
|
+
dropout=0.1,
|
|
938
|
+
use_layer_norm=True
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
h_model = VarianceSchedulerUnCLIP(
|
|
942
|
+
num_steps=500,
|
|
943
|
+
beta_start=1e-4,
|
|
944
|
+
beta_end=0.02,
|
|
945
|
+
trainable_beta=False,
|
|
946
|
+
beta_method="linear"
|
|
947
|
+
)
|
|
948
|
+
for_ = ForwardUnCLIP(h_model)
|
|
949
|
+
rev_ = ReverseUnCLIP(h_model)
|
|
950
|
+
|
|
951
|
+
cond = TextEncoder(
|
|
952
|
+
use_pretrained_model=True,
|
|
953
|
+
model_name="bert-base-uncased",
|
|
954
|
+
vocabulary_size=30522,
|
|
955
|
+
num_layers=2,
|
|
956
|
+
input_dimension=32,
|
|
957
|
+
output_dimension=32,
|
|
958
|
+
num_heads=2,
|
|
959
|
+
context_length=77
|
|
960
|
+
).to(d)
|
|
961
|
+
|
|
962
|
+
decoder = UnClipDecoder(
|
|
963
|
+
embedding_dim=32,
|
|
964
|
+
noise_predictor=n_model,
|
|
965
|
+
forward_diffusion=for_,
|
|
966
|
+
reverse_diffusion=rev_,
|
|
967
|
+
conditional_model=cond, # GLIDE text encoder
|
|
968
|
+
tokenizer=None,
|
|
969
|
+
device="cpu",
|
|
970
|
+
output_range=(-1.0, 1.0),
|
|
971
|
+
normalize=True,
|
|
972
|
+
classifier_free=0.1, # paper specifies 10%
|
|
973
|
+
drop_caption=0.5, # paper specifies 50%
|
|
974
|
+
max_length=77
|
|
975
|
+
)
|
|
976
|
+
|
|
977
|
+
opt = torch.optim.AdamW([p for p in decoder.parameters() if p.requires_grad], lr=1e-3)
|
|
978
|
+
|
|
979
|
+
|
|
980
|
+
obj = nn.MSELoss()
|
|
981
|
+
|
|
982
|
+
mets = Metrics(
|
|
983
|
+
device="cpu",
|
|
984
|
+
fid=True,
|
|
985
|
+
metrics=True,
|
|
986
|
+
lpips_=True
|
|
987
|
+
)
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
model = TrainUnClipDecoder(
|
|
991
|
+
embedding_dim=512,
|
|
992
|
+
decoder_model=decoder,
|
|
993
|
+
clip_model=c_model,
|
|
994
|
+
train_loader=t_loader,
|
|
995
|
+
optimizer=opt,
|
|
996
|
+
objective=obj,
|
|
997
|
+
text_projection=t_proj,
|
|
998
|
+
image_projection=i_proj,
|
|
999
|
+
val_loader=v_loader,
|
|
1000
|
+
metrics_=mets,
|
|
1001
|
+
max_epoch=5,
|
|
1002
|
+
device="cuda",
|
|
1003
|
+
store_path="unclip_decoder",
|
|
1004
|
+
patience=5,
|
|
1005
|
+
warmup_epochs=2,
|
|
1006
|
+
val_frequency=10,
|
|
1007
|
+
use_ddp=False,
|
|
1008
|
+
num_grad_accumulation=1,
|
|
1009
|
+
progress_frequency=1,
|
|
1010
|
+
compilation=False,
|
|
1011
|
+
output_range=(-1.0, 1.0),
|
|
1012
|
+
reduce_dim=True,
|
|
1013
|
+
output_dim=32,
|
|
1014
|
+
normalize=True,
|
|
1015
|
+
finetune_projections=False
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
# Ensure requires_grad is set correctly
|
|
1019
|
+
for p in model.clip_model.parameters():
|
|
1020
|
+
p.requires_grad = False
|
|
1021
|
+
if not model.finetune_projections:
|
|
1022
|
+
for p in model.text_projection.parameters():
|
|
1023
|
+
p.requires_grad = False
|
|
1024
|
+
for p in model.image_projection.parameters():
|
|
1025
|
+
p.requires_grad = False
|
|
1026
|
+
if not model.decoder_model.forward_diffusion.variance_scheduler.trainable_beta:
|
|
1027
|
+
for p in model.decoder_model.forward_diffusion.variance_scheduler.parameters():
|
|
1028
|
+
p.requires_grad = False
|
|
1029
|
+
|
|
1030
|
+
# Run training
|
|
1031
|
+
one, two = model()
|
|
1032
|
+
|
|
1033
|
+
# Count trainable parameters
|
|
1034
|
+
def count_trainable_parameters(model, finetune_projections=False):
|
|
1035
|
+
total_params = 0
|
|
1036
|
+
total_params += sum(p.numel() for p in model.decoder_model.parameters() if p.requires_grad)
|
|
1037
|
+
if finetune_projections and model.text_projection is not None and model.image_projection is not None:
|
|
1038
|
+
total_params += sum(p.numel() for p in model.text_projection.parameters() if p.requires_grad)
|
|
1039
|
+
total_params += sum(p.numel() for p in model.image_projection.parameters() if p.requires_grad)
|
|
1040
|
+
return total_params
|
|
1041
|
+
|
|
1042
|
+
# Case 1: finetune_projections=False, train_projection=False
|
|
1043
|
+
print("Trainable parameters (finetune_projections=False, train_projection=False):")
|
|
1044
|
+
total_params_false = count_trainable_parameters(model, finetune_projections=False)
|
|
1045
|
+
print(f"Total trainable parameters: {total_params_false}")
|
|
1046
|
+
|
|
1047
|
+
# Case 2: finetune_projections=True, train_projection=True
|
|
1048
|
+
model.finetune_projections = True
|
|
1049
|
+
for p in model.text_projection.parameters():
|
|
1050
|
+
p.requires_grad = True
|
|
1051
|
+
for p in model.image_projection.parameters():
|
|
1052
|
+
p.requires_grad = True
|
|
1053
|
+
|
|
1054
|
+
print("\nTrainable parameters (finetune_projections=True, train_projection=True):")
|
|
1055
|
+
total_params_true = count_trainable_parameters(model, finetune_projections=True)
|
|
1056
|
+
print(f"Total trainable parameters: {total_params_true}")
|
|
1057
|
+
|
|
1058
|
+
print("After parameters count")
|
|
1059
|
+
"""
|