TorchDiff 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. ddim/__init__.py +0 -0
  2. ddim/forward_ddim.py +79 -0
  3. ddim/hyper_param.py +225 -0
  4. ddim/noise_predictor.py +521 -0
  5. ddim/reverse_ddim.py +91 -0
  6. ddim/sample_ddim.py +219 -0
  7. ddim/text_encoder.py +152 -0
  8. ddim/train_ddim.py +394 -0
  9. ddpm/__init__.py +0 -0
  10. ddpm/forward_ddpm.py +89 -0
  11. ddpm/hyper_param.py +180 -0
  12. ddpm/noise_predictor.py +521 -0
  13. ddpm/reverse_ddpm.py +102 -0
  14. ddpm/sample_ddpm.py +213 -0
  15. ddpm/text_encoder.py +152 -0
  16. ddpm/train_ddpm.py +386 -0
  17. ldm/__init__.py +0 -0
  18. ldm/autoencoder.py +855 -0
  19. ldm/forward_idm.py +100 -0
  20. ldm/hyper_param.py +239 -0
  21. ldm/metrics.py +206 -0
  22. ldm/noise_predictor.py +1074 -0
  23. ldm/reverse_ldm.py +119 -0
  24. ldm/sample_ldm.py +254 -0
  25. ldm/text_encoder.py +429 -0
  26. ldm/train_autoencoder.py +216 -0
  27. ldm/train_ldm.py +412 -0
  28. sde/__init__.py +0 -0
  29. sde/forward_sde.py +98 -0
  30. sde/hyper_param.py +200 -0
  31. sde/noise_predictor.py +521 -0
  32. sde/reverse_sde.py +115 -0
  33. sde/sample_sde.py +216 -0
  34. sde/text_encoder.py +152 -0
  35. sde/train_sde.py +400 -0
  36. torchdiff/__init__.py +8 -0
  37. torchdiff/ddim.py +1222 -0
  38. torchdiff/ddpm.py +1153 -0
  39. torchdiff/ldm.py +2156 -0
  40. torchdiff/sde.py +1231 -0
  41. torchdiff/tests/__init__.py +0 -0
  42. torchdiff/tests/test_ddim.py +551 -0
  43. torchdiff/tests/test_ddpm.py +1188 -0
  44. torchdiff/tests/test_ldm.py +742 -0
  45. torchdiff/tests/test_sde.py +626 -0
  46. torchdiff/tests/test_unclip.py +366 -0
  47. torchdiff/unclip.py +4170 -0
  48. torchdiff/utils.py +1660 -0
  49. torchdiff-2.0.0.dist-info/METADATA +315 -0
  50. torchdiff-2.0.0.dist-info/RECORD +68 -0
  51. torchdiff-2.0.0.dist-info/WHEEL +5 -0
  52. torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
  53. torchdiff-2.0.0.dist-info/top_level.txt +6 -0
  54. unclip/__init__.py +0 -0
  55. unclip/clip_model.py +304 -0
  56. unclip/ddim_model.py +1296 -0
  57. unclip/decoder_model.py +312 -0
  58. unclip/prior_diff.py +402 -0
  59. unclip/prior_model.py +264 -0
  60. unclip/project_decoder.py +57 -0
  61. unclip/project_prior.py +170 -0
  62. unclip/train_decoder.py +1059 -0
  63. unclip/train_prior.py +757 -0
  64. unclip/unclip_sampler.py +626 -0
  65. unclip/upsampler.py +432 -0
  66. unclip/upsampler_trainer.py +784 -0
  67. unclip/utils.py +1793 -0
  68. unclip/val_metrics.py +221 -0
