TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
unclip/train_prior.py ADDED
@@ -0,0 +1,757 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, List, Tuple, Union, Callable
4
+ from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
5
+ import torch.distributed as dist
6
+ from torch.nn.parallel import DistributedDataParallel as DDP
7
+ from torch.distributed import init_process_group, destroy_process_group
8
+ from tqdm import tqdm
9
+ import warnings
10
+ import os
11
+
12
+
13
+
14
+ class TrainUnCLIPPrior(nn.Module):
15
+ """Trainer for the UnCLIPTransformerPrior model.
16
+
17
+ Handles the training of the UnCLIP prior model to predict clean image embeddings from
18
+ noisy image embeddings and text embeddings, with support for dimension reduction,
19
+ mixed precision training, and distributed training.
20
+
21
+ Parameters
22
+ ----------
23
+ `prior_model` : nn.Module
24
+ The UnCLIP prior model to be trained (e.g., UnCLIPTransformerPrior).
25
+ `clip_model` : nn.Module
26
+ CLIP model for encoding text and images.
27
+ `train_loader` : torch.utils.data.DataLoader
28
+ DataLoader for training data.
29
+ `optimizer` : torch.optim.Optimizer
30
+ Optimizer for training the prior model.
31
+ `objective` : Callable
32
+ Loss function to compute the difference between predicted and target embeddings.
33
+ `val_loader` : torch.utils.data.DataLoader, optional
34
+ DataLoader for validation data, default None.
35
+ `max_epochs` : int, optional
36
+ Maximum number of training epochs (default: 1000).
37
+ `device` : Union[str, torch.device], optional
38
+ Device for computation (default: CUDA if available, else CPU).
39
+ `store_path` : str, optional
40
+ Directory path to save model checkpoints, default None.
41
+ `patience` : int, optional
42
+ Number of epochs to wait for improvement before early stopping (default: 100).
43
+ `warmup_epochs` : int, optional
44
+ Number of epochs for learning rate warmup (default: 100).
45
+ `val_frequency` : int, optional
46
+ Frequency (in epochs) for validation (default: 10).
47
+ `use_ddp` : bool, optional
48
+ Whether to use Distributed Data Parallel training (default: False).
49
+ `num_grad_accumulation` : int, optional
50
+ Number of gradient accumulation steps before optimizer update (default: 1).
51
+ `log_frequency` : int, optional
52
+ Frequency (in epochs) for printing training progress (default: 1).
53
+ `use_compilation` : bool, optional
54
+ Whether to compile models for optimization (default: False).
55
+ `embedding_output_range` : Tuple[float, float], optional
56
+ Range for clamping output embeddings (default: (-1.0, 1.0)).
57
+ `reduce_clip_embedding_dim` : bool, optional
58
+ Whether to apply dimension reduction to embeddings (default: True).
59
+ `transformer_embedding_dim` : int, optional
60
+ Target dimensionality for reduced embeddings (default: 319).
61
+ `normalize` : bool, optional
62
+ Whether to normalize CLIP embeddings (default: True).
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ prior_model: nn.Module,
68
+ clip_model: nn.Module,
69
+ train_loader: torch.utils.data.DataLoader,
70
+ optimizer: torch.optim.Optimizer,
71
+ objective: Callable,
72
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
73
+ max_epochs: int = 1000,
74
+ device: Optional[Union[str, torch.device]] = None,
75
+ store_path: Optional[str] = None,
76
+ patience: int = 100,
77
+ warmup_epochs: int = 100,
78
+ val_frequency: int = 10,
79
+ use_ddp: bool = False,
80
+ grad_accumulation_steps: int = 1,
81
+ log_frequency: int = 1,
82
+ use_compilation: bool = False,
83
+ embedding_output_range: Tuple[float, float] = (-1.0, 1.0),
84
+ reduce_clip_embedding_dim: bool = True,
85
+ transformer_embedding_dim: int = 319,
86
+ normalize_clip_embeddings: bool = True
87
+ ) -> None:
88
+ super().__init__()
89
+
90
+ # Training configuration
91
+ self.use_ddp = use_ddp
92
+ self.grad_accumulation_steps = grad_accumulation_steps
93
+ if device is None:
94
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
+ elif isinstance(device, str):
96
+ self.device = torch.device(device)
97
+ else:
98
+ self.device = device
99
+
100
+ # Setup distributed training
101
+ if self.use_ddp:
102
+ self._setup_ddp()
103
+ else:
104
+ self._setup_single_gpu()
105
+
106
+ # Core models
107
+ self.prior_model = prior_model.to(self.device)
108
+ self.clip_model = clip_model.to(self.device)
109
+
110
+ # Training components
111
+ self.optimizer = optimizer
112
+ self.objective = objective
113
+ self.train_loader = train_loader
114
+ self.val_loader = val_loader
115
+
116
+ # Training parameters
117
+ self.max_epochs = max_epochs
118
+ self.patience = patience
119
+ self.val_frequency = val_frequency
120
+ self.log_frequency = log_frequency
121
+ self.use_compilation = use_compilation
122
+ self.embedding_output_range = embedding_output_range
123
+ self.reduce_clip_embedding_dim = reduce_clip_embedding_dim
124
+ self.normalize_clip_embeddings = normalize_clip_embeddings
125
+ self.transformer_embedding_dim = transformer_embedding_dim
126
+
127
+ # Checkpoint management
128
+ self.store_path = store_path
129
+ # os.makedirs(self.store_path, exist_ok=True)
130
+
131
+ # Learning rate scheduling
132
+ self.scheduler = ReduceLROnPlateau(
133
+ self.optimizer,
134
+ patience=self.patience,
135
+ factor=0.5
136
+ )
137
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
138
+
139
+
140
+ def _setup_ddp(self) -> None:
141
+ """Sets up Distributed Data Parallel training configuration.
142
+
143
+ Initializes the process group, sets up rank information, and configures the CUDA
144
+ device for the current process.
145
+
146
+ Raises
147
+ ------
148
+ ValueError
149
+ If required DDP environment variables (RANK, LOCAL_RANK, WORLD_SIZE) are not set.
150
+ RuntimeError
151
+ If CUDA is not available when DDP is enabled.
152
+ """
153
+
154
+ required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
155
+ for var in required_env_vars:
156
+ if var not in os.environ:
157
+ raise ValueError(f"DDP enabled but {var} environment variable not set")
158
+
159
+ # Ensure CUDA is available for DDP
160
+ if not torch.cuda.is_available():
161
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
162
+
163
+ # Initialize process group only if not already initialized
164
+ if not torch.distributed.is_initialized():
165
+ init_process_group(backend="nccl")
166
+
167
+ # Get rank information
168
+ self.ddp_rank = int(os.environ["RANK"]) # Global rank across all nodes
169
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # Local rank on current node
170
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # Total number of processes
171
+
172
+ # Set device and make it current
173
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
174
+ # self.device = f"cuda:{self.ddp_local_rank}"
175
+ torch.cuda.set_device(self.device)
176
+
177
+ # Master process handles logging, checkpointing, etc.
178
+ self.master_process = self.ddp_rank == 0
179
+
180
+ if self.master_process:
181
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
182
+
183
+
184
+ def _setup_single_gpu(self) -> None:
185
+ """Sets up single GPU or CPU training configuration.
186
+
187
+ Configures the training setup for single-device operation, setting rank and process
188
+ information for non-DDP training.
189
+ """
190
+ self.ddp_rank = 0
191
+ self.ddp_local_rank = 0
192
+ self.ddp_world_size = 1
193
+ self.master_process = True
194
+
195
+ @staticmethod
196
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
197
+ """Creates a learning rate scheduler for warmup.
198
+
199
+ Generates a scheduler that linearly increases the learning rate from 0 to the
200
+ optimizer's initial value over the specified warmup epochs.
201
+
202
+ Parameters
203
+ ----------
204
+ `optimizer` : torch.optim.Optimizer
205
+ Optimizer to apply the scheduler to.
206
+ `warmup_epochs` : int
207
+ Number of epochs for the warmup phase.
208
+
209
+ Returns
210
+ -------
211
+ lr_scheduler : torch.optim.lr_scheduler.LambdaLR
212
+ Learning rate scheduler for warmup.
213
+ """
214
+ def lr_lambda(epoch):
215
+ return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
216
+ return LambdaLR(optimizer, lr_lambda)
217
+
218
+ def _wrap_models_for_ddp(self) -> None:
219
+ """Wraps the prior model with DistributedDataParallel for multi-GPU training.
220
+
221
+ Configures the prior model for DDP, setting device IDs and handling unused parameters.
222
+ """
223
+ if self.use_ddp:
224
+ # Wrap prior with DDP
225
+ self.prior_model = DDP(
226
+ self.prior_model,
227
+ device_ids=[self.ddp_local_rank],
228
+ find_unused_parameters=True
229
+ )
230
+
231
+ def _compile_models(self) -> None:
232
+ """Compiles models for optimization if supported.
233
+
234
+ Attempts to compile the prior model using torch.compile for performance optimization,
235
+ with fallback to uncompiled models if compilation fails.
236
+ """
237
+ if self.use_compilation:
238
+ try:
239
+ self.prior_model = torch.compile(self.prior_model)
240
+
241
+ if self.master_process:
242
+ print("Models compiled successfully")
243
+ except Exception as e:
244
+ if self.master_process:
245
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
246
+
247
+ def forward(self) -> Tuple[List[float], float]:
248
+ """Trains the UnCLIP prior model.
249
+
250
+ Executes the training loop, optimizing the prior model to predict clean image embeddings
251
+ from noisy embeddings and text conditions, with support for validation, early stopping,
252
+ and checkpointing.
253
+
254
+ Returns
255
+ -------
256
+ train_losses : List[float]
257
+ List of mean training losses per epoch.
258
+ best_val_loss : float
259
+ Best validation or training loss achieved.
260
+ """
261
+ # Set models to training mode
262
+ self.prior_model.train()
263
+
264
+ # Compile and wrap models
265
+ self._compile_models()
266
+ self._wrap_models_for_ddp()
267
+
268
+ # Initialize training components
269
+ scaler = torch.GradScaler()
270
+ train_losses = []
271
+ best_val_loss = float("inf")
272
+ wait = 0
273
+
274
+ # Main training loop
275
+ for epoch in range(self.max_epochs):
276
+ # Set epoch for distributed sampler if using DDP
277
+ if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
278
+ self.train_loader.sampler.set_epoch(epoch)
279
+
280
+ train_losses_epoch = []
281
+
282
+ # Training step loop with gradient accumulation
283
+ for step, (x, y) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
284
+ x = x.to(self.device, non_blocking=True)
285
+
286
+ # Forward pass with mixed precision
287
+ with torch.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
288
+ loss = self._compute_training_loss(x, y)
289
+ loss = loss / self.grad_accumulation_steps
290
+
291
+ # Backward pass - ONLY ONCE!
292
+ scaler.scale(loss).backward()
293
+
294
+ # Optimizer step with gradient accumulation
295
+ if (step + 1) % self.grad_accumulation_steps == 0:
296
+ self._optimizer_step(scaler)
297
+ # Update learning rate (warmup scheduler)
298
+ self.warmup_lr_scheduler.step()
299
+
300
+ # Record loss (unscaled)
301
+ train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
302
+
303
+ # Compute and sync training loss
304
+ mean_train_loss = self._compute_mean_loss(train_losses_epoch)
305
+ train_losses.append(mean_train_loss)
306
+
307
+ # Print training progress (only master process)
308
+ if self.master_process and (epoch + 1) % self.log_frequency == 0:
309
+ current_lr = self.optimizer.param_groups[0]['lr']
310
+ print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}", end="")
311
+
312
+ # Validation and checkpointing
313
+ current_loss = mean_train_loss
314
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
315
+ val_loss = self.validate()
316
+ current_loss = val_loss
317
+
318
+ if self.master_process:
319
+ print(f" | Val Loss: {val_loss:.4f}")
320
+ elif self.master_process:
321
+ print()
322
+
323
+ # Learning rate scheduling
324
+ self.scheduler.step(current_loss)
325
+
326
+ # Save checkpoint and early stopping
327
+ if self.master_process:
328
+ if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
329
+ best_val_loss = current_loss
330
+ wait = 0
331
+ self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
332
+ else:
333
+ wait += 1
334
+ if wait >= self.patience:
335
+ print("Early stopping triggered")
336
+ self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
337
+ break
338
+
339
+ # Cleanup
340
+ if self.use_ddp:
341
+ destroy_process_group()
342
+
343
+ return train_losses, best_val_loss
344
+
345
+
346
+ def _compute_training_loss(self, images: torch.Tensor, texts: List[str]) -> torch.Tensor:
347
+ """Computes the training loss for the UnCLIP prior model.
348
+
349
+ Calculates the loss by encoding images and text with CLIP, applying forward diffusion,
350
+ predicting clean embeddings, and comparing with target embeddings.
351
+
352
+ Parameters
353
+ ----------
354
+ `images` : torch.Tensor
355
+ Input images, shape (batch_size, channels, height, width).
356
+ `texts` : List[str]
357
+ List of text prompts for conditioning.
358
+
359
+ Returns
360
+ -------
361
+ loss : torch.Tensor
362
+ Loss value computed between predicted and target embeddings.
363
+ """
364
+
365
+ with torch.no_grad():
366
+ # Encode text and image with CLIP
367
+ text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize_clip_embeddings)
368
+ image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize_clip_embeddings)
369
+
370
+ #print("encoded images: ", image_embeddings.size())
371
+ #print("encoded text: ", text_embeddings.size())
372
+
373
+ # Reduce dimensionality (optional)
374
+ if self.reduce_clip_embedding_dim:
375
+ text_embeddings = self.prior_model.clip_text_projection(text_embeddings)
376
+ image_embeddings = self.prior_model.clip_image_projection(image_embeddings)
377
+ #print("encoded images: ", image_embeddings.size())
378
+ #print("encoded text: ", text_embeddings.size())
379
+
380
+ # Sample timestep t ~ Uniform(1, T)
381
+ batch_size = image_embeddings.shape[0]
382
+ timesteps = torch.randint(0, self.prior_model.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
383
+ #print("time ", timesteps.size())
384
+
385
+ # Sample noise ε ~ N(0, I)
386
+ noise = torch.randn_like(image_embeddings)
387
+ #print("noise ", noise.size())
388
+
389
+ # Compute noised embedding z_{i,t}
390
+ noisy_image_embeddings = self.prior_model.forward_diffusion(image_embeddings, noise, timesteps)
391
+ #print("noisy image: ", noisy_image_embeddings.size())
392
+
393
+ # Predict unnoised embedding ẑ_i
394
+ predicted_image_embeddings = self.prior_model(text_embeddings, noisy_image_embeddings, timesteps)
395
+
396
+ # Transform back to original space if using dimension reduction
397
+ if self.reduce_clip_embedding_dim:
398
+ predicted_image_embeddings = self.prior_model.image_projection.inverse_transform(predicted_image_embeddings)
399
+ target_embeddings = self.prior_model.image_projection.inverse_transform(image_embeddings)
400
+ else:
401
+ target_embeddings = image_embeddings
402
+
403
+ # Compute loss L = ||ẑ_i - z_i||²
404
+ loss = self.objective(predicted_image_embeddings, target_embeddings)
405
+ return loss
406
+
407
+ def _optimizer_step(self, scaler: torch.GradScaler) -> None:
408
+ """Performs an optimizer step with gradient clipping.
409
+
410
+ Applies gradient clipping, updates the optimizer with scaled gradients, and resets
411
+ gradients for the next iteration.
412
+
413
+ Parameters
414
+ ----------
415
+ `scaler` : torch.GradScaler
416
+ Gradient scaler for mixed precision training.
417
+ """
418
+ scaler.unscale_(self.optimizer)
419
+
420
+ # Gradient clipping
421
+ torch.nn.utils.clip_grad_norm_(self.prior_model.parameters(), max_norm=1.0)
422
+
423
+ scaler.step(self.optimizer)
424
+ scaler.update()
425
+ self.optimizer.zero_grad()
426
+
427
+ def _compute_mean_loss(self, losses: List[float]) -> float:
428
+ """Computes the mean loss and synchronizes across processes if using DDP.
429
+
430
+ Calculates the mean of the provided loss values and performs an all-reduce operation
431
+ in DDP mode to synchronize the loss across processes.
432
+
433
+ Parameters
434
+ ----------
435
+ `losses` : List[float]
436
+ List of loss values from a training or validation epoch.
437
+
438
+ Returns
439
+ -------
440
+ mean_loss : float
441
+ Mean loss value, synchronized across processes if DDP is enabled.
442
+ """
443
+ mean_loss = torch.tensor(losses).mean().item()
444
+
445
+ if self.use_ddp:
446
+ loss_tensor = torch.tensor(mean_loss, device=self.device)
447
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
448
+ mean_loss = loss_tensor.item()
449
+
450
+ return mean_loss
451
+
452
+
453
+ def validate(self) -> float:
454
+ """Validates the UnCLIP prior model.
455
+
456
+ Computes the validation loss by encoding images and text, applying forward diffusion,
457
+ predicting clean embeddings, and comparing with target embeddings.
458
+
459
+ Returns
460
+ -------
461
+ val_loss : float
462
+ Mean validation loss, synchronized across processes if DDP is enabled.
463
+ """
464
+
465
+ self.prior_model.eval()
466
+
467
+ val_losses = []
468
+
469
+ with torch.no_grad():
470
+ for images, texts in self.val_loader:
471
+ images = images.to(self.device, non_blocking=True)
472
+
473
+ # Get embeddings
474
+ text_embeddings = self.clip_model(data=texts, data_type="text", normalize=self.normalize_clip_embeddings)
475
+ image_embeddings = self.clip_model(data=images, data_type="img", normalize=self.normalize_clip_embeddings)
476
+ original_image_embeddings = image_embeddings.clone()
477
+
478
+ if self.reduce_clip_embedding_dim:
479
+ text_embeddings = self.prior_model.text_projection(text_embeddings)
480
+ image_embeddings = self.prior_model.image_projection(image_embeddings)
481
+
482
+ # Forward diffusion
483
+ batch_size = image_embeddings.shape[0]
484
+ timesteps = torch.randint(0, self.prior_model.forward_diffusion.variance_scheduler.num_steps, (batch_size,), device=self.device)
485
+ noise = torch.randn_like(image_embeddings)
486
+ noisy_image_embeddings = self.prior_model.forward_diffusion(image_embeddings, noise, timesteps)
487
+
488
+ # Predict
489
+ predicted_embeddings = self.prior_model(text_embeddings, noisy_image_embeddings, timesteps)
490
+
491
+ if self.reduce_clip_embedding_dim:
492
+ predicted_embeddings = self.prior_model.image_projection.inverse_transform(predicted_embeddings)
493
+
494
+ # Compute loss
495
+ loss = self.objective(predicted_embeddings, original_image_embeddings)
496
+ val_losses.append(loss.item())
497
+
498
+
499
+ # Compute averages
500
+ val_loss = self._compute_mean_loss(val_losses)
501
+
502
+ # Return to training mode
503
+ self.prior_model.train()
504
+
505
+ return val_loss
506
+
507
+
508
+ def _save_checkpoint(self, epoch: int, loss: float, suffix: str = "", is_best: bool = False) -> None:
509
+ """Saves a model checkpoint.
510
+
511
+ Saves the state of the prior model and optimizer to a checkpoint file, with options
512
+ for best model or early stopping checkpoints.
513
+
514
+ Parameters
515
+ ----------
516
+ `epoch` : int
517
+ Current epoch number.
518
+ `loss` : float
519
+ Current loss value.
520
+ `suffix` : str, optional
521
+ Suffix to append to the checkpoint filename, default "".
522
+ `is_best` : bool, optional
523
+ Whether to save the checkpoint as the best model, default False.
524
+ """
525
+ try:
526
+ # Get state dicts
527
+ prior_state = (
528
+ self.prior_model.module.state_dict() if self.use_ddp
529
+ else self.prior_model.state_dict()
530
+ )
531
+
532
+ checkpoint = {
533
+ 'epoch': epoch,
534
+ 'prior_model_state_dict': prior_state,
535
+ 'optimizer_state_dict': self.optimizer.state_dict(),
536
+ 'loss': loss,
537
+ 'max_epochs': self.max_epochs,
538
+ }
539
+
540
+ # Create the directory if it doesn't exist
541
+ os.makedirs(self.store_path, exist_ok=True)
542
+
543
+ # Define the checkpoint filename
544
+ if is_best:
545
+ filename = "best_model.pth"
546
+ else:
547
+ filename = f"checkpoint_epoch_{epoch}{suffix}.pth"
548
+
549
+ # Construct the full save path
550
+ save_path = os.path.join(self.store_path, filename)
551
+
552
+ # Save checkpoint
553
+ torch.save(checkpoint, save_path)
554
+ if self.master_process: # Only print from the master process in DDP
555
+ print(f"Checkpoint saved: {save_path}")
556
+
557
+ except Exception as e:
558
+ print(f"Failed to save checkpoint: {e}")
559
+
560
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
561
+ """Loads a model checkpoint to resume training.
562
+
563
+ Restores the prior model and optimizer states from a saved checkpoint, handling
564
+ DDP compatibility for state dictionaries.
565
+
566
+ Parameters
567
+ ----------
568
+ `checkpoint_path` : str
569
+ Path to the checkpoint file.
570
+
571
+ Returns
572
+ -------
573
+ epoch : int
574
+ The epoch at which the checkpoint was saved.
575
+ loss : float
576
+ The loss value at the checkpoint.
577
+ """
578
+ try:
579
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
580
+ except FileNotFoundError:
581
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
582
+
583
+ # Load prior model
584
+ if 'prior_model_state_dict' in checkpoint:
585
+ state_dict = checkpoint['prior_model_state_dict']
586
+
587
+ # Handle DDP state dict compatibility
588
+ if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
589
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
590
+ elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
591
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
592
+
593
+ self.prior_model.load_state_dict(state_dict)
594
+
595
+ # Load optimizer
596
+ if 'optimizer_state_dict' in checkpoint:
597
+ try:
598
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
599
+ except Exception as e:
600
+ warnings.warn(f"Failed to load optimizer state: {e}")
601
+
602
+ epoch = checkpoint.get('epoch', 0)
603
+ loss = checkpoint.get('loss', float('inf'))
604
+
605
+ if self.master_process:
606
+ print(f"Loaded checkpoint from {checkpoint_path} (epoch {epoch}, loss {loss:.4f})")
607
+
608
+ return epoch, loss
609
+
610
+
611
+
612
+ """
613
+ from prior_diff import ForwardUnCLIP, ReverseUnCLIP, VarianceSchedulerUnCLIP
614
+ from prior_model import UnCLIPTransformerPrior
615
+ from clip_model import CLIPEncoder
616
+ from project_prior import Projection
617
+ from torchvision import datasets, transforms
618
+ from torch.utils.data import DataLoader, Subset, Dataset
619
+ import torch
620
+
621
+
622
+ # Option 2A: Use CIFAR-10 with descriptive captions
623
+ class CIFAR10WithCaptions(Dataset):
624
+ def __init__(self, cifar_dataset):
625
+ self.dataset = cifar_dataset
626
+ self.class_names = [
627
+ 'airplane', 'automobile', 'bird', 'cat', 'deer',
628
+ 'dog', 'frog', 'horse', 'ship', 'truck'
629
+ ]
630
+ # More descriptive templates
631
+ self.templates = [
632
+ "A photo of a {}",
633
+ "An image of a {}",
634
+ "A picture of a {}",
635
+ "This is a {}",
636
+ ]
637
+
638
+ def __len__(self):
639
+ return len(self.dataset)
640
+
641
+ def __getitem__(self, idx):
642
+ image, label = self.dataset[idx]
643
+ class_name = self.class_names[label]
644
+ # Use different templates for variety
645
+ template = self.templates[idx % len(self.templates)]
646
+ caption = template.format(class_name)
647
+ return image, caption
648
+
649
+
650
+ # Updated transforms for CLIP
651
+ transform = transforms.Compose([
652
+ transforms.Resize((224, 224)),
653
+ transforms.ToTensor(),
654
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
655
+ ])
656
+
657
+ # Load CIFAR-10 with captions
658
+ cifar_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
659
+ cifar_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
660
+
661
+ train_dataset = CIFAR10WithCaptions(cifar_train)
662
+ test_dataset = CIFAR10WithCaptions(cifar_test)
663
+
664
+ # Small subset for testing
665
+ train_subset_indices = torch.randperm(len(train_dataset))[:100]
666
+ test_subset_indices = torch.randperm(len(test_dataset))[:20]
667
+
668
+ train_subset = Subset(train_dataset, train_subset_indices)
669
+ test_subset = Subset(test_dataset, test_subset_indices)
670
+
671
+ # DataLoaders
672
+ t_loader = DataLoader(train_subset, batch_size=32, shuffle=True, pin_memory=True)
673
+ val = DataLoader(test_subset, batch_size=10, shuffle=False, pin_memory=True)
674
+
675
+ h_model = VarianceSchedulerUnCLIP(
676
+ num_steps=1000,
677
+ beta_start=1e-4,
678
+ beta_end=0.02,
679
+ trainable_beta=True,
680
+ beta_method="cosine"
681
+ )
682
+
683
+ c_model = CLIPEncoder(model_name="openai/clip-vit-base-patch32")
684
+ tp = Projection(
685
+ input_dim=512,
686
+ output_dim=320,
687
+ hidden_dim=480,
688
+ num_layers=2,
689
+ dropout=0.1,
690
+ use_layer_norm=True
691
+ )
692
+ ip = Projection(
693
+ input_dim=512,
694
+ output_dim=320,
695
+ hidden_dim=480,
696
+ num_layers=2,
697
+ dropout=0.1,
698
+ use_layer_norm=True
699
+ )
700
+
701
+ d_model = ForwardUnCLIP(h_model)
702
+ r_model = ReverseUnCLIP(h_model)
703
+
704
+ p_model = UnCLIPTransformerPrior(
705
+ forward_diffusion=d_model,
706
+ reverse_diffusion=r_model, # will be used during training
707
+ text_projection=tp, # used during training instead of PCA in the main paper
708
+ image_projection=ip,
709
+ embedding_dim=320,
710
+ num_layers=12,
711
+ num_attention_heads=8,
712
+ feedforward_dim=512,
713
+ max_sequence_length=2,
714
+ dropout_rate=0.3
715
+ )
716
+
717
+
718
+
719
+ opt = torch.optim.AdamW([p for p in p_model.parameters() if p.requires_grad], lr=1e-3)
720
+
721
+ models = [h_model, p_model, tp, ip]
722
+
723
+ total_params = 0
724
+ for model in models:
725
+ total_params += sum(p.numel() for p in model.parameters() if p.requires_grad)
726
+ print(total_params)
727
+
728
+ obj = nn.MSELoss()
729
+
730
+
731
+
732
+ train = TrainUnCLIPPrior(
733
+ prior_model=p_model,
734
+ clip_model=c_model,
735
+ train_loader=t_loader,
736
+ optimizer=opt,
737
+ objective=obj,
738
+ val_loader=val,
739
+ max_epochs=5,
740
+ device="cuda",
741
+ store_path="prior",
742
+ patience=3,
743
+ warmup_epochs=2,
744
+ val_frequency=3,
745
+ use_ddp=False,
746
+ num_grad_accumulation=2,
747
+ progress_frequency=1,
748
+ compilation=False,
749
+ output_range=(-1.0, 1.0),
750
+ reduce_dim=True,
751
+ output_dim=320,
752
+ normalize=True
753
+ )
754
+
755
+ train_losses, best_val_loss = train()
756
+ """
757
+