TorchDiff 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ddim/__init__.py +0 -0
- ddim/forward_ddim.py +79 -0
- ddim/hyper_param.py +225 -0
- ddim/noise_predictor.py +521 -0
- ddim/reverse_ddim.py +91 -0
- ddim/sample_ddim.py +219 -0
- ddim/text_encoder.py +152 -0
- ddim/train_ddim.py +394 -0
- ddpm/__init__.py +0 -0
- ddpm/forward_ddpm.py +89 -0
- ddpm/hyper_param.py +180 -0
- ddpm/noise_predictor.py +521 -0
- ddpm/reverse_ddpm.py +102 -0
- ddpm/sample_ddpm.py +213 -0
- ddpm/text_encoder.py +152 -0
- ddpm/train_ddpm.py +386 -0
- ldm/__init__.py +0 -0
- ldm/autoencoder.py +855 -0
- ldm/forward_idm.py +100 -0
- ldm/hyper_param.py +239 -0
- ldm/metrics.py +206 -0
- ldm/noise_predictor.py +1074 -0
- ldm/reverse_ldm.py +119 -0
- ldm/sample_ldm.py +254 -0
- ldm/text_encoder.py +429 -0
- ldm/train_autoencoder.py +216 -0
- ldm/train_ldm.py +412 -0
- sde/__init__.py +0 -0
- sde/forward_sde.py +98 -0
- sde/hyper_param.py +200 -0
- sde/noise_predictor.py +521 -0
- sde/reverse_sde.py +115 -0
- sde/sample_sde.py +216 -0
- sde/text_encoder.py +152 -0
- sde/train_sde.py +400 -0
- torchdiff/__init__.py +8 -0
- torchdiff/ddim.py +1222 -0
- torchdiff/ddpm.py +1153 -0
- torchdiff/ldm.py +2156 -0
- torchdiff/sde.py +1231 -0
- torchdiff/tests/__init__.py +0 -0
- torchdiff/tests/test_ddim.py +551 -0
- torchdiff/tests/test_ddpm.py +1188 -0
- torchdiff/tests/test_ldm.py +742 -0
- torchdiff/tests/test_sde.py +626 -0
- torchdiff/tests/test_unclip.py +366 -0
- torchdiff/unclip.py +4170 -0
- torchdiff/utils.py +1660 -0
- torchdiff-2.0.0.dist-info/METADATA +315 -0
- torchdiff-2.0.0.dist-info/RECORD +68 -0
- torchdiff-2.0.0.dist-info/WHEEL +5 -0
- torchdiff-2.0.0.dist-info/licenses/LICENSE +21 -0
- torchdiff-2.0.0.dist-info/top_level.txt +6 -0
- unclip/__init__.py +0 -0
- unclip/clip_model.py +304 -0
- unclip/ddim_model.py +1296 -0
- unclip/decoder_model.py +312 -0
- unclip/prior_diff.py +402 -0
- unclip/prior_model.py +264 -0
- unclip/project_decoder.py +57 -0
- unclip/project_prior.py +170 -0
- unclip/train_decoder.py +1059 -0
- unclip/train_prior.py +757 -0
- unclip/unclip_sampler.py +626 -0
- unclip/upsampler.py +432 -0
- unclip/upsampler_trainer.py +784 -0
- unclip/utils.py +1793 -0
- unclip/val_metrics.py +221 -0
|
@@ -0,0 +1,784 @@
|
|
|
1
|
+
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
import random
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from typing import Optional, Tuple, Union, Callable, List
|
|
7
|
+
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
|
|
8
|
+
import torch.distributed as dist
|
|
9
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
10
|
+
from torch.distributed import init_process_group, destroy_process_group
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
import os
|
|
13
|
+
import warnings
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TrainUpsamplerUnCLIP(nn.Module):
|
|
19
|
+
"""Trainer for the UnCLIP upsampler model.
|
|
20
|
+
|
|
21
|
+
Orchestrates the training of the UnCLIP upsampler model, integrating forward diffusion,
|
|
22
|
+
noise prediction, and low-resolution image conditioning with optional corruption (Gaussian
|
|
23
|
+
blur or BSR degradation). Supports mixed precision, gradient accumulation, DDP, and
|
|
24
|
+
comprehensive training utilities.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
`upsampler_model` : nn.Module
|
|
29
|
+
The UnCLIP upsampler model (e.g., UpsamplerUnCLIP) to be trained.
|
|
30
|
+
`train_loader` : torch.utils.data.DataLoader
|
|
31
|
+
DataLoader for training data, providing low- and high-resolution image pairs.
|
|
32
|
+
`optimizer` : torch.optim.Optimizer
|
|
33
|
+
Optimizer for training the upsampler model.
|
|
34
|
+
`objective` : Callable
|
|
35
|
+
Loss function to compute the difference between predicted and target noise.
|
|
36
|
+
`val_loader` : torch.utils.data.DataLoader, optional
|
|
37
|
+
DataLoader for validation data, default None.
|
|
38
|
+
`max_epochs` : int, optional
|
|
39
|
+
Maximum number of training epochs (default: 1000).
|
|
40
|
+
`device` : Union[str, torch.device], optional
|
|
41
|
+
Device for computation (default: CUDA if available, else CPU).
|
|
42
|
+
`store_path` : str, optional
|
|
43
|
+
Directory to save model checkpoints (default: "unclip_upsampler").
|
|
44
|
+
`patience` : int, optional
|
|
45
|
+
Number of epochs to wait for improvement before early stopping (default: 100).
|
|
46
|
+
`warmup_epochs` : int, optional
|
|
47
|
+
Number of epochs for learning rate warmup (default: 100).
|
|
48
|
+
`val_frequency` : int, optional
|
|
49
|
+
Frequency (in epochs) for validation (default: 10).
|
|
50
|
+
`use_ddp` : bool, optional
|
|
51
|
+
Whether to use Distributed Data Parallel training (default: False).
|
|
52
|
+
`grad_accumulation_steps` : int, optional
|
|
53
|
+
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
54
|
+
`log_frequency` : int, optional
|
|
55
|
+
Frequency (in epochs) for printing progress (default: 1).
|
|
56
|
+
`use_compilation` : bool, optional
|
|
57
|
+
Whether to compile the model using torch.compile (default: False).
|
|
58
|
+
`image_output_range` : Tuple[float, float], optional
|
|
59
|
+
Range for clamping output images (default: (-1.0, 1.0)).
|
|
60
|
+
`normalize_image_outputs` : bool, optional
|
|
61
|
+
Whether to normalize inputs/outputs (default: True).
|
|
62
|
+
`use_autocast` : bool, optional
|
|
63
|
+
Whether to use automatic mixed precision training (default: True).
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
upsampler_model: nn.Module,
|
|
69
|
+
train_loader: torch.utils.data.DataLoader,
|
|
70
|
+
optimizer: torch.optim.Optimizer,
|
|
71
|
+
objective: Callable,
|
|
72
|
+
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
73
|
+
max_epochs: int = 1000,
|
|
74
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
75
|
+
store_path: str = "unclip_upsampler",
|
|
76
|
+
patience: int = 100,
|
|
77
|
+
warmup_epochs: int = 100,
|
|
78
|
+
val_frequency: int = 10,
|
|
79
|
+
use_ddp: bool = False,
|
|
80
|
+
grad_accumulation_steps: int = 1,
|
|
81
|
+
log_frequency: int = 1,
|
|
82
|
+
use_compilation: bool = False,
|
|
83
|
+
image_output_range: Tuple[float, float] = (-1.0, 1.0),
|
|
84
|
+
normalize_image_outputs: bool = True,
|
|
85
|
+
use_autocast: bool = True
|
|
86
|
+
) -> None:
|
|
87
|
+
super().__init__()
|
|
88
|
+
# Training configuration
|
|
89
|
+
self.use_ddp = use_ddp
|
|
90
|
+
self.grad_accumulation_steps = grad_accumulation_steps
|
|
91
|
+
self.use_compilation = use_compilation
|
|
92
|
+
self.use_autocast = use_autocast # Store autocast flag
|
|
93
|
+
|
|
94
|
+
# Device initialization
|
|
95
|
+
if device is None:
|
|
96
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
97
|
+
elif isinstance(device, str):
|
|
98
|
+
self.device = torch.device(device)
|
|
99
|
+
else:
|
|
100
|
+
self.device = device
|
|
101
|
+
|
|
102
|
+
# Setup distributed training
|
|
103
|
+
if self.use_ddp:
|
|
104
|
+
self._setup_ddp()
|
|
105
|
+
else:
|
|
106
|
+
self._setup_single_gpu()
|
|
107
|
+
|
|
108
|
+
# Compile and wrap models
|
|
109
|
+
self._compile_models()
|
|
110
|
+
self._wrap_models_for_ddp()
|
|
111
|
+
|
|
112
|
+
# Core model
|
|
113
|
+
self.upsampler_model = upsampler_model.to(self.device)
|
|
114
|
+
self.num_timesteps = self.upsampler_model.forward_diffusion.variance_scheduler.num_steps
|
|
115
|
+
|
|
116
|
+
# Training components
|
|
117
|
+
self.optimizer = optimizer
|
|
118
|
+
self.objective = objective
|
|
119
|
+
self.train_loader = train_loader
|
|
120
|
+
self.val_loader = val_loader
|
|
121
|
+
|
|
122
|
+
# Training parameters
|
|
123
|
+
self.max_epochs = max_epochs
|
|
124
|
+
self.patience = patience
|
|
125
|
+
self.val_frequency = val_frequency
|
|
126
|
+
self.log_frequency = log_frequency
|
|
127
|
+
self.image_output_range = image_output_range
|
|
128
|
+
self.normalize_image_outputs = normalize_image_outputs
|
|
129
|
+
|
|
130
|
+
# Checkpoint management
|
|
131
|
+
self.store_path = store_path
|
|
132
|
+
|
|
133
|
+
# Learning rate scheduling
|
|
134
|
+
self.scheduler = ReduceLROnPlateau(
|
|
135
|
+
self.optimizer,
|
|
136
|
+
patience=self.patience,
|
|
137
|
+
factor=0.5
|
|
138
|
+
)
|
|
139
|
+
self.warmup_lr_scheduler = self.warmup_scheduler(self.optimizer, warmup_epochs)
|
|
140
|
+
|
|
141
|
+
def forward(self) -> Tuple[List[float], float]:
|
|
142
|
+
"""Trains the UnCLIP upsampler model to predict noise for denoising.
|
|
143
|
+
|
|
144
|
+
Executes the training loop, optimizing the upsampler model using low- and high-resolution
|
|
145
|
+
image pairs, mixed precision, gradient clipping, and learning rate scheduling. Supports
|
|
146
|
+
validation, early stopping, and checkpointing.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
train_losses : List[float]
|
|
151
|
+
List of mean training losses per epoch.
|
|
152
|
+
best_val_loss : float
|
|
153
|
+
Best validation or training loss achieved.
|
|
154
|
+
"""
|
|
155
|
+
# Set models to training mode
|
|
156
|
+
self.upsampler_model.train()
|
|
157
|
+
if self.upsampler_model.forward_diffusion.variance_scheduler.trainable_beta:
|
|
158
|
+
self.upsampler_model.forward_diffusion.variance_scheduler.train()
|
|
159
|
+
else:
|
|
160
|
+
self.upsampler_model.forward_diffusion.variance_scheduler.eval()
|
|
161
|
+
|
|
162
|
+
# Initialize training components
|
|
163
|
+
scaler = torch.GradScaler() if self.use_autocast else None # Only use scaler with autocast
|
|
164
|
+
train_losses = []
|
|
165
|
+
best_val_loss = float("inf")
|
|
166
|
+
wait = 0
|
|
167
|
+
|
|
168
|
+
# Main training loop
|
|
169
|
+
for epoch in range(self.max_epochs):
|
|
170
|
+
if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
|
|
171
|
+
self.train_loader.sampler.set_epoch(epoch)
|
|
172
|
+
|
|
173
|
+
train_losses_epoch = []
|
|
174
|
+
|
|
175
|
+
# Training step loop with gradient accumulation
|
|
176
|
+
for step, (low_res_images, high_res_images) in enumerate(tqdm(self.train_loader, disable=not self.master_process)):
|
|
177
|
+
low_res_images = low_res_images.to(self.device, non_blocking=True)
|
|
178
|
+
high_res_images = high_res_images.to(self.device, non_blocking=True)
|
|
179
|
+
|
|
180
|
+
# Forward pass with optional autocast
|
|
181
|
+
if self.use_autocast:
|
|
182
|
+
with torch.autocast(device_type='cuda' if self.device.type == 'cuda' else 'cpu'):
|
|
183
|
+
batch_size = high_res_images.shape[0]
|
|
184
|
+
timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
|
|
185
|
+
noise = torch.randn_like(high_res_images)
|
|
186
|
+
# Force FP32 for forward_diffusion to avoid NaN in variance scheduling
|
|
187
|
+
with torch.autocast(device_type='cuda', enabled=False):
|
|
188
|
+
high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
|
|
189
|
+
corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
|
|
190
|
+
low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
|
|
191
|
+
predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
|
|
192
|
+
loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
|
|
193
|
+
else:
|
|
194
|
+
batch_size = high_res_images.shape[0]
|
|
195
|
+
timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
|
|
196
|
+
noise = torch.randn_like(high_res_images)
|
|
197
|
+
high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
|
|
198
|
+
corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
|
|
199
|
+
low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
|
|
200
|
+
predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
|
|
201
|
+
loss = self.objective(predicted_noise, noise) / self.grad_accumulation_steps
|
|
202
|
+
|
|
203
|
+
# Backward pass
|
|
204
|
+
if self.use_autocast:
|
|
205
|
+
scaler.scale(loss).backward()
|
|
206
|
+
else:
|
|
207
|
+
loss.backward()
|
|
208
|
+
|
|
209
|
+
if (step + 1) % self.grad_accumulation_steps == 0:
|
|
210
|
+
# Clip gradients
|
|
211
|
+
if self.use_autocast:
|
|
212
|
+
scaler.unscale_(self.optimizer)
|
|
213
|
+
torch.nn.utils.clip_grad_norm_(self.upsampler_model.parameters(), max_norm=1.0)
|
|
214
|
+
torch.nn.utils.clip_grad_norm_(self.upsampler_model.forward_diffusion.parameters(), max_norm=1.0)
|
|
215
|
+
|
|
216
|
+
# Optimizer step
|
|
217
|
+
if self.use_autocast:
|
|
218
|
+
scaler.step(self.optimizer)
|
|
219
|
+
scaler.update()
|
|
220
|
+
else:
|
|
221
|
+
self.optimizer.step()
|
|
222
|
+
self.optimizer.zero_grad()
|
|
223
|
+
torch.cuda.empty_cache() # Clear memory after optimizer step
|
|
224
|
+
|
|
225
|
+
train_losses_epoch.append(loss.item() * self.grad_accumulation_steps)
|
|
226
|
+
|
|
227
|
+
# Changed: Moved warmup_lr_scheduler.step() here to ensure it is called after optimizer.step()
|
|
228
|
+
# and only once per epoch, matching the intent of warmup_epochs.
|
|
229
|
+
self.warmup_lr_scheduler.step()
|
|
230
|
+
|
|
231
|
+
mean_train_loss = self._compute_mean_loss(train_losses_epoch)
|
|
232
|
+
train_losses.append(mean_train_loss)
|
|
233
|
+
|
|
234
|
+
if self.master_process and (epoch + 1) % self.log_frequency == 0:
|
|
235
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
236
|
+
print(f"Epoch {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
|
|
237
|
+
|
|
238
|
+
current_loss = mean_train_loss
|
|
239
|
+
|
|
240
|
+
if self.val_loader is not None and (epoch + 1) % self.val_frequency == 0:
|
|
241
|
+
val_loss = self.validate()
|
|
242
|
+
if self.master_process:
|
|
243
|
+
print(f" | Val Loss: {val_loss:.4f}")
|
|
244
|
+
print()
|
|
245
|
+
current_loss = val_loss
|
|
246
|
+
|
|
247
|
+
self.scheduler.step(current_loss)
|
|
248
|
+
|
|
249
|
+
if self.master_process:
|
|
250
|
+
if current_loss < best_val_loss and (epoch + 1) % self.val_frequency == 0:
|
|
251
|
+
best_val_loss = current_loss
|
|
252
|
+
wait = 0
|
|
253
|
+
self._save_checkpoint(epoch + 1, best_val_loss, is_best=True)
|
|
254
|
+
else:
|
|
255
|
+
wait += 1
|
|
256
|
+
if wait >= self.patience:
|
|
257
|
+
print("Early stopping triggered")
|
|
258
|
+
self._save_checkpoint(epoch + 1, current_loss, suffix="_early_stop")
|
|
259
|
+
break
|
|
260
|
+
|
|
261
|
+
if self.use_ddp:
|
|
262
|
+
destroy_process_group()
|
|
263
|
+
|
|
264
|
+
return train_losses, best_val_loss
|
|
265
|
+
|
|
266
|
+
def _compute_mean_loss(self, losses: List[float]) -> float:
|
|
267
|
+
"""Computes mean loss with DDP synchronization if needed.
|
|
268
|
+
|
|
269
|
+
Calculates the mean of the provided losses and synchronizes the result across
|
|
270
|
+
processes in DDP mode.
|
|
271
|
+
|
|
272
|
+
Parameters
|
|
273
|
+
----------
|
|
274
|
+
`losses` : List[float]
|
|
275
|
+
List of loss values for the current epoch.
|
|
276
|
+
|
|
277
|
+
Returns
|
|
278
|
+
-------
|
|
279
|
+
mean_loss : float
|
|
280
|
+
Mean loss value, synchronized if using DDP.
|
|
281
|
+
"""
|
|
282
|
+
if not losses:
|
|
283
|
+
return 0.0
|
|
284
|
+
mean_loss = sum(losses) / len(losses)
|
|
285
|
+
if self.use_ddp:
|
|
286
|
+
# synchronize loss across all processes
|
|
287
|
+
loss_tensor = torch.tensor(mean_loss, device=self.device)
|
|
288
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
|
|
289
|
+
mean_loss = (loss_tensor / self.ddp_world_size).item()
|
|
290
|
+
|
|
291
|
+
return mean_loss
|
|
292
|
+
|
|
293
|
+
def _setup_ddp(self) -> None:
|
|
294
|
+
"""Sets up Distributed Data Parallel training configuration.
|
|
295
|
+
|
|
296
|
+
Initializes the process group, sets up rank information, and configures the CUDA
|
|
297
|
+
device for the current process in DDP mode.
|
|
298
|
+
"""
|
|
299
|
+
required_env_vars = ["RANK", "LOCAL_RANK", "WORLD_SIZE"]
|
|
300
|
+
for var in required_env_vars:
|
|
301
|
+
if var not in os.environ:
|
|
302
|
+
raise ValueError(f"DDP enabled but {var} environment variable not set")
|
|
303
|
+
|
|
304
|
+
if not torch.cuda.is_available():
|
|
305
|
+
raise RuntimeError("DDP requires CUDA but CUDA is not available")
|
|
306
|
+
|
|
307
|
+
if not torch.distributed.is_initialized():
|
|
308
|
+
init_process_group(backend="nccl")
|
|
309
|
+
|
|
310
|
+
self.ddp_rank = int(os.environ["RANK"])
|
|
311
|
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
|
312
|
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"])
|
|
313
|
+
|
|
314
|
+
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
|
|
315
|
+
torch.cuda.set_device(self.device)
|
|
316
|
+
|
|
317
|
+
self.master_process = self.ddp_rank == 0
|
|
318
|
+
|
|
319
|
+
if self.master_process:
|
|
320
|
+
print(f"DDP initialized with world_size={self.ddp_world_size}")
|
|
321
|
+
|
|
322
|
+
def _setup_single_gpu(self) -> None:
|
|
323
|
+
"""Sets up single GPU or CPU training configuration.
|
|
324
|
+
|
|
325
|
+
Configures the training setup for single-device operation, setting rank and process
|
|
326
|
+
information for non-DDP training.
|
|
327
|
+
"""
|
|
328
|
+
self.ddp_rank = 0
|
|
329
|
+
self.ddp_local_rank = 0
|
|
330
|
+
self.ddp_world_size = 1
|
|
331
|
+
self.master_process = True
|
|
332
|
+
|
|
333
|
+
@staticmethod
|
|
334
|
+
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_epochs: int) -> torch.optim.lr_scheduler.LambdaLR:
|
|
335
|
+
"""Creates a learning rate scheduler for warmup.
|
|
336
|
+
|
|
337
|
+
Generates a scheduler that linearly increases the learning rate from 0 to the
|
|
338
|
+
optimizer's initial value over the specified warmup epochs.
|
|
339
|
+
|
|
340
|
+
Parameters
|
|
341
|
+
----------
|
|
342
|
+
`optimizer` : torch.optim.Optimizer
|
|
343
|
+
Optimizer to apply the scheduler to.
|
|
344
|
+
`warmup_epochs` : int
|
|
345
|
+
Number of epochs for the warmup phase.
|
|
346
|
+
|
|
347
|
+
Returns
|
|
348
|
+
-------
|
|
349
|
+
lr_scheduler : torch.optim.lr_scheduler.LambdaLR
|
|
350
|
+
Learning rate scheduler for warmup.
|
|
351
|
+
"""
|
|
352
|
+
def lr_lambda(epoch):
|
|
353
|
+
return min(1.0, epoch / warmup_epochs) if warmup_epochs > 0 else 1.0
|
|
354
|
+
|
|
355
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
356
|
+
|
|
357
|
+
def _wrap_models_for_ddp(self) -> None:
|
|
358
|
+
"""Wraps models with DistributedDataParallel for multi-GPU training.
|
|
359
|
+
|
|
360
|
+
Configures the upsampler model for DDP training by wrapping it with DistributedDataParallel.
|
|
361
|
+
"""
|
|
362
|
+
if self.use_ddp:
|
|
363
|
+
self.upsampler_model = self.upsampler_model.to(self.ddp_local_rank)
|
|
364
|
+
self.upsampler_model = DDP(
|
|
365
|
+
self.upsampler_model,
|
|
366
|
+
device_ids=[self.ddp_local_rank],
|
|
367
|
+
find_unused_parameters=True
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
def _compile_models(self) -> None:
|
|
371
|
+
"""Compiles models for optimization if supported.
|
|
372
|
+
|
|
373
|
+
Attempts to compile the upsampler model using torch.compile for optimization,
|
|
374
|
+
falling back to uncompiled execution if compilation fails.
|
|
375
|
+
"""
|
|
376
|
+
if self.use_compilation:
|
|
377
|
+
try:
|
|
378
|
+
self.upsampler_model = self.upsampler_model.to(self.device)
|
|
379
|
+
self.upsampler_model = torch.compile(self.upsampler_model, mode="reduce-overhead")
|
|
380
|
+
|
|
381
|
+
if self.master_process:
|
|
382
|
+
print("Models compiled successfully")
|
|
383
|
+
except Exception as e:
|
|
384
|
+
if self.master_process:
|
|
385
|
+
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
386
|
+
|
|
387
|
+
def corrupt_conditioning_image(self, x_low: torch.Tensor, corruption_type: str = "gaussian_blur") -> torch.Tensor:
|
|
388
|
+
"""Corrupts the low-resolution conditioning image for robustness.
|
|
389
|
+
|
|
390
|
+
Applies Gaussian blur or BSR degradation to the low-resolution image to simulate
|
|
391
|
+
real-world degradation, as specified in the UnCLIP paper.
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
`x_low` : torch.Tensor
|
|
396
|
+
Low-resolution input image, shape (batch_size, channels, low_res_size, low_res_size).
|
|
397
|
+
`corruption_type` : str, optional
|
|
398
|
+
Type of corruption to apply: "gaussian_blur" or "bsr_degradation" (default: "gaussian_blur").
|
|
399
|
+
|
|
400
|
+
Returns
|
|
401
|
+
-------
|
|
402
|
+
x_degraded : torch.Tensor
|
|
403
|
+
Corrupted low-resolution image, same shape as input.
|
|
404
|
+
"""
|
|
405
|
+
if corruption_type == "gaussian_blur":
|
|
406
|
+
# apply Gaussian blur
|
|
407
|
+
kernel_size = random.choice([3, 5, 7])
|
|
408
|
+
sigma = random.uniform(0.5, 2.0)
|
|
409
|
+
return self._gaussian_blur(x_low, kernel_size, sigma)
|
|
410
|
+
elif corruption_type == "bsr_degradation":
|
|
411
|
+
# more diverse BSR degradation for second upsampler
|
|
412
|
+
return self._bsr_degradation(x_low)
|
|
413
|
+
else:
|
|
414
|
+
return x_low
|
|
415
|
+
|
|
416
|
+
def _gaussian_blur(self, x: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
|
|
417
|
+
"""Applies Gaussian blur to the input image.
|
|
418
|
+
|
|
419
|
+
Parameters
|
|
420
|
+
----------
|
|
421
|
+
`x` : torch.Tensor
|
|
422
|
+
Input image tensor, shape (batch_size, channels, height, width).
|
|
423
|
+
`kernel_size` : int
|
|
424
|
+
Size of the Gaussian kernel.
|
|
425
|
+
`sigma` : float
|
|
426
|
+
Standard deviation of the Gaussian distribution.
|
|
427
|
+
|
|
428
|
+
Returns
|
|
429
|
+
-------
|
|
430
|
+
x_blurred : torch.Tensor
|
|
431
|
+
Blurred image tensor, same shape as input.
|
|
432
|
+
"""
|
|
433
|
+
# create Gaussian kernel
|
|
434
|
+
kernel = self._get_gaussian_kernel(kernel_size, sigma).to(x.device)
|
|
435
|
+
kernel = kernel.expand(x.shape[1], 1, kernel_size, kernel_size)
|
|
436
|
+
padding = kernel_size // 2
|
|
437
|
+
return F.conv2d(x, kernel, padding=padding, groups=x.shape[1])
|
|
438
|
+
|
|
439
|
+
def _get_gaussian_kernel(self, kernel_size: int, sigma: float) -> torch.Tensor:
|
|
440
|
+
"""Generates a 2D Gaussian kernel.
|
|
441
|
+
|
|
442
|
+
Parameters
|
|
443
|
+
----------
|
|
444
|
+
`kernel_size` : int
|
|
445
|
+
Size of the Gaussian kernel.
|
|
446
|
+
`sigma` : float
|
|
447
|
+
Standard deviation of the Gaussian distribution.
|
|
448
|
+
|
|
449
|
+
Returns
|
|
450
|
+
-------
|
|
451
|
+
kernel : torch.Tensor
|
|
452
|
+
2D Gaussian kernel, shape (kernel_size, kernel_size).
|
|
453
|
+
"""
|
|
454
|
+
coords = torch.arange(kernel_size, dtype=torch.float32) - kernel_size // 2
|
|
455
|
+
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
|
|
456
|
+
g = g / g.sum()
|
|
457
|
+
return g[:, None] * g[None, :]
|
|
458
|
+
|
|
459
|
+
def _bsr_degradation(self, x: torch.Tensor) -> torch.Tensor:
|
|
460
|
+
"""Applies BSR degradation to the input image.
|
|
461
|
+
|
|
462
|
+
Simulates degradation with noise and Gaussian blur, as used in the UnCLIP paper
|
|
463
|
+
for the second upsampler.
|
|
464
|
+
|
|
465
|
+
Parameters
|
|
466
|
+
----------
|
|
467
|
+
`x` : torch.Tensor
|
|
468
|
+
Input image tensor, shape (batch_size, channels, height, width).
|
|
469
|
+
|
|
470
|
+
Returns
|
|
471
|
+
-------
|
|
472
|
+
x_degraded : torch.Tensor
|
|
473
|
+
Degraded image tensor, same shape as input, clamped to [-1, 1].
|
|
474
|
+
"""
|
|
475
|
+
# add noise
|
|
476
|
+
noise_level = random.uniform(0.0, 0.1)
|
|
477
|
+
noise = torch.randn_like(x) * noise_level
|
|
478
|
+
|
|
479
|
+
# apply blur
|
|
480
|
+
kernel_size = random.choice([3, 5, 7])
|
|
481
|
+
sigma = random.uniform(0.5, 3.0)
|
|
482
|
+
x_degraded = self._gaussian_blur(x + noise, kernel_size, sigma)
|
|
483
|
+
|
|
484
|
+
return torch.clamp(x_degraded, -1.0, 1.0)
|
|
485
|
+
|
|
486
|
+
def validate(self) -> float:
|
|
487
|
+
"""Validates the UnCLIP upsampler model.
|
|
488
|
+
|
|
489
|
+
Computes the validation loss by applying forward diffusion to high-resolution images,
|
|
490
|
+
predicting noise with the upsampler model conditioned on corrupted low-resolution images,
|
|
491
|
+
and comparing predicted noise to ground truth.
|
|
492
|
+
|
|
493
|
+
Returns
|
|
494
|
+
-------
|
|
495
|
+
val_loss : float
|
|
496
|
+
Mean validation loss.
|
|
497
|
+
"""
|
|
498
|
+
# set models to eval mode for evaluation
|
|
499
|
+
self.upsampler_model.eval()
|
|
500
|
+
self.upsampler_model.forward_diffusion.eval()
|
|
501
|
+
|
|
502
|
+
val_losses = []
|
|
503
|
+
|
|
504
|
+
with torch.no_grad():
|
|
505
|
+
for low_res_images, high_res_images in self.val_loader:
|
|
506
|
+
low_res_images = low_res_images.to(self.device, non_blocking=True)
|
|
507
|
+
high_res_images = high_res_images.to(self.device, non_blocking=True)
|
|
508
|
+
batch_size = high_res_images.shape[0]
|
|
509
|
+
timesteps = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device)
|
|
510
|
+
noise = torch.randn_like(high_res_images)
|
|
511
|
+
high_res_images_noisy = self.upsampler_model.forward_diffusion(high_res_images, noise, timesteps)
|
|
512
|
+
corruption_type = "gaussian_blur" if self.upsampler_model.low_res_size == 64 else "bsr_degradation"
|
|
513
|
+
low_res_images_corrupted = self.corrupt_conditioning_image(low_res_images, corruption_type)
|
|
514
|
+
predicted_noise = self.upsampler_model(high_res_images_noisy, timesteps, low_res_images_corrupted)
|
|
515
|
+
# compute loss
|
|
516
|
+
loss = self.objective(predicted_noise, noise)
|
|
517
|
+
val_losses.append(loss.item())
|
|
518
|
+
|
|
519
|
+
# compute average loss
|
|
520
|
+
val_loss = torch.tensor(val_losses).mean().item()
|
|
521
|
+
|
|
522
|
+
if self.use_ddp:
|
|
523
|
+
val_loss_tensor = torch.tensor(val_loss, device=self.device)
|
|
524
|
+
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
|
|
525
|
+
val_loss = val_loss_tensor.item()
|
|
526
|
+
|
|
527
|
+
# return to training mode
|
|
528
|
+
self.upsampler_model.train()
|
|
529
|
+
if not self.upsampler_model.forward_diffusion.variance_scheduler.trainable_beta:
|
|
530
|
+
self.upsampler_model.forward_diffusion.variance_scheduler.eval()
|
|
531
|
+
|
|
532
|
+
return val_loss
|
|
533
|
+
|
|
534
|
+
def _save_checkpoint(self, epoch: int, loss: float, is_best: bool = False, suffix: str = ""):
|
|
535
|
+
"""Saves model checkpoint.
|
|
536
|
+
|
|
537
|
+
Saves the state of the upsampler model, its variance scheduler, optimizer, and
|
|
538
|
+
schedulers, with options for best model and epoch-specific checkpoints.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
`epoch` : int
|
|
543
|
+
Current epoch number.
|
|
544
|
+
`loss` : float
|
|
545
|
+
Current loss value.
|
|
546
|
+
`is_best` : bool, optional
|
|
547
|
+
Whether to save as the best model checkpoint (default: False).
|
|
548
|
+
`suffix` : str, optional
|
|
549
|
+
Suffix to add to checkpoint filename, default "".
|
|
550
|
+
"""
|
|
551
|
+
if not self.master_process:
|
|
552
|
+
return
|
|
553
|
+
checkpoint = {
|
|
554
|
+
'epoch': epoch,
|
|
555
|
+
'loss': loss,
|
|
556
|
+
# Core model
|
|
557
|
+
'upsampler_model_state_dict': self.upsampler_model.module.state_dict() if self.use_ddp else self.upsampler_model.state_dict(),
|
|
558
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
559
|
+
# Training configuration
|
|
560
|
+
'model_channels': self.upsampler_model.model_channels,
|
|
561
|
+
'num_res_blocks': self.upsampler_model.num_res_blocks,
|
|
562
|
+
'normalize': self.normalize_image_outputs,
|
|
563
|
+
'output_range': self.image_output_range
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
# Save variance scheduler (submodule of forward_diffusion)
|
|
567
|
+
checkpoint['variance_scheduler_state_dict'] = (
|
|
568
|
+
self.upsampler_model.module.forward_diffusion.variance_scheduler.state_dict() if self.use_ddp
|
|
569
|
+
else self.upsampler_model.forward_diffusion.variance_scheduler.state_dict()
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
# Save schedulers state
|
|
573
|
+
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
|
574
|
+
checkpoint['warmup_scheduler_state_dict'] = self.warmup_lr_scheduler.state_dict()
|
|
575
|
+
|
|
576
|
+
filename = f"unclip_upsampler_epoch_{epoch}{suffix}.pth"
|
|
577
|
+
if is_best:
|
|
578
|
+
filename = f"unclip_upsampler_best{suffix}.pth"
|
|
579
|
+
|
|
580
|
+
filepath = os.path.join(self.store_path, filename)
|
|
581
|
+
os.makedirs(self.store_path, exist_ok=True)
|
|
582
|
+
torch.save(checkpoint, filepath)
|
|
583
|
+
|
|
584
|
+
if is_best:
|
|
585
|
+
print(f"Best model saved: {filepath}")
|
|
586
|
+
|
|
587
|
+
def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
|
|
588
|
+
"""Loads model checkpoint.
|
|
589
|
+
|
|
590
|
+
Restores the state of the upsampler model, its variance scheduler, optimizer, and
|
|
591
|
+
schedulers from a saved checkpoint, handling DDP compatibility.
|
|
592
|
+
|
|
593
|
+
Parameters
|
|
594
|
+
----------
|
|
595
|
+
`checkpoint_path` : str
|
|
596
|
+
Path to the checkpoint file.
|
|
597
|
+
|
|
598
|
+
Returns
|
|
599
|
+
-------
|
|
600
|
+
epoch : int
|
|
601
|
+
The epoch at which the checkpoint was saved.
|
|
602
|
+
loss : float
|
|
603
|
+
The loss at the checkpoint.
|
|
604
|
+
"""
|
|
605
|
+
try:
|
|
606
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
607
|
+
except FileNotFoundError:
|
|
608
|
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
609
|
+
|
|
610
|
+
def _load_model_state_dict(model: nn.Module, state_dict: dict, model_name: str) -> None:
|
|
611
|
+
"""Helper function to load state dict with DDP compatibility."""
|
|
612
|
+
try:
|
|
613
|
+
# Handle DDP state dict compatibility
|
|
614
|
+
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
|
|
615
|
+
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
|
|
616
|
+
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
617
|
+
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
618
|
+
|
|
619
|
+
model.load_state_dict(state_dict)
|
|
620
|
+
if self.master_process:
|
|
621
|
+
print(f"✓ Loaded {model_name}")
|
|
622
|
+
except Exception as e:
|
|
623
|
+
warnings.warn(f"Failed to load {model_name}: {e}")
|
|
624
|
+
|
|
625
|
+
# Load core upsampler model
|
|
626
|
+
if 'upsampler_model_state_dict' in checkpoint:
|
|
627
|
+
_load_model_state_dict(self.upsampler_model, checkpoint['upsampler_model_state_dict'],
|
|
628
|
+
'upsampler_model')
|
|
629
|
+
|
|
630
|
+
# Load variance scheduler (submodule of forward_diffusion)
|
|
631
|
+
if 'variance_scheduler_state_dict' in checkpoint or 'hyper_params_state_dict' in checkpoint:
|
|
632
|
+
state_dict = checkpoint.get('variance_scheduler_state_dict', checkpoint.get('hyper_params_state_dict'))
|
|
633
|
+
try:
|
|
634
|
+
_load_model_state_dict(self.upsampler_model.forward_diffusion.variance_scheduler, state_dict, 'variance_scheduler')
|
|
635
|
+
except Exception as e:
|
|
636
|
+
warnings.warn(f"Failed to load variance scheduler: {e}")
|
|
637
|
+
|
|
638
|
+
# Load optimizer
|
|
639
|
+
if 'optimizer_state_dict' in checkpoint:
|
|
640
|
+
try:
|
|
641
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
642
|
+
if self.master_process:
|
|
643
|
+
print("✓ Loaded optimizer")
|
|
644
|
+
except Exception as e:
|
|
645
|
+
warnings.warn(f"Failed to load optimizer state: {e}")
|
|
646
|
+
|
|
647
|
+
# Load schedulers
|
|
648
|
+
if 'scheduler_state_dict' in checkpoint:
|
|
649
|
+
try:
|
|
650
|
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
651
|
+
if self.master_process:
|
|
652
|
+
print("✓ Loaded main scheduler")
|
|
653
|
+
except Exception as e:
|
|
654
|
+
warnings.warn(f"Failed to load scheduler state: {e}")
|
|
655
|
+
|
|
656
|
+
if 'warmup_scheduler_state_dict' in checkpoint:
|
|
657
|
+
try:
|
|
658
|
+
self.warmup_lr_scheduler.load_state_dict(checkpoint['warmup_scheduler_state_dict'])
|
|
659
|
+
if self.master_process:
|
|
660
|
+
print("✓ Loaded warmup scheduler")
|
|
661
|
+
except Exception as e:
|
|
662
|
+
warnings.warn(f"Failed to load warmup scheduler state: {e}")
|
|
663
|
+
|
|
664
|
+
# Verify configuration compatibility
|
|
665
|
+
if 'model_channels' in checkpoint:
|
|
666
|
+
if checkpoint['model_channels'] != self.upsampler_model.model_channels:
|
|
667
|
+
warnings.warn(
|
|
668
|
+
f"Model channels mismatch: checkpoint={checkpoint['model_channels']}, current={self.upsampler_model.model_channels}")
|
|
669
|
+
|
|
670
|
+
if 'num_res_blocks' in checkpoint:
|
|
671
|
+
if checkpoint['num_res_blocks'] != self.upsampler_model.num_res_blocks:
|
|
672
|
+
warnings.warn(
|
|
673
|
+
f"Num res blocks mismatch: checkpoint={checkpoint['num_res_blocks']}, current={self.upsampler_model.num_res_blocks}")
|
|
674
|
+
|
|
675
|
+
if 'normalize' in checkpoint:
|
|
676
|
+
if checkpoint['normalize'] != self.normalize_image_outputs:
|
|
677
|
+
warnings.warn(
|
|
678
|
+
f"Normalize setting mismatch: checkpoint={checkpoint['normalize']}, current={self.normalize_image_outputs}")
|
|
679
|
+
|
|
680
|
+
epoch = checkpoint.get('epoch', 0)
|
|
681
|
+
loss = checkpoint.get('loss', float('inf'))
|
|
682
|
+
|
|
683
|
+
if self.master_process:
|
|
684
|
+
print(f"Successfully loaded checkpoint from {checkpoint_path}")
|
|
685
|
+
print(f"Epoch: {epoch}, Loss: {loss:.4f}")
|
|
686
|
+
|
|
687
|
+
return epoch, loss
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
"""
|
|
691
|
+
from prior_diff import VarianceSchedulerUnCLIP, ForwardUnCLIP
|
|
692
|
+
from upsampler import UpsamplerUnCLIP
|
|
693
|
+
import torch
|
|
694
|
+
import torch.optim as optim
|
|
695
|
+
import torch.nn as nn
|
|
696
|
+
from torch.utils.data import Dataset, DataLoader
|
|
697
|
+
|
|
698
|
+
# Define a dummy dataset for example purposes (replace with real dataset in practice)
|
|
699
|
+
class DummyDataset(Dataset):
|
|
700
|
+
def __init__(self, num_samples=1000, low_res_size=64, high_res_size=256):
|
|
701
|
+
self.num_samples = num_samples
|
|
702
|
+
self.low_res_size = low_res_size
|
|
703
|
+
self.high_res_size = high_res_size
|
|
704
|
+
|
|
705
|
+
def __len__(self):
|
|
706
|
+
return self.num_samples
|
|
707
|
+
|
|
708
|
+
def __getitem__(self, idx):
|
|
709
|
+
# Generate random low-res and high-res images (in practice, load from disk or augment)
|
|
710
|
+
low_res_image = torch.rand(3, self.low_res_size, self.low_res_size) * 2 - 1 # Normalize to [-1, 1]
|
|
711
|
+
high_res_image = torch.rand(3, self.high_res_size, self.high_res_size) * 2 - 1 # Normalize to [-1, 1]
|
|
712
|
+
return low_res_image, high_res_image
|
|
713
|
+
|
|
714
|
+
# Instantiate the variance scheduler
|
|
715
|
+
hyp = VarianceSchedulerUnCLIP(
|
|
716
|
+
num_steps=400,
|
|
717
|
+
beta_start=1e-4,
|
|
718
|
+
beta_end=0.02,
|
|
719
|
+
trainable_beta=True,
|
|
720
|
+
beta_method="linear"
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# Instantiate the forward diffusion process
|
|
724
|
+
forward = ForwardUnCLIP(hyp)
|
|
725
|
+
|
|
726
|
+
# Instantiate the upsampler model
|
|
727
|
+
model = UpsamplerUnCLIP(
|
|
728
|
+
forward_diffusion=forward,
|
|
729
|
+
in_channels=3,
|
|
730
|
+
out_channels=3,
|
|
731
|
+
model_channels=32,
|
|
732
|
+
num_res_blocks=2,
|
|
733
|
+
channel_mult=(1, 2, 4, 8),
|
|
734
|
+
dropout=0.1,
|
|
735
|
+
time_embed_dim=32,
|
|
736
|
+
low_res_size=64,
|
|
737
|
+
high_res_size=256
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
# Create train loader with dummy dataset (replace with real DataLoader for your dataset)
|
|
741
|
+
train_dataset = DummyDataset(num_samples=4)
|
|
742
|
+
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
|
|
743
|
+
|
|
744
|
+
# Optional validation loader (using same dummy for example)
|
|
745
|
+
val_dataset = DummyDataset(num_samples=2)
|
|
746
|
+
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0)
|
|
747
|
+
|
|
748
|
+
# Define optimizer
|
|
749
|
+
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
|
|
750
|
+
|
|
751
|
+
# Define objective (loss function, e.g., MSE for noise prediction)
|
|
752
|
+
objective = nn.MSELoss()
|
|
753
|
+
|
|
754
|
+
# Instantiate the trainer
|
|
755
|
+
trainer = TrainUpsamplerUnCLIP(
|
|
756
|
+
upsampler_model=model,
|
|
757
|
+
train_loader=train_loader,
|
|
758
|
+
optimizer=optimizer,
|
|
759
|
+
objective=objective,
|
|
760
|
+
val_loader=val_loader, # Optional
|
|
761
|
+
max_epoch=10, # Small number for example; increase for real training
|
|
762
|
+
device='cuda' if torch.cuda.is_available() else 'cpu',
|
|
763
|
+
store_path="upsampler",
|
|
764
|
+
patience=10,
|
|
765
|
+
warmup_epochs=2,
|
|
766
|
+
val_frequency=5,
|
|
767
|
+
use_ddp=False, # Set to True if using distributed training
|
|
768
|
+
num_grad_accumulation=2,
|
|
769
|
+
progress_frequency=1,
|
|
770
|
+
compilation=True, # Set to True if torch.compile is desired and supported
|
|
771
|
+
output_range=(-1.0, 1.0),
|
|
772
|
+
normalize=True,
|
|
773
|
+
use_autocast=False
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
# Run the training
|
|
777
|
+
train_losses, best_val_loss = trainer()
|
|
778
|
+
|
|
779
|
+
# Print results
|
|
780
|
+
print(f"Training losses: {train_losses}")
|
|
781
|
+
print(f"Best validation loss: {best_val_loss}")
|
|
782
|
+
|
|
783
|
+
"""
|
|
784
|
+
|