@@ -0,0 +1,784 @@
1
+
2
+ import torch.nn.functional as F
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional, Tuple, Union, Callable, List
7
+ from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
8
+ import torch.distributed as dist
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ from torch.distributed import init_process_group, destroy_process_group
11
+ from tqdm import tqdm
12
+ import os
13
+ import warnings
14
+
15
+
16
+
17
+
18
+ class TrainUpsamplerUnCLIP(nn.Module):
19
+ """Trainer for the UnCLIP upsampler model.
20
+
21
+ Orchestrates the training of the UnCLIP upsampler model, integrating forward diffusion,
22
+ noise prediction, and low-resolution image conditioning with optional corruption (Gaussian
23
+ blur or BSR degradation). Supports mixed precision, gradient accumulation, DDP, and
24
+ comprehensive training utilities.
25
+
26
+ Parameters
27
+ ----------
28
+ `upsampler_model` : nn.Module
29
+ The UnCLIP upsampler model (e.g., UpsamplerUnCLIP) to be trained.
30
+ `train_loader` : torch.utils.data.DataLoader
31
+ DataLoader for training data, providing low- and high-resolution image pairs.
32
+ `optimizer` : torch.optim.Optimizer
33
+ Optimizer for training the upsampler model.
34
+ `objective` : Callable
35
+ Loss function to compute the difference between predicted and target noise.
36
+ `val_loader` : torch.utils.data.DataLoader, optional
37
+ DataLoader for validation data, default None.
38
+ `max_epochs` : int, optional
39
+ Maximum number of training epochs (default: 1000).
40
+ `device` : Union[str, torch.device], optional
41
+ Device for computation (default: CUDA if available, else CPU).
42
+ `store_path` : str, optional
43
+ Directory to save model checkpoints (default: "unclip_upsampler").
44
+ `patience` : int, optional
45
+ Number of epochs to wait for improvement before early stopping (default: 100).
46
+ `warmup_epochs` : int, optional
47
+ Number of epochs for learning rate warmup (default: 100).
48
+ `val_frequency` : int, optional
49
+ Frequency (in epochs) for validation (default: 10).
50
+ `use_ddp` : bool, optional
51
+ Whether to use Distributed Data Parallel training (default: False).
52
+ `grad_accumulation_steps` : int, optional
53
+ Number of gradient accumulation steps before optimizer update (default: 1).
54
+ `log_frequency` : int, optional
55
+ Frequency (in epochs) for printing progress (default: 1).
56
+ `use_compilation` : bool, optional
57
+ Whether to compile the model using torch.compile (default: False).
58
+ `image_output_range` : Tuple[float, float], optional
59
+ Range for clamping output images (default: (-1.0, 1.0)).
60
+ `normalize_image_outputs` : bool, optional
61
+ Whether to normalize inputs/outputs (default: True).
62
+ `use_autocast` : bool, optional
63
+ Whether to use automatic mixed precision training (default: True).
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ upsampler_model: nn.Module,
69
+ train_loader: torch.utils.data.DataLoader,
70
+ optimizer: torch.optim.Optimizer,
71
+ objective: Callable,
72
+ val_loader: Optional[torch.utils.data.DataLoader] = None,
73
+ max_epochs: int = 1000,
74
+ device: Optional[Union[str, torch.device]] = None,
75
+ store_path: str = "unclip_upsampler",
76
+ patience: int = 100,
77
+ warmup_epochs: int = 100,
78
+ val_frequency: int = 10,
79
+ use_ddp: bool = False,
80
+ grad_accumulation_steps: int = 1,
81
+ log_frequency: int = 1,
82
+ use_compilation: bool = False,
83
+ image_output_range: Tuple[float, float] = (-1.0, 1.0),
84
+ normalize_image_outputs: bool = True,
85
+ use_autocast: bool = True
86
+ ) -> None:
87
+ super().__init__()
88
+ # Training configuration
89
+ self.use_ddp = use_ddp
90
+ self.grad_accumulation_steps = grad_accumulation_steps
91
+ self.use_compilation = use_compilation
92
+ self.use_autocast = use_autocast # Store autocast flag
93
+
94
+ # Device initialization
95
+ if device is None:
96
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
+ elif isinstance(device, str):
98
+ self.device = torch.device(device)
99
+ else:
100
+ self.device = device
101
+
102
+ # Setup distributed training
103
+ if self.use_ddp:
104
+ self._setup_ddp()
105
+ else:
106
+ self._setup_single_gpu()
107
+
108
+ # Compile and wrap models
109
+ self._compile_models()
110
+ self._wrap_models_for_ddp()
111
+
112
+ # Core model
113
+ self.upsampler_model = upsampler_model.to(self.device)
114
+ self.num_timesteps = self.upsampler_model.forward_diffusion.variance_scheduler.num_steps
115
+
116
+ # Training components
117
+ self.optimizer = optimizer
118
+ self.objective = objective
119
+ self.train_loader = train_loader
120
+ self.val_loader = val_loader
121
+
122
+ # Training parameters
123
+ self.max_epochs = max_epochs
124
+ self.patience = patience
125
+ self.val_frequency = val_frequency
126
+ self.log_frequency = log_frequency
127
+ self.image_output_range = image_output_range
128
+ self.normalize_image_outputs = normalize_image_outputs
129
+
130
+ # Checkpoint management
131
+ self.store_path = store_path
132
+
133
+ # Learning rate scheduling
134
+ self.scheduler = ReduceLROnPlateau(
135
+ self.optimizer,
136
+ patience=self.patience,
137
+ factor=0.5
138
+ )
139
+ self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
140
+
141
+ def forward(self) -> Tuple[List[float], float]:
142
+ """Trains the UnCLIP upsampler model to predict noise for denoising.
143
+
144
+ Executes the training loop, optimizing the upsampler model using low- and high-resolution
145
+ image pairs, mixed precision, gradient clipping, and learning rate scheduling. Supports
146
+ validation, early stopping, and checkpointing.
147
+
148
+ Returns
149
+ -------
150
+ train_losses : List[float]
151
+ List of mean training losses per epoch.
152
+ best_val_loss : float
153
+ Best validation or training loss achieved.
154
+ """
155
+ # Set models to training mode
156
+ self.upsampler_model.train()
157
+ if self.upsampler_model.forward_diffusion.variance_scheduler.trainable_beta:
158
+ self.upsampler_model.forward_diffusion.variance_scheduler.train()
159
+ else:
160
+ self.upsampler_model.forward_diffusion.variance_scheduler.eval()
161
+
162
+ # Initialize training components
163
+ scaler = torch.GradScaler() if self.use_autocast else None # Only use scaler with autocast
164
+ train_losses = []
165
+ best_val_loss = float("inf")
166
+ wait = 0
167
+
168
+ # Main training loop
169
+ for epoch in range(self.max_epochs):
170
+ if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
171
+ self.train_loader.sampler.set_epoch(epoch)
172
+
173
+ train_losses_epoch = []
174
+
175
+ # Training step loop with gradient accumulation
176
+ for step, (low_res_images, high_res_images) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
177
+ low_res_images = low_res_images.to(self.device, non_blocking=True)
178
+ high_res_images = high_res_images.to(self.device, non_blocking=True)
179
+
180
+ # Forward pass with optional autocast
181
+ if self.use_autocast:
182
+ with torch.autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
183
+ batch_size = high_res_images.shape[0]
184
+ timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
185
+ noise = torch.randn_like(high_res_images)
186
+ # Force FP32 for forward_diffusion to avoid NaN in variance scheduling
187
+ with torch.autocast(device_type='cuda', enabled=False):
188
+ high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
189
+ corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
190
+ low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
191
+ predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
192
+ loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
193
+ else:
194
+ batch_size = high_res_images.shape[0]
195
+ timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
196
+ noise = torch.randn_like(high_res_images)
197
+ high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
198
+ corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
199
+ low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
200
+ predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
201
+ loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
202
+
203
+ # Backward pass
204
+ if self.use_autocast:
205
+ scaler.scale(loss).backward()
206
+ else:
207
+ loss.backward()
208
+
209
+ if (step + 1) % self.grad_accumulation_steps == 0:
210
+ # Clip gradients
211
+ if self.use_autocast:
212
+ scaler.unscale_(self.optimizer)
213
+ torch.nn.utils.clip_grad_norm_(self.upsampler_model.parameters(), max_norm=1.0)
214
+ torch.nn.utils.clip_grad_norm_(self.upsampler_model.forward_diffusion.parameters(), max_norm=1.0)
215
+
216
+ # Optimizer step
217
+ if self.use_autocast:
218
+ scaler.step(self.optimizer)
219
+ scaler.update()
220
+ else:
221
+ self.optimizer.step()
222
+ self.optimizer.zero_grad()
223
+ torch.cuda.empty_cache() # Clear memory after optimizer step
224
+
225
+ train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
226
+
227
+ # Changed: Moved warmup_lr_scheduler.step() here to ensure it is called after optimizer.step()
228
+ # and only once per epoch, matching the intent of warmup_epochs.
229
+ self.warmup_lr_scheduler.step()
230
+
231
+ mean_train_loss = self._compute_mean_loss(train_losses_epoch)
232
+ train_losses.append(mean_train_loss)
233
+
234
+ if self.master_process and (epoch + 1) % self.log_frequency == 0:
235
+ current_lr = self.optimizer.param_groups[0]['lr']
236
+ print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
237
+
238
+ current_loss = mean_train_loss
239
+
240
+ if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
241
+ val_loss = self.validate()
242
+ if self.master_process:
243
+ print(f" | Val Loss: {val_loss:.4f}")
244
+ print()
245
+ current_loss = val_loss
246
+
247
+ self.scheduler.step(current_loss)
248
+
249
+ if self.master_process:
250
+ if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
251
+ best_val_loss = current_loss
252
+ wait = 0
253
+ self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
254
+ else:
255
+ wait += 1
256
+ if wait >= self.patience:
257
+ print("Early stopping triggered")
258
+ self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
259
+ break
260
+
261
+ if self.use_ddp:
262
+ destroy_process_group()
263
+
264
+ return train_losses, best_val_loss
265
+
266
+ def _compute_mean_loss(self, losses: List[float]) -> float:
267
+ """Computes mean loss with DDP synchronization if needed.
268
+
269
+ Calculates the mean of the provided losses and synchronizes the result across
270
+ processes in DDP mode.
271
+
272
+ Parameters
273
+ ----------
274
+ `losses` : List[float]
275
+ List of loss values for the current epoch.
276
+
277
+ Returns
278
+ -------
279
+ mean_loss : float
280
+ Mean loss value, synchronized if using DDP.
281
+ """
282
+ if not losses:
283
+ return 0.0
284
+ mean_loss = sum(losses) / len(losses)
285
+ if self.use_ddp:
286
+ # synchronize loss across all processes
287
+ loss_tensor = torch.tensor(mean_loss, device=self.device)
288
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
289
+ mean_loss = (loss_tensor / self.ddp_world_size).item()
290
+
291
+ return mean_loss
292
+
293
+ def _setup_ddp(self) -> None:
294
+ """Sets up Distributed Data Parallel training configuration.
295
+
296
+ Initializes the process group, sets up rank information, and configures the CUDA
297
+ device for the current process in DDP mode.
298
+ """
299
+ required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
300
+ for var in required_env_vars:
301
+ if var not in os.environ:
302
+ raise ValueError(f"DDP enabled but {var} environment variable not set")
303
+
304
+ if not torch.cuda.is_available():
305
+ raise RuntimeError("DDP requires CUDA but CUDA is not available")
306
+
307
+ if not torch.distributed.is_initialized():
308
+ init_process_group(backend="nccl")
309
+
310
+ self.ddp_rank = int(os.environ["RANK"])
311
+ self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
312
+ self.ddp_world_size = int(os.environ["WORLD_SIZE"])
313
+
314
+ self.device = torch.device(f"cuda:{self.ddp_local_rank}")
315
+ torch.cuda.set_device(self.device)
316
+
317
+ self.master_process = self.ddp_rank == 0
318
+
319
+ if self.master_process:
320
+ print(f"DDP initialized with world_size={self.ddp_world_size}")
321
+
322
+ def _setup_single_gpu(self) -> None:
323
+ """Sets up single GPU or CPU training configuration.
324
+
325
+ Configures the training setup for single-device operation, setting rank and process
326
+ information for non-DDP training.
327
+ """
328
+ self.ddp_rank = 0
329
+ self.ddp_local_rank = 0
330
+ self.ddp_world_size = 1
331
+ self.master_process = True
332
+
333
+ @staticmethod
334
+ def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
335
+ """Creates a learning rate scheduler for warmup.
336
+
337
+ Generates a scheduler that linearly increases the learning rate from 0 to the
338
+ optimizer's initial value over the specified warmup epochs.
339
+
340
+ Parameters
341
+ ----------
342
+ `optimizer` : torch.optim.Optimizer
343
+ Optimizer to apply the scheduler to.
344
+ `warmup_epochs` : int
345
+ Number of epochs for the warmup phase.
346
+
347
+ Returns
348
+ -------
349
+ lr_scheduler : torch.optim.lr_scheduler.LambdaLR
350
+ Learning rate scheduler for warmup.
351
+ """
352
+ def lr_lambda(epoch):
353
+ return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
354
+
355
+ return LambdaLR(optimizer, lr_lambda)
356
+
357
+ def _wrap_models_for_ddp(self) -> None:
358
+ """Wraps models with DistributedDataParallel for multi-GPU training.
359
+
360
+ Configures the upsampler model for DDP training by wrapping it with DistributedDataParallel.
361
+ """
362
+ if self.use_ddp:
363
+ self.upsampler_model = self.upsampler_model.to(self.ddp_local_rank)
364
+ self.upsampler_model = DDP(
365
+ self.upsampler_model,
366
+ device_ids=[self.ddp_local_rank],
367
+ find_unused_parameters=True
368
+ )
369
+
370
+ def _compile_models(self) -> None:
371
+ """Compiles models for optimization if supported.
372
+
373
+ Attempts to compile the upsampler model using torch.compile for optimization,
374
+ falling back to uncompiled execution if compilation fails.
375
+ """
376
+ if self.use_compilation:
377
+ try:
378
+ self.upsampler_model = self.upsampler_model.to(self.device)
379
+ self.upsampler_model = torch.compile(self.upsampler_model, mode="reduce-overhead")
380
+
381
+ if self.master_process:
382
+ print("Models compiled successfully")
383
+ except Exception as e:
384
+ if self.master_process:
385
+ print(f"Model compilation failed: {e}. Continuing without compilation.")
386
+
387
+ def corrupt_conditioning_image(self, x_low: torch.Tensor, corruption_type: str = "gaussian_blur") -> torch.Tensor:
388
+ """Corrupts the low-resolution conditioning image for robustness.
389
+
390
+ Applies Gaussian blur or BSR degradation to the low-resolution image to simulate
391
+ real-world degradation, as specified in the UnCLIP paper.
392
+
393
+ Parameters
394
+ ----------
395
+ `x_low` : torch.Tensor
396
+ Low-resolution input image, shape (batch_size, channels, low_res_size, low_res_size).
397
+ `corruption_type` : str, optional
398
+ Type of corruption to apply: "gaussian_blur" or "bsr_degradation" (default: "gaussian_blur").
399
+
400
+ Returns
401
+ -------
402
+ x_degraded : torch.Tensor
403
+ Corrupted low-resolution image, same shape as input.
404
+ """
405
+ if corruption_type == "gaussian_blur":
406
+ # apply Gaussian blur
407
+ kernel_size = random.choice([3, 5, 7])
408
+ sigma = random.uniform(0.5, 2.0)
409
+ return self._gaussian_blur(x_low, kernel_size, sigma)
410
+ elif corruption_type == "bsr_degradation":
411
+ # more diverse BSR degradation for second upsampler
412
+ return self._bsr_degradation(x_low)
413
+ else:
414
+ return x_low
415
+
416
+ def _gaussian_blur(self, x: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
417
+ """Applies Gaussian blur to the input image.
418
+
419
+ Parameters
420
+ ----------
421
+ `x` : torch.Tensor
422
+ Input image tensor, shape (batch_size, channels, height, width).
423
+ `kernel_size` : int
424
+ Size of the Gaussian kernel.
425
+ `sigma` : float
426
+ Standard deviation of the Gaussian distribution.
427
+
428
+ Returns
429
+ -------
430
+ x_blurred : torch.Tensor
431
+ Blurred image tensor, same shape as input.
432
+ """
433
+ # create Gaussian kernel
434
+ kernel = self._get_gaussian_kernel(kernel_size, sigma).to(x.device)
435
+ kernel = kernel.expand(x.shape[1], 1, kernel_size, kernel_size)
436
+ padding = kernel_size // 2
437
+ return F.conv2d(x, kernel, padding=padding, groups=x.shape[1])
438
+
439
+ def _get_gaussian_kernel(self, kernel_size: int, sigma: float) -> torch.Tensor:
440
+ """Generates a 2D Gaussian kernel.
441
+
442
+ Parameters
443
+ ----------
444
+ `kernel_size` : int
445
+ Size of the Gaussian kernel.
446
+ `sigma` : float
447
+ Standard deviation of the Gaussian distribution.
448
+
449
+ Returns
450
+ -------
451
+ kernel : torch.Tensor
452
+ 2D Gaussian kernel, shape (kernel_size, kernel_size).
453
+ """
454
+ coords = torch.arange(kernel_size, dtype=torch.float32) - kernel_size // 2
455
+ g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
456
+ g = g / g.sum()
457
+ return g[:, None] * g[None, :]
458
+
459
+ def _bsr_degradation(self, x: torch.Tensor) -> torch.Tensor:
460
+ """Applies BSR degradation to the input image.
461
+
462
+ Simulates degradation with noise and Gaussian blur, as used in the UnCLIP paper
463
+ for the second upsampler.
464
+
465
+ Parameters
466
+ ----------
467
+ `x` : torch.Tensor
468
+ Input image tensor, shape (batch_size, channels, height, width).
469
+
470
+ Returns
471
+ -------
472
+ x_degraded : torch.Tensor
473
+ Degraded image tensor, same shape as input, clamped to [-1, 1].
474
+ """
475
+ # add noise
476
+ noise_level = random.uniform(0.0, 0.1)
477
+ noise = torch.randn_like(x) * noise_level
478
+
479
+ # apply blur
480
+ kernel_size = random.choice([3, 5, 7])
481
+ sigma = random.uniform(0.5, 3.0)
482
+ x_degraded = self._gaussian_blur(x + noise, kernel_size, sigma)
483
+
484
+ return torch.clamp(x_degraded, -1.0, 1.0)
485
+
486
+ def validate(self) -> float:
487
+ """Validates the UnCLIP upsampler model.
488
+
489
+ Computes the validation loss by applying forward diffusion to high-resolution images,
490
+ predicting noise with the upsampler model conditioned on corrupted low-resolution images,
491
+ and comparing predicted noise to ground truth.
492
+
493
+ Returns
494
+ -------
495
+ val_loss : float
496
+ Mean validation loss.
497
+ """
498
+ # set models to eval mode for evaluation
499
+ self.upsampler_model.eval()
500
+ self.upsampler_model.forward_diffusion.eval()
501
+
502
+ val_losses = []
503
+
504
+ with torch.no_grad():
505
+ for low_res_images, high_res_images in self.val_loader:
506
+ low_res_images = low_res_images.to(self.device, non_blocking=True)
507
+ high_res_images = high_res_images.to(self.device, non_blocking=True)
508
+ batch_size = high_res_images.shape[0]
509
+ timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
510
+ noise = torch.randn_like(high_res_images)
511
+ high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
512
+ corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
513
+ low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
514
+ predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
515
+ # compute loss
516
+ loss = self.objective(predicted_noise, noise)
517
+ val_losses.append(loss.item())
518
+
519
+ # compute average loss
520
+ val_loss = torch.tensor(val_losses).mean().item()
521
+
522
+ if self.use_ddp:
523
+ val_loss_tensor = torch.tensor(val_loss, device=self.device)
524
+ dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
525
+ val_loss = val_loss_tensor.item()
526
+
527
+ # return to training mode
528
+ self.upsampler_model.train()
529
+ if not self.upsampler_model.forward_diffusion.variance_scheduler.trainable_beta:
530
+ self.upsampler_model.forward_diffusion.variance_scheduler.eval()
531
+
532
+ return val_loss
533
+
534
+ def _save_checkpoint(self, epoch: int, loss: float, is_best: bool = False, suffix: str = ""):
535
+ """Saves model checkpoint.
536
+
537
+ Saves the state of the upsampler model, its variance scheduler, optimizer, and
538
+ schedulers, with options for best model and epoch-specific checkpoints.
539
+
540
+ Parameters
541
+ ----------
542
+ `epoch` : int
543
+ Current epoch number.
544
+ `loss` : float
545
+ Current loss value.
546
+ `is_best` : bool, optional
547
+ Whether to save as the best model checkpoint (default: False).
548
+ `suffix` : str, optional
549
+ Suffix to add to checkpoint filename, default "".
550
+ """
551
+ if not self.master_process:
552
+ return
553
+ checkpoint = {
554
+ 'epoch': epoch,
555
+ 'loss': loss,
556
+ # Core model
557
+ 'upsampler_model_state_dict': self.upsampler_model.module.state_dict() if self.use_ddp else self.upsampler_model.state_dict(),
558
+ 'optimizer_state_dict': self.optimizer.state_dict(),
559
+ # Training configuration
560
+ 'model_channels': self.upsampler_model.model_channels,
561
+ 'num_res_blocks': self.upsampler_model.num_res_blocks,
562
+ 'normalize': self.normalize_image_outputs,
563
+ 'output_range': self.image_output_range
564
+ }
565
+
566
+ # Save variance scheduler (submodule of forward_diffusion)
567
+ checkpoint['variance_scheduler_state_dict'] = (
568
+ self.upsampler_model.module.forward_diffusion.variance_scheduler.state_dict() if self.use_ddp
569
+ else self.upsampler_model.forward_diffusion.variance_scheduler.state_dict()
570
+ )
571
+
572
+ # Save schedulers state
573
+ checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
574
+ checkpoint['warmup_scheduler_state_dict'] = self.warmup_lr_scheduler.state_dict()
575
+
576
+ filename = f"unclip_upsampler_epoch_{epoch}{suffix}.pth"
577
+ if is_best:
578
+ filename = f"unclip_upsampler_best{suffix}.pth"
579
+
580
+ filepath = os.path.join(self.store_path, filename)
581
+ os.makedirs(self.store_path, exist_ok=True)
582
+ torch.save(checkpoint, filepath)
583
+
584
+ if is_best:
585
+ print(f"Best model saved: {filepath}")
586
+
587
+ def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
588
+ """Loads model checkpoint.
589
+
590
+ Restores the state of the upsampler model, its variance scheduler, optimizer, and
591
+ schedulers from a saved checkpoint, handling DDP compatibility.
592
+
593
+ Parameters
594
+ ----------
595
+ `checkpoint_path` : str
596
+ Path to the checkpoint file.
597
+
598
+ Returns
599
+ -------
600
+ epoch : int
601
+ The epoch at which the checkpoint was saved.
602
+ loss : float
603
+ The loss at the checkpoint.
604
+ """
605
+ try:
606
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
607
+ except FileNotFoundError:
608
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
609
+
610
+ def _load_model_state_dict(model: nn.Module, state_dict: dict, model_name: str) -> None:
611
+ """Helper function to load state dict with DDP compatibility."""
612
+ try:
613
+ # Handle DDP state dict compatibility
614
+ if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
615
+ state_dict = {f'module.{k}': v for k, v in state_dict.items()}
616
+ elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
617
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
618
+
619
+ model.load_state_dict(state_dict)
620
+ if self.master_process:
621
+ print(f"✓ Loaded {model_name}")
622
+ except Exception as e:
623
+ warnings.warn(f"Failed to load {model_name}: {e}")
624
+
625
+ # Load core upsampler model
626
+ if 'upsampler_model_state_dict' in checkpoint:
627
+ _load_model_state_dict(self.upsampler_model, checkpoint['upsampler_model_state_dict'],
628
+ 'upsampler_model')
629
+
630
+ # Load variance scheduler (submodule of forward_diffusion)
631
+ if 'variance_scheduler_state_dict' in checkpoint or 'hyper_params_state_dict' in checkpoint:
632
+ state_dict = checkpoint.get('variance_scheduler_state_dict', checkpoint.get('hyper_params_state_dict'))
633
+ try:
634
+ _load_model_state_dict(self.upsampler_model.forward_diffusion.variance_scheduler, state_dict, 'variance_scheduler')
635
+ except Exception as e:
636
+ warnings.warn(f"Failed to load variance scheduler: {e}")
637
+
638
+ # Load optimizer
639
+ if 'optimizer_state_dict' in checkpoint:
640
+ try:
641
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
642
+ if self.master_process:
643
+ print("✓ Loaded optimizer")
644
+ except Exception as e:
645
+ warnings.warn(f"Failed to load optimizer state: {e}")
646
+
647
+ # Load schedulers
648
+ if 'scheduler_state_dict' in checkpoint:
649
+ try:
650
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
651
+ if self.master_process:
652
+ print("✓ Loaded main scheduler")
653
+ except Exception as e:
654
+ warnings.warn(f"Failed to load scheduler state: {e}")
655
+
656
+ if 'warmup_scheduler_state_dict' in checkpoint:
657
+ try:
658
+ self.warmup_lr_scheduler.load_state_dict(checkpoint['warmup_scheduler_state_dict'])
659
+ if self.master_process:
660
+ print("✓ Loaded warmup scheduler")
661
+ except Exception as e:
662
+ warnings.warn(f"Failed to load warmup scheduler state: {e}")
663
+
664
+ # Verify configuration compatibility
665
+ if 'model_channels' in checkpoint:
666
+ if checkpoint['model_channels'] != self.upsampler_model.model_channels:
667
+ warnings.warn(
668
+ f"Model channels mismatch: checkpoint={checkpoint['model_channels']}, current={self.upsampler_model.model_channels}")
669
+
670
+ if 'num_res_blocks' in checkpoint:
671
+ if checkpoint['num_res_blocks'] != self.upsampler_model.num_res_blocks:
672
+ warnings.warn(
673
+ f"Num res blocks mismatch: checkpoint={checkpoint['num_res_blocks']}, current={self.upsampler_model.num_res_blocks}")
674
+
675
+ if 'normalize' in checkpoint:
676
+ if checkpoint['normalize'] != self.normalize_image_outputs:
677
+ warnings.warn(
678
+ f"Normalize setting mismatch: checkpoint={checkpoint['normalize']}, current={self.normalize_image_outputs}")
679
+
680
+ epoch = checkpoint.get('epoch', 0)
681
+ loss = checkpoint.get('loss', float('inf'))
682
+
683
+ if self.master_process:
684
+ print(f"Successfully loaded checkpoint from {checkpoint_path}")
685
+ print(f"Epoch: {epoch}, Loss: {loss:.4f}")
686
+
687
+ return epoch, loss
688
+
689
+
690
+ """
691
+ from prior_diff import VarianceSchedulerUnCLIP, ForwardUnCLIP
692
+ from upsampler import UpsamplerUnCLIP
693
+ import torch
694
+ import torch.optim as optim
695
+ import torch.nn as nn
696
+ from torch.utils.data import Dataset, DataLoader
697
+
698
+ # Define a dummy dataset for example purposes (replace with real dataset in practice)
699
+ class DummyDataset(Dataset):
700
+ def __init__(self, num_samples=1000, low_res_size=64, high_res_size=256):
701
+ self.num_samples = num_samples
702
+ self.low_res_size = low_res_size
703
+ self.high_res_size = high_res_size
704
+
705
+ def __len__(self):
706
+ return self.num_samples
707
+
708
+ def __getitem__(self, idx):
709
+ # Generate random low-res and high-res images (in practice, load from disk or augment)
710
+ low_res_image = torch.rand(3, self.low_res_size, self.low_res_size) * 2 - 1 # Normalize to [-1, 1]
711
+ high_res_image = torch.rand(3, self.high_res_size, self.high_res_size) * 2 - 1 # Normalize to [-1, 1]
712
+ return low_res_image, high_res_image
713
+
714
+ # Instantiate the variance scheduler
715
+ hyp = VarianceSchedulerUnCLIP(
716
+ num_steps=400,
717
+ beta_start=1e-4,
718
+ beta_end=0.02,
719
+ trainable_beta=True,
720
+ beta_method="linear"
721
+ )
722
+
723
+ # Instantiate the forward diffusion process
724
+ forward = ForwardUnCLIP(hyp)
725
+
726
+ # Instantiate the upsampler model
727
+ model = UpsamplerUnCLIP(
728
+ forward_diffusion=forward,
729
+ in_channels=3,
730
+ out_channels=3,
731
+ model_channels=32,
732
+ num_res_blocks=2,
733
+ channel_mult=(1, 2, 4, 8),
734
+ dropout=0.1,
735
+ time_embed_dim=32,
736
+ low_res_size=64,
737
+ high_res_size=256
738
+ )
739
+
740
+ # Create train loader with dummy dataset (replace with real DataLoader for your dataset)
741
+ train_dataset = DummyDataset(num_samples=4)
742
+ train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
743
+
744
+ # Optional validation loader (using same dummy for example)
745
+ val_dataset = DummyDataset(num_samples=2)
746
+ val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0)
747
+
748
+ # Define optimizer
749
+ optimizer = optim.AdamW(model.parameters(), lr=1e-3)
750
+
751
+ # Define objective (loss function, e.g., MSE for noise prediction)
752
+ objective = nn.MSELoss()
753
+
754
+ # Instantiate the trainer
755
+ trainer = TrainUpsamplerUnCLIP(
756
+ upsampler_model=model,
757
+ train_loader=train_loader,
758
+ optimizer=optimizer,
759
+ objective=objective,
760
+ val_loader=val_loader, # Optional
761
+ max_epoch=10, # Small number for example; increase for real training
762
+ device='cuda' if torch.cuda.is_available() else 'cpu',
763
+ store_path="upsampler",
764
+ patience=10,
765
+ warmup_epochs=2,
766
+ val_frequency=5,
767
+ use_ddp=False, # Set to True if using distributed training
768
+ num_grad_accumulation=2,
769
+ progress_frequency=1,
770
+ compilation=True, # Set to True if torch.compile is desired and supported
771
+ output_range=(-1.0, 1.0),
772
+ normalize=True,
773
+ use_autocast=False
774
+ )
775
+
776
+ # Run the training
777
+ train_losses, best_val_loss = trainer()
778
+
779
+ # Print results
780
+ print(f"Training losses: {train_losses}")
781
+ print(f"Best validation loss: {best_val_loss}")
782
+
783
+ """
784
+