TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
@@ -0,0 +1,1059 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, List, Tuple, Union, Callable, Any
4
+ from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
5
+ import torch.distributed as dist
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+ from torch.distributed import init_process_group, destroy_process_group
8
+ from tqdm import tqdm
9
+ import os
10
+ import warnings
11
+
12
+
13
+
14
+
15
+ class TrainUnClipDecoder(nn.Module):
16
+ """Trainer for the UnCLIP decoder model.
17
+
18
+ Orchestrates the training of the UnCLIP decoder model, integrating CLIP embeddings, forward
19
+ and reverse diffusion processes, and optional dimensionality reduction. Supports mixed
20
+ precision, gradient accumulation, DDP, and comprehensive evaluation metrics.
21
+
22
+ Parameters
23
+ ----------
24
+ `clip_embedding_dim` : int
25
+ Dimensionality of the input embeddings.
26
+ `decoder_model` : nn.Module
27
+ The UnCLIP decoder model (e.g., UnClipDecoder) to be trained.
28
+ `clip_model` : nn.Module
29
+ CLIP model for generating text and image embeddings.
30
+ `train_loader` : torch.utils.data.DataLoader
31
+ DataLoader for training data.
32
+ `optimizer` : torch.optim.Optimizer
33
+ Optimizer for training the decoder model.
34
+ `objective` : Callable
35
+ Loss function to compute the difference between predicted and target noise.
36
+ `clip_text_projection` : nn.Module, optional
37
+ Projection module for text embeddings, default None.
38
+ `clip_image_projection` : nn.Module, optional
39
+ Projection module for image embeddings, default None.
40
+ `val_loader` : torch.utils.data.DataLoader, optional
41
+ DataLoader for validation data, default None.
42
+ `metrics_` : Any, optional
43
+ Object providing evaluation metrics (e.g., FID, MSE, PSNR, SSIM, LPIPS), default None.
44
+ `max_epochs` : int, optional
45
+ Maximum number of training epochs (default: 1000).
46
+ `device` : Union[str, torch.device], optional
47
+ Device for computation (default: CUDA if available, else CPU).
48
+ `store_path` : str, optional
49
+ Directory to save model checkpoints (default: "unclip_decoder").
50
+ `patience` : int, optional
51
+ Number of epochs to wait for improvement before early stopping (default: 100).
52
+ `warmup_epochs` : int, optional
53
+ Number of epochs for learning rate warmup (default: 100).
54
+ `val_frequency` : int, optional
55
+ Frequency (in epochs) for validation (default: 10).
56
+ `use_ddp` : bool, optional
57
+ Whether to use Distributed Data Parallel training (default: False).
58
+ `grad_accumulation_steps` : int, optional
59
+ Number of gradient accumulation steps before optimizer update (default: 1).
60
+ `log_frequency` : int, optional
61
+ Frequency (in epochs) for printing progress (default: 1).
62
+ `use_compilation` : bool, optional
63
+ Whether to compile the model using torch.compile (default: False).
64
+ `image_output_range` : Tuple[float, float], optional
65
+ Range for clamping output images (default: (-1.0, 1.0)).
66
+ `reduce_clip_embedding_dim` : bool, optional
67
+ Whether to apply dimensionality reduction to embeddings (default: True).
68
+ `transformer_embedding_dim` : int, optional
69
+ Output dimensionality for reduced embeddings (default: 312).
70
+ `normalize_clip_embeddings` : bool, optional
71
+ Whether to normalize CLIP embeddings (default: True).
72
+ `finetune_clip_projections` : bool, optional
73
+ Whether to fine-tune projection layers (default: False).
74
+ """
75
+ def __init__(
76
+ self,
77
+ clip_embedding_dim: int,
78
+ decoder_model: nn.Module,
79
+ clip_model: nn.Module,
80
+ train_loader: torch.utils.data.DataLoader,
81
+ optimizer: torch.optim.Optimizer,
82
+ objective: Callable,
83
+ clip_text_projection: Optional[nn.Module] = None,
84
+ clip_image_projection: Optional[nn.Module] = None,
85
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
86
+ metrics_: Optional[Any] = None,
87
+ max_epochs: int = 1000,
88
+ device: Optional[Union[str, torch.device]] = None,
89
+ store_path: str = "unclip_decoder",
90
+ patience: int = 100,
91
+ warmup_epochs: int = 100,
92
+ val_frequency: int = 10,
93
+ use_ddp: bool = False,
94
+ grad_accumulation_steps: int = 1,
95
+ log_frequency: int = 1,
96
+ use_compilation: bool = False,
97
+ image_output_range: Tuple[float, float] = (-1.0, 1.0),
98
+ reduce_clip_embedding_dim: bool = True,
99
+ transformer_embedding_dim: int = 312,
100
+ normalize_clip_embeddings: bool = True,
101
+ finetune_clip_projections: bool = False # if text_projection and image_projection model should be finetune
102
+ ):
103
+ super().__init__()
104
+ # training configuration
105
+ self.use_ddp = use_ddp
106
+ self.grad_accumulation_steps = grad_accumulation_steps
107
+ self.use_compilation = use_compilation
108
+ if device is None:
109
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
110
+ elif isinstance(device, str):
111
+ self.device = torch.device(device)
112
+ else:
113
+ self.device = device
114
+
115
+ # core models
116
+ self.decoder_model = decoder_model.to(self.device)
117
+ self.clip_model = clip_model.to(self.device)
118
+
119
+ self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
120
+
121
+ # setup distributed training
122
+ if self.use_ddp:
123
+ self._setup_ddp()
124
+ else:
125
+ self._setup_single_gpu()
126
+
127
+ # compile and wrap models
128
+ self._compile_models()
129
+ self._wrap_models_for_ddp()
130
+
131
+ # projection models (PCA equivalent in the paper)
132
+ if self.reduce_clip_embedding_dim and clip_text_projection is not None and clip_image_projection is not None:
133
+ self.clip_text_projection = clip_text_projection.to(self.device)
134
+ self.clip_image_projection = clip_image_projection.to(self.device)
135
+ else:
136
+ self.clip_text_projection = None
137
+ self.clip_image_projection = None
138
+
139
+ # training components
140
+ self.clip_embedding_dim = transformer_embedding_dim if self.reduce_clip_embedding_dim else clip_embedding_dim
141
+ self.metrics_ = metrics_
142
+ self.optimizer = optimizer
143
+ self.objective = objective
144
+ self.train_loader = train_loader
145
+ self.val_loader = val_loader
146
+
147
+ # training parameters
148
+ self.max_epochs = max_epochs
149
+ self.patience = patience
150
+ self.val_frequency = val_frequency
151
+ self.log_frequency = log_frequency
152
+ self.image_output_range = image_output_range
153
+ self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
154
+ self.normalize_clip_embeddings = normalize_clip_embeddings
155
+ self.transformer_embedding_dim = transformer_embedding_dim
156
+ self.finetune_clip_projections = finetune_clip_projections
157
+
158
+
159
+ # checkpoint management
160
+ self.store_path = store_path
161
+
162
+ # learning rate scheduling
163
+ self.scheduler = ReduceLROnPlateau(
164
+ self.optimizer,
165
+ patience=self.patience,
166
+ factor=0.5
167
+ )
168
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
169
+
170
+ def forward(self) -> Tuple[List[float], float]:
171
+ """Trains the UnCLIP decoder model to predict noise for denoising.
172
+
173
+ Executes the training loop, optimizing the decoder model using CLIP embeddings, mixed
174
+ precision, gradient clipping, and learning rate scheduling. Supports validation, early
175
+ stopping, and checkpointing.
176
+
177
+ Returns
178
+ -------
179
+ train_losses : List[float]
180
+ List of mean training losses per epoch.
181
+ best_val_loss : float
182
+ Best validation or training loss achieved.
183
+ """
184
+ # set models to training mode
185
+ self.decoder_model.train() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj to train mode
186
+ if not self.decoder_model.forward_diffusion.variance_scheduler.trainable_beta: # ff beta is not trainable
187
+ self.decoder_model.forward_diffusion.variance_scheduler.eval()
188
+
189
+ # set text_projection and image_projection to train mode if fine-tuning
190
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
191
+ if self.finetune_clip_projections:
192
+ self.clip_text_projection.train()
193
+ self.clip_image_projection.train()
194
+ else:
195
+ self.clip_text_projection.eval()
196
+ self.clip_image_projection.eval()
197
+
198
+ # set CLIP model to eval mode (frozen)
199
+ if self.clip_model is not None:
200
+ self.clip_model.eval()
201
+
202
+ # initialize training components
203
+ scaler = torch.GradScaler()
204
+ train_losses = []
205
+ best_val_loss = float("inf")
206
+ wait = 0
207
+
208
+ # main training loop
209
+ for epoch in range(self.max_epochs):
210
+ # set epoch for distributed sampler if using DDP
211
+ if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
212
+ self.train_loader.sampler.set_epoch(epoch)
213
+
214
+ train_losses_epoch = []
215
+
216
+ # training step loop with gradient accumulation
217
+ for step, (images, texts) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
218
+ images = images.to(self.device, non_blocking=True)
219
+
220
+ # forward pass with mixed precision
221
+ with torch.autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
222
+ # encode text and image with CLIP
223
+ text_embeddings, image_embeddings = self._get_clip_embeddings(images, texts)
224
+
225
+ # reduce dimensionality (PCA equivalent)
226
+ text_embeddings, image_embeddings = self._apply_dimensionality_reduction(
227
+ text_embeddings, image_embeddings
228
+ )
229
+
230
+ # use decoder model to predict noise
231
+ p_classifier_free = torch.rand(1).item()
232
+ p_text_drop = torch.rand(1).item()
233
+ predicted_noise, noise = self.decoder_model(
234
+ image_embeddings,
235
+ text_embeddings,
236
+ images,
237
+ texts,
238
+ p_classifier_free,
239
+ p_text_drop
240
+ )
241
+
242
+ # compute loss
243
+ loss = self.objective(predicted_noise, noise) / self.num_grad_accumulation
244
+
245
+ scaler.scale(loss).backward()
246
+
247
+ if (step + 1) % self.num_grad_accumulation == 0:
248
+ # clip gradients
249
+ scaler.unscale_(self.optimizer)
250
+ torch.nn.utils.clip_grad_norm_(self.decoder_model.parameters(), max_norm=1.0) # covers all submodules
251
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
252
+ torch.nn.utils.clip_grad_norm_(self.clip_text_projection.parameters(), max_norm=1.0)
253
+ torch.nn.utils.clip_grad_norm_(self.clip_image_projection.parameters(), max_norm=1.0)
254
+
255
+ scaler.step(self.optimizer)
256
+ scaler.update()
257
+ self.optimizer.zero_grad()
258
+ self.warmup_lr_scheduler.step()
259
+ torch.cuda.empty_cache() # clear memory after optimizer step
260
+
261
+ train_losses_epoch.append(loss.item() * self.num_grad_accumulation)
262
+
263
+ mean_train_loss = self._compute_mean_loss(train_losses_epoch)
264
+ train_losses.append(mean_train_loss)
265
+
266
+ if self.master_process and (epoch + 1) % self.log_frequency == 0:
267
+ current_lr = self.optimizer.param_groups[0]['lr']
268
+ print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
269
+
270
+ current_loss = mean_train_loss
271
+
272
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
273
+ val_metrics = self.validate()
274
+ val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
275
+
276
+ if self.master_process:
277
+ print(f" | Val Loss: {val_loss:.4f}", end="")
278
+ if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
279
+ print(f" | FID: {fid:.4f}", end="")
280
+ if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
281
+ print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
282
+ if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
283
+ print(f" | LPIPS: {lpips_score:.4f}", end="")
284
+ print()
285
+
286
+ self.scheduler.step(current_loss)
287
+
288
+ if self.master_process:
289
+ if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
290
+ best_val_loss = current_loss
291
+ wait = 0
292
+ self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
293
+ else:
294
+ wait += 1
295
+ if wait >= self.patience:
296
+ print("Early stopping triggered")
297
+ self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
298
+ break
299
+
300
+ if self.use_ddp:
301
+ destroy_process_group()
302
+
303
+ return train_losses, best_val_loss
304
+
305
+ def _setup_ddp(self) -> None:
306
+ """Sets up Distributed Data Parallel training configuration.
307
+
308
+ Initializes the process group, sets up rank information, and configures the CUDA
309
+ device for the current process in DDP mode.
310
+ """
311
+ required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
312
+ for var in required_env_vars:
313
+ if var not in os.environ:
314
+ raise ValueError(f"DDP enabled but {var} environment variable not set")
315
+
316
+ if not torch.cuda.is_available():
317
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
318
+
319
+ if not torch.distributed.is_initialized():
320
+ init_process_group(backend="nccl")
321
+
322
+ self.ddp_rank = int(os.environ["RANK"])
323
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
324
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"])
325
+
326
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
327
+ torch.cuda.set_device(self.device)
328
+
329
+ self.master_process = self.ddp_rank == 0
330
+
331
+ if self.master_process:
332
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
333
+
334
+ def _setup_single_gpu(self) -> None:
335
+ """Sets up single GPU or CPU training configuration.
336
+
337
+ Configures the training setup for single-device operation, setting rank and process
338
+ information for non-DDP training.
339
+ """
340
+ self.ddp_rank = 0
341
+ self.ddp_local_rank = 0
342
+ self.ddp_world_size = 1
343
+ self.master_process = True
344
+
345
+ @staticmethod
346
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
347
+ """Creates a learning rate scheduler for warmup.
348
+
349
+ Generates a scheduler that linearly increases the learning rate from 0 to the
350
+ optimizer's initial value over the specified warmup epochs.
351
+
352
+ Parameters
353
+ ----------
354
+ `optimizer` : torch.optim.Optimizer
355
+ Optimizer to apply the scheduler to.
356
+ `warmup_epochs` : int
357
+ Number of epochs for the warmup phase.
358
+
359
+ Returns
360
+ -------
361
+ lr_scheduler : torch.optim.lr_scheduler.LambdaLR
362
+ Learning rate scheduler for warmup.
363
+ """
364
+ def lr_lambda(epoch):
365
+ return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
366
+
367
+ return LambdaLR(optimizer, lr_lambda)
368
+
369
+ def _wrap_models_for_ddp(self) -> None:
370
+ """Wraps models with DistributedDataParallel for multi-GPU training.
371
+
372
+ Configures the decoder model and, if fine-tuning, the projection models for DDP training.
373
+ """
374
+ if self.use_ddp:
375
+ self.decoder_model = self.decoder_model.to(self.ddp_local_rank)
376
+ self.decoder_model = DDP(
377
+ self.decoder_model,
378
+ device_ids=[self.ddp_local_rank],
379
+ find_unused_parameters=True
380
+ )
381
+ # only wrap text_projection and image_projection if they are trainable
382
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
383
+ self.clip_text_projection = self.clip_text_projection.to(self.ddp_local_rank)
384
+ self.clip_image_projection = self.clip_image_projection.to(self.ddp_local_rank)
385
+ self.clip_text_projection = DDP(self.clip_text_projection, device_ids=[self.ddp_local_rank])
386
+ self.clip_image_projection = DDP(self.clip_image_projection, device_ids=[self.ddp_local_rank])
387
+
388
+ def _compile_models(self) -> None:
389
+ """Compiles models for optimization if supported.
390
+
391
+ Attempts to compile the decoder model and, if fine-tuning, the projection models using
392
+ torch.compile for optimization, falling back to uncompiled execution if compilation fails.
393
+ """
394
+ if self.use_compilation:
395
+ try:
396
+ self.decoder_model = self.decoder_model.to(self.device)
397
+ self.decoder_model = torch.compile(self.decoder_model, mode="reduce-overhead")
398
+ # only compile text_projection and image_projection if they are trainable
399
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None and self.finetune_clip_projections:
400
+ self.clip_text_projection = self.clip_text_projection.to(self.device)
401
+ self.clip_image_projection = self.clip_image_projection.to(self.device)
402
+ self.clip_text_projection = torch.compile(self.clip_text_projection, mode="reduce-overhead")
403
+ self.clip_image_projection = torch.compile(self.clip_image_projection, mode="reduce-overhead")
404
+ if self.master_process:
405
+ print("Models compiled successfully")
406
+ except Exception as e:
407
+ if self.master_process:
408
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
409
+
410
+ def _get_clip_embeddings(
411
+ self,
412
+ images: torch.Tensor,
413
+ texts: Union[List, torch.Tensor]
414
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
415
+ """Encodes images and texts using the CLIP model.
416
+
417
+ Generates text and image embeddings using the CLIP model, with optional normalization.
418
+
419
+ Parameters
420
+ ----------
421
+ `images` : torch.Tensor
422
+ Input images, shape (batch_size, channels, height, width).
423
+ `texts` : Union[List, torch.Tensor]
424
+ Text prompts for conditional generation.
425
+
426
+ Returns
427
+ -------
428
+ text_embeddings : torch.Tensor
429
+ CLIP text embeddings, shape (batch_size, embedding_dim).
430
+ image_embeddings : torch.Tensor
431
+ CLIP image embeddings, shape (batch_size, embedding_dim).
432
+ """
433
+ with torch.no_grad():
434
+ # encode text y with CLIP text encoder: z_t ← CLIP_text(y)
435
+ text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize)
436
+ # encode image x with CLIP image encoder: z_i ← CLIP_image(x)
437
+ image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize)
438
+ return text_embeddings, image_embeddings
439
+
440
+ def _apply_dimensionality_reduction(
441
+ self,
442
+ text_embeddings: torch.Tensor,
443
+ image_embeddings: torch.Tensor
444
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
445
+ """Applies dimensionality reduction to embeddings if enabled.
446
+
447
+ Projects text and image embeddings to a lower-dimensional space using learned
448
+ projection layers, mimicking PCA as used in the UnCLIP paper.
449
+
450
+ Parameters
451
+ ----------
452
+ `text_embeddings` : torch.Tensor
453
+ CLIP text embeddings, shape (batch_size, embedding_dim).
454
+ `image_embeddings` : torch.Tensor
455
+ CLIP image embeddings, shape (batch_size, embedding_dim).
456
+
457
+ Returns
458
+ -------
459
+ text_embeddings : torch.Tensor
460
+ Projected text embeddings, shape (batch_size, output_dim) if reduced, else unchanged.
461
+ image_embeddings : torch.Tensor
462
+ Projected image embeddings, shape (batch_size, output_dim) if reduced, else unchanged.
463
+ """
464
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
465
+ if not self.finetune_clip_projections:
466
+ with torch.no_grad():
467
+ text_embeddings = self.clip_text_projection(text_embeddings.to(self.device))
468
+ image_embeddings = self.clip_image_projection(image_embeddings.to(self.device))
469
+ else:
470
+ text_embeddings = self.clip_text_projection(text_embeddings.to(self.device))
471
+ image_embeddings = self.clip_image_projection(image_embeddings.to(self.device))
472
+ return text_embeddings.to(self.device), image_embeddings.to(self.device)
473
+
474
+ def _compute_mean_loss(self, losses: List[float]) -> float:
475
+ """Computes mean loss with DDP synchronization if needed.
476
+
477
+ Calculates the mean of the provided losses and synchronizes the result across
478
+ processes in DDP mode.
479
+
480
+ Parameters
481
+ ----------
482
+ `losses` : List[float]
483
+ List of loss values for the current epoch.
484
+
485
+ Returns
486
+ -------
487
+ mean_loss : float
488
+ Mean loss value, synchronized if using DDP.
489
+ """
490
+ if not losses:
491
+ return 0.0
492
+ mean_loss = sum(losses) / len(losses)
493
+ if self.use_ddp:
494
+ # synchronize loss across all processes
495
+ loss_tensor = torch.tensor(mean_loss, device=self.device)
496
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
497
+ mean_loss = (loss_tensor / self.ddp_world_size).item()
498
+
499
+ return mean_loss
500
+
501
+ def _save_checkpoint(self, epoch: int, loss: float, is_best: bool = False, suffix: str = ""):
502
+ """Saves model checkpoint.
503
+
504
+ Saves the state of the decoder model, its submodules, optimizer, and schedulers,
505
+ with options for best model and epoch-specific checkpoints.
506
+
507
+ Parameters
508
+ ----------
509
+ `epoch` : int
510
+ Current epoch number.
511
+ `loss` : float
512
+ Current loss value.
513
+ `is_best` : bool, optional
514
+ Whether to save as the best model checkpoint (default: False).
515
+ `suffix` : str, optional
516
+ Suffix to add to checkpoint filename, default "".
517
+ """
518
+ if not self.master_process:
519
+ return
520
+ checkpoint = {
521
+ 'epoch': epoch,
522
+ 'loss': loss,
523
+ # Core models (submodules of decoder_model)
524
+ 'noise_predictor_state_dict': self.decoder_model.module.noise_predictor.state_dict() if self.use_ddp else self.decoder_model.noise_predictor.state_dict(),
525
+ 'optimizer_state_dict': self.optimizer.state_dict(),
526
+ # Training configuration
527
+ 'embedding_dim': self.clip_embedding_dim,
528
+ 'output_dim': self.transformer_embedding_dim,
529
+ 'reduce_dim': self.reduce_clip_embedding_dim,
530
+ 'normalize': self.normalize
531
+ }
532
+
533
+ # Save conditional model (submodule of decoder_model)
534
+ if self.decoder_model.conditional_model is not None:
535
+ checkpoint['conditional_model_state_dict'] = (
536
+ self.decoder_model.module.conditional_model.state_dict() if self.use_ddp
537
+ else self.decoder_model.conditional_model.state_dict()
538
+ )
539
+
540
+ # Save variance scheduler (submodule of decoder_model, always saved)
541
+ checkpoint['variance_scheduler_state_dict'] = (
542
+ self.decoder_model.forward_diffusion.module.variance_scheduler.state_dict() if self.use_ddp
543
+ else self.decoder_model.forward_diffusion.variance_scheduler.state_dict()
544
+ )
545
+
546
+ # Save CLIP time projection layer (submodule of decoder_model)
547
+ checkpoint['clip_time_proj_state_dict'] = (
548
+ self.decoder_model.module.clip_time_proj.state_dict() if self.use_ddp
549
+ else self.decoder_model.clip_time_proj.state_dict()
550
+ )
551
+
552
+ # Save decoder projection layer (submodule of decoder_model)
553
+ checkpoint['decoder_projection_state_dict'] = (
554
+ self.decoder_model.module.decoder_projection.state_dict() if self.use_ddp
555
+ else self.decoder_model.decoder_projection.state_dict()
556
+ )
557
+
558
+ # Save projection models (PCA equivalent)
559
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
560
+ checkpoint['text_projection_state_dict'] = (
561
+ self.clip_text_projection.module.state_dict() if self.use_ddp
562
+ else self.clip_text_projection.state_dict()
563
+ )
564
+ checkpoint['image_projection_state_dict'] = (
565
+ self.clip_image_projection.module.state_dict() if self.use_ddp
566
+ else self.clip_image_projection.state_dict()
567
+ )
568
+
569
+ # Save schedulers state
570
+ checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
571
+ checkpoint['warmup_scheduler_state_dict'] = self.warmup_lr_scheduler.state_dict()
572
+
573
+ filename = f"unclip_decoder_epoch_{epoch}{suffix}.pth"
574
+ if is_best:
575
+ filename = f"unclip_decoder_best{suffix}.pth"
576
+
577
+ filepath = os.path.join(self.store_path, filename)
578
+ os.makedirs(self.store_path, exist_ok=True)
579
+ torch.save(checkpoint, filepath)
580
+
581
+ if is_best:
582
+ print(f"Best model saved: {filepath}")
583
+
584
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
585
+ """Loads model checkpoint.
586
+
587
+ Restores the state of the decoder model, its submodules, optimizer, and schedulers
588
+ from a saved checkpoint, handling DDP compatibility.
589
+
590
+ Parameters
591
+ ----------
592
+ `checkpoint_path` : str
593
+ Path to the checkpoint file.
594
+
595
+ Returns
596
+ -------
597
+ epoch : int
598
+ The epoch at which the checkpoint was saved.
599
+ loss : float
600
+ The loss at the checkpoint.
601
+ """
602
+ try:
603
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
604
+ except FileNotFoundError:
605
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
606
+
607
+ def _load_model_state_dict(model: nn.Module, state_dict: dict, model_name: str) -> None:
608
+ """Helper function to load state dict with DDP compatibility."""
609
+ try:
610
+ # Handle DDP state dict compatibility
611
+ if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
612
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
613
+ elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
614
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
615
+
616
+ model.load_state_dict(state_dict)
617
+ if self.master_process:
618
+ print(f"✓ Loaded {model_name}")
619
+ except Exception as e:
620
+ warnings.warn(f"Failed to load {model_name}: {e}")
621
+
622
+ # Load core noise predictor model (submodule of decoder_model)
623
+ if 'noise_predictor_state_dict' in checkpoint:
624
+ _load_model_state_dict(self.decoder_model.noise_predictor, checkpoint['noise_predictor_state_dict'],
625
+ 'noise_predictor')
626
+
627
+ # Load conditional model (submodule of decoder_model)
628
+ if self.decoder_model.conditional_model is not None and 'conditional_model_state_dict' in checkpoint:
629
+ _load_model_state_dict(self.decoder_model.conditional_model, checkpoint['conditional_model_state_dict'],
630
+ 'conditional_model')
631
+
632
+ # Load variance scheduler (submodule of decoder_model)
633
+ if 'variance_scheduler_state_dict' in checkpoint:
634
+ state_dict = checkpoint.get('variance_scheduler_state_dict')
635
+ try:
636
+ _load_model_state_dict(self.decoder_model.forward_diffusion.variance_scheduler, state_dict, 'variance_scheduler')
637
+ except Exception as e:
638
+ warnings.warn(f"Failed to load variance scheduler: {e}")
639
+
640
+ # Load CLIP time projection layer (submodule of decoder_model)
641
+ if 'clip_time_proj_state_dict' in checkpoint:
642
+ try:
643
+ _load_model_state_dict(self.decoder_model.clip_time_proj, checkpoint['clip_time_proj_state_dict'],
644
+ 'clip_time_proj')
645
+ except Exception as e:
646
+ warnings.warn(f"Failed to load CLIP time projection: {e}")
647
+
648
+ # Load decoder projection layer (submodule of decoder_model)
649
+ if 'decoder_projection_state_dict' in checkpoint:
650
+ try:
651
+ _load_model_state_dict(self.decoder_model.decoder_projection,
652
+ checkpoint['decoder_projection_state_dict'], 'decoder_projection')
653
+ except Exception as e:
654
+ warnings.warn(f"Failed to load decoder projection: {e}")
655
+
656
+ # Load projection models (PCA equivalent)
657
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
658
+ if 'text_projection_state_dict' in checkpoint:
659
+ _load_model_state_dict(self.clip_text_projection, checkpoint['text_projection_state_dict'],
660
+ 'text_projection')
661
+ if 'image_projection_state_dict' in checkpoint:
662
+ _load_model_state_dict(self.clip_image_projection, checkpoint['image_projection_state_dict'],
663
+ 'image_projection')
664
+
665
+ # Load optimizer
666
+ if 'optimizer_state_dict' in checkpoint:
667
+ try:
668
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
669
+ if self.master_process:
670
+ print("✓ Loaded optimizer")
671
+ except Exception as e:
672
+ warnings.warn(f"Failed to load optimizer state: {e}")
673
+
674
+ # Load schedulers
675
+ if 'scheduler_state_dict' in checkpoint:
676
+ try:
677
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
678
+ if self.master_process:
679
+ print("✓ Loaded main scheduler")
680
+ except Exception as e:
681
+ warnings.warn(f"Failed to load scheduler state: {e}")
682
+
683
+ if 'warmup_scheduler_state_dict' in checkpoint:
684
+ try:
685
+ self.warmup_lr_scheduler.load_state_dict(checkpoint['warmup_scheduler_state_dict'])
686
+ if self.master_process:
687
+ print("✓ Loaded warmup scheduler")
688
+ except Exception as e:
689
+ warnings.warn(f"Failed to load warmup scheduler state: {e}")
690
+
691
+ # Verify configuration compatibility
692
+ if 'embedding_dim' in checkpoint:
693
+ if checkpoint['embedding_dim'] != self.clip_embedding_dim:
694
+ warnings.warn(
695
+ f"Embedding dimension mismatch: checkpoint={checkpoint['embedding_dim']}, current={self.clip_embedding_dim}")
696
+
697
+ if 'reduce_dim' in checkpoint:
698
+ if checkpoint['reduce_dim'] != self.reduce_clip_embedding_dim:
699
+ warnings.warn(
700
+ f"Reduce dimension setting mismatch: checkpoint={checkpoint['reduce_dim']}, current={self.reduce_clip_embedding_dim}")
701
+
702
+ epoch = checkpoint.get('epoch', 0)
703
+ loss = checkpoint.get('loss', float('inf'))
704
+
705
+ if self.master_process:
706
+ print(f"Successfully loaded checkpoint from {checkpoint_path}")
707
+ print(f"Epoch: {epoch}, Loss: {loss:.4f}")
708
+
709
+ return epoch, loss
710
+
711
+ def validate(self) -> Tuple[float, Optional[float], Optional[float], Optional[float], Optional[float], Optional[float]]:
712
+ """Validates the UnCLIP decoder model.
713
+
714
+ Computes validation loss and optional metrics (FID, MSE, PSNR, SSIM, LPIPS) by
715
+ encoding images and texts, applying forward diffusion, predicting noise, and
716
+ reconstructing images through reverse diffusion.
717
+
718
+ Returns
719
+ -------
720
+ val_loss : float
721
+ Mean validation loss.
722
+ fid_avg : float or None
723
+ Average FID score, if computed.
724
+ mse_avg : float or None
725
+ Average MSE score, if computed.
726
+ psnr_avg : float or None
727
+ Average PSNR score, if computed.
728
+ ssim_avg : float or None
729
+ Average SSIM score, if computed.
730
+ lpips_avg : float or None
731
+ Average LPIPS score, if computed.
732
+ """
733
+
734
+ # set models to eval mode for evaluation
735
+ self.decoder_model.eval() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj, decoder_projection to eval mode
736
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
737
+ self.clip_text_projection.eval()
738
+ self.clip_image_projection.eval()
739
+ if self.clip_model is not None:
740
+ self.clip_model.eval()
741
+
742
+ val_losses = []
743
+ fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
744
+
745
+ with torch.no_grad():
746
+ for images, texts in self.val_loader:
747
+ images = images.to(self.device, non_blocking=True)
748
+ images_orig = images.clone()
749
+ text_embeddings, image_embeddings = self._get_clip_embeddings(images, texts)
750
+ text_embeddings, image_embeddings = self._apply_dimensionality_reduction(
751
+ text_embeddings, image_embeddings
752
+ )
753
+ p_classifier_free = torch.rand(1).item()
754
+ p_text_drop = torch.rand(1).item()
755
+ predicted_noise, noise = self.decoder_model(
756
+ image_embeddings,
757
+ text_embeddings,
758
+ images,
759
+ texts,
760
+ p_classifier_free,
761
+ p_text_drop
762
+ )
763
+ loss = self.objective(predicted_noise, noise)
764
+ val_losses.append(loss.item())
765
+
766
+ if self.metrics_ is not None and self.decoder_model.reverse_diffusion is not None:
767
+ xt = torch.randn_like(images).to(self.device)
768
+ for t in reversed(range(self.decoder_model.forward_diffusion.variance_scheduler.tau_num_steps)):
769
+ time_steps = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long)
770
+ prev_time_steps = torch.full((xt.shape[0],), max(t - 1, 0), device=self.device, dtype=torch.long)
771
+ image_embeddings = self.decoder_model._apply_classifier_free_guidance(image_embeddings, p_classifier_free)
772
+ text_embeddings = self.decoder_model._apply_text_dropout(text_embeddings, p_text_drop)
773
+ c = self.decoder_model.decoder_projection(image_embeddings) # updated to submodule
774
+ y_encoded = self.decoder_model._encode_text_with_glide(texts if text_embeddings is not None else None)
775
+ context = self.decoder_model._concatenate_embeddings(y_encoded, c)
776
+ clip_image_embedding = self.decoder_model.clip_time_proj(image_embeddings)
777
+ predicted_noise = self.decoder_model.noise_predictor(xt, time_steps, context, clip_image_embedding)
778
+ xt, _ = self.decoder_model.reverse_diffusion(xt, predicted_noise, time_steps, prev_time_steps)
779
+
780
+ x_hat = torch.clamp(xt, min=self.image_output_range[0], max=self.image_output_range[1])
781
+
782
+ if self.normalize:
783
+ x_hat = (x_hat - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
784
+ x_orig = (images_orig - self.image_output_range[0]) / (self.image_output_range[1] - self.image_output_range[0])
785
+
786
+ metrics_result = self.metrics_.forward(x_orig, x_hat)
787
+ fid = metrics_result[0] if getattr(self.metrics_, 'fid', False) else float('inf')
788
+ mse = metrics_result[1] if getattr(self.metrics_, 'metrics', False) else None
789
+ psnr = metrics_result[2] if getattr(self.metrics_, 'metrics', False) else None
790
+ ssim = metrics_result[3] if getattr(self.metrics_, 'metrics', False) else None
791
+ lpips_score = metrics_result[4] if getattr(self.metrics_, 'lpips', False) else None
792
+
793
+ if fid != float('inf'):
794
+ fid_scores.append(fid)
795
+ if mse is not None:
796
+ mse_scores.append(mse)
797
+ if psnr is not None:
798
+ psnr_scores.append(psnr)
799
+ if ssim is not None:
800
+ ssim_scores.append(ssim)
801
+ if lpips_score is not None:
802
+ lpips_scores.append(lpips_score)
803
+
804
+ # compute averages
805
+ val_loss = torch.tensor(val_losses).mean().item()
806
+ fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
807
+ mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
808
+ psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
809
+ ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
810
+ lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
811
+
812
+ # synchronize metrics across GPUs in DDP mode
813
+ if self.use_ddp:
814
+ metrics = [val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg]
815
+ metrics_tensors = [torch.tensor(m, device=self.device) if m is not None else torch.tensor(float('inf'), device=self.device) for m in metrics]
816
+ for tensor in metrics_tensors:
817
+ dist.all_reduce(tensor, op=dist.ReduceOp.AVG)
818
+ val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg = [t.item() if t.item() != float('inf') else (None if i > 1 else float('inf')) for i, t in enumerate(metrics_tensors)]
819
+
820
+ # return to training mode
821
+ self.decoder_model.train() # sets noise_predictor, conditional_model, variance_scheduler, clip_time_proj, decoder_projection to train mode
822
+ if not self.decoder_model.variance_scheduler.trainable_beta:
823
+ self.decoder_model.variance_scheduler.eval()
824
+ if self.reduce_clip_embedding_dim and self.clip_text_projection is not None and self.clip_image_projection is not None:
825
+ if self.finetune_clip_projections:
826
+ self.clip_text_projection.train()
827
+ self.clip_image_projection.train()
828
+ else:
829
+ self.clip_text_projection.eval()
830
+ self.clip_image_projection.eval()
831
+ if self.clip_model is not None:
832
+ self.clip_model.eval()
833
+
834
+ return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
835
+
836
+
837
+ """
838
+ from utils import NoisePredictor, TextEncoder, Metrics
839
+ from clip_model import CLIPEncoder
840
+ from torchvision import datasets, transforms
841
+ from torch.utils.data import DataLoader, Subset, Dataset
842
+ from project_prior import Projection
843
+ import torch
844
+ from prior_diff import VarianceSchedulerUnCLIP, ForwardUnCLIP, ReverseUnCLIP
845
+ from decoder_model import UnClipDecoder
846
+
847
+
848
+ class CIFAR10WithCaptions(Dataset):
849
+ def __init__(self, cifar_dataset):
850
+ self.dataset = cifar_dataset
851
+ self.class_names = [
852
+ 'airplane', 'automobile', 'bird', 'cat', 'deer',
853
+ 'dog', 'frog', 'horse', 'ship', 'truck'
854
+ ]
855
+ # More descriptive templates
856
+ self.templates = [
857
+ "A photo of a {}",
858
+ "An image of a {}",
859
+ "A picture of a {}",
860
+ "This is a {}",
861
+ ]
862
+
863
+ def __len__(self):
864
+ return len(self.dataset)
865
+
866
+ def __getitem__(self, idx):
867
+ image, label = self.dataset[idx]
868
+ class_name = self.class_names[label]
869
+ # Use different templates for variety
870
+ template = self.templates[idx % len(self.templates)]
871
+ caption = template.format(class_name)
872
+ return image, caption
873
+
874
+
875
+
876
+ # Updated transforms for CLIP
877
+ transform = transforms.Compose([
878
+ transforms.Resize((224, 224)),
879
+ transforms.ToTensor(),
880
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
881
+ ])
882
+
883
+ # Load CIFAR-10 with captions
884
+ cifar_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
885
+ cifar_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
886
+
887
+ train_dataset = CIFAR10WithCaptions(cifar_train)
888
+ test_dataset = CIFAR10WithCaptions(cifar_test)
889
+
890
+ # Small subset for testing
891
+ train_subset_indices = torch.randperm(len(train_dataset))[:4]
892
+ test_subset_indices = torch.randperm(len(test_dataset))[:2]
893
+ train_subset = Subset(train_dataset, train_subset_indices)
894
+ test_subset = Subset(test_dataset, test_subset_indices)
895
+
896
+ # DataLoaders
897
+ t_loader = DataLoader(train_subset, batch_size=2, shuffle=True, pin_memory=True)
898
+ v_loader = DataLoader(test_subset, batch_size=1, shuffle=False, pin_memory=True)
899
+
900
+ d = torch.device("cuda")
901
+
902
+ n_model = NoisePredictor(
903
+ in_channels=3,
904
+ down_channels=[16, 32],
905
+ mid_channels=[32, 32],
906
+ up_channels=[32, 16],
907
+ down_sampling=[True, True],
908
+ time_embed_dim=32,
909
+ y_embed_dim=32,
910
+ num_down_blocks=2,
911
+ num_mid_blocks=2,
912
+ num_up_blocks=2,
913
+ down_sampling_factor=2
914
+ )
915
+
916
+
917
+ c_model = CLIPEncoder(
918
+ model_name="openai/clip-vit-base-patch32",
919
+ device="cuda",
920
+ use_fast=False
921
+ )
922
+
923
+
924
+ t_proj = Projection(
925
+ input_dim=512,
926
+ output_dim=32,
927
+ hidden_dim=128,
928
+ num_layers=2,
929
+ dropout=0.1,
930
+ use_layer_norm=True
931
+ )
932
+ i_proj = Projection(
933
+ input_dim=512,
934
+ output_dim=32,
935
+ hidden_dim=128,
936
+ num_layers=2,
937
+ dropout=0.1,
938
+ use_layer_norm=True
939
+ )
940
+
941
+ h_model = VarianceSchedulerUnCLIP(
942
+ num_steps=500,
943
+ beta_start=1e-4,
944
+ beta_end=0.02,
945
+ trainable_beta=False,
946
+ beta_method="linear"
947
+ )
948
+ for_ = ForwardUnCLIP(h_model)
949
+ rev_ = ReverseUnCLIP(h_model)
950
+
951
+ cond = TextEncoder(
952
+ use_pretrained_model=True,
953
+ model_name="bert-base-uncased",
954
+ vocabulary_size=30522,
955
+ num_layers=2,
956
+ input_dimension=32,
957
+ output_dimension=32,
958
+ num_heads=2,
959
+ context_length=77
960
+ ).to(d)
961
+
962
+ decoder = UnClipDecoder(
963
+ embedding_dim=32,
964
+ noise_predictor=n_model,
965
+ forward_diffusion=for_,
966
+ reverse_diffusion=rev_,
967
+ conditional_model=cond, # GLIDE text encoder
968
+ tokenizer=None,
969
+ device="cpu",
970
+ output_range=(-1.0, 1.0),
971
+ normalize=True,
972
+ classifier_free=0.1, # paper specifies 10%
973
+ drop_caption=0.5, # paper specifies 50%
974
+ max_length=77
975
+ )
976
+
977
+ opt = torch.optim.AdamW([p for p in decoder.parameters() if p.requires_grad], lr=1e-3)
978
+
979
+
980
+ obj = nn.MSELoss()
981
+
982
+ mets = Metrics(
983
+ device="cpu",
984
+ fid=True,
985
+ metrics=True,
986
+ lpips_=True
987
+ )
988
+
989
+
990
+ model = TrainUnClipDecoder(
991
+ embedding_dim=512,
992
+ decoder_model=decoder,
993
+ clip_model=c_model,
994
+ train_loader=t_loader,
995
+ optimizer=opt,
996
+ objective=obj,
997
+ text_projection=t_proj,
998
+ image_projection=i_proj,
999
+ val_loader=v_loader,
1000
+ metrics_=mets,
1001
+ max_epoch=5,
1002
+ device="cuda",
1003
+ store_path="unclip_decoder",
1004
+ patience=5,
1005
+ warmup_epochs=2,
1006
+ val_frequency=10,
1007
+ use_ddp=False,
1008
+ num_grad_accumulation=1,
1009
+ progress_frequency=1,
1010
+ compilation=False,
1011
+ output_range=(-1.0, 1.0),
1012
+ reduce_dim=True,
1013
+ output_dim=32,
1014
+ normalize=True,
1015
+ finetune_projections=False
1016
+ )
1017
+
1018
+ # Ensure requires_grad is set correctly
1019
+ for p in model.clip_model.parameters():
1020
+ p.requires_grad = False
1021
+ if not model.finetune_projections:
1022
+ for p in model.text_projection.parameters():
1023
+ p.requires_grad = False
1024
+ for p in model.image_projection.parameters():
1025
+ p.requires_grad = False
1026
+ if not model.decoder_model.forward_diffusion.variance_scheduler.trainable_beta:
1027
+ for p in model.decoder_model.forward_diffusion.variance_scheduler.parameters():
1028
+ p.requires_grad = False
1029
+
1030
+ # Run training
1031
+ one, two = model()
1032
+
1033
+ # Count trainable parameters
1034
+ def count_trainable_parameters(model, finetune_projections=False):
1035
+ total_params = 0
1036
+ total_params += sum(p.numel() for p in model.decoder_model.parameters() if p.requires_grad)
1037
+ if finetune_projections and model.text_projection is not None and model.image_projection is not None:
1038
+ total_params += sum(p.numel() for p in model.text_projection.parameters() if p.requires_grad)
1039
+ total_params += sum(p.numel() for p in model.image_projection.parameters() if p.requires_grad)
1040
+ return total_params
1041
+
1042
+ # Case 1: finetune_projections=False, train_projection=False
1043
+ print("Trainable parameters (finetune_projections=False, train_projection=False):")
1044
+ total_params_false = count_trainable_parameters(model, finetune_projections=False)
1045
+ print(f"Total trainable parameters: {total_params_false}")
1046
+
1047
+ # Case 2: finetune_projections=True, train_projection=True
1048
+ model.finetune_projections = True
1049
+ for p in model.text_projection.parameters():
1050
+ p.requires_grad = True
1051
+ for p in model.image_projection.parameters():
1052
+ p.requires_grad = True
1053
+
1054
+ print("\nTrainable parameters (finetune_projections=True, train_projection=True):")
1055
+ total_params_true = count_trainable_parameters(model, finetune_projections=True)
1056
+ print(f"Total trainable parameters: {total_params_true}")
1057
+
1058
+ print("After parameters count")
1059
+ """