TorchDiff 2.2.0__tar.gz → 2.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchdiff-2.2.0 → torchdiff-2.3.0}/PKG-INFO +1 -1
- {torchdiff-2.2.0 → torchdiff-2.3.0}/README.md +3 -3
- {torchdiff-2.2.0 → torchdiff-2.3.0}/TorchDiff.egg-info/PKG-INFO +1 -1
- {torchdiff-2.2.0 → torchdiff-2.3.0}/TorchDiff.egg-info/SOURCES.txt +10 -14
- {torchdiff-2.2.0 → torchdiff-2.3.0}/setup.py +1 -1
- torchdiff-2.3.0/torchdiff/__init__.py +13 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/ddim.py +78 -82
- {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/ddpm.py +45 -49
- {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/ldm.py +115 -102
- {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/sde.py +59 -61
- torchdiff-2.3.0/torchdiff/unclip.py +3776 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/utils.py +1 -2
- torchdiff-2.2.0/unclip/clip_model.py → torchdiff-2.3.0/unclip/clip_encoder.py +8 -138
- torchdiff-2.3.0/unclip/forward_unclip.py +55 -0
- torchdiff-2.2.0/unclip/project_prior.py → torchdiff-2.3.0/unclip/projections.py +76 -49
- torchdiff-2.3.0/unclip/reverse_unclip.py +114 -0
- torchdiff-2.3.0/unclip/scheduler.py +125 -0
- torchdiff-2.3.0/unclip/train_unclip_decoder.py +779 -0
- torchdiff-2.3.0/unclip/train_unclip_prior.py +497 -0
- torchdiff-2.2.0/unclip/decoder_model.py → torchdiff-2.3.0/unclip/unclip_decoder.py +94 -129
- torchdiff-2.3.0/unclip/unclip_sampler.py +307 -0
- torchdiff-2.3.0/unclip/unclip_trainstormer_prior.py +357 -0
- torchdiff-2.3.0/unclip/upsampler_trainer.py +559 -0
- torchdiff-2.2.0/unclip/upsampler.py → torchdiff-2.3.0/unclip/upsampler_unclip.py +65 -141
- torchdiff-2.2.0/torchdiff/__init__.py +0 -8
- torchdiff-2.2.0/torchdiff/tests/test_ldm.py +0 -660
- torchdiff-2.2.0/torchdiff/tests/test_unclip.py +0 -315
- torchdiff-2.2.0/torchdiff/unclip.py +0 -4171
- torchdiff-2.2.0/unclip/ddim_model.py +0 -1296
- torchdiff-2.2.0/unclip/prior_diff.py +0 -402
- torchdiff-2.2.0/unclip/prior_model.py +0 -264
- torchdiff-2.2.0/unclip/project_decoder.py +0 -57
- torchdiff-2.2.0/unclip/train_decoder.py +0 -1059
- torchdiff-2.2.0/unclip/train_prior.py +0 -757
- torchdiff-2.2.0/unclip/unclip_sampler.py +0 -626
- torchdiff-2.2.0/unclip/upsampler_trainer.py +0 -784
- torchdiff-2.2.0/unclip/utils.py +0 -1793
- torchdiff-2.2.0/unclip/val_metrics.py +0 -221
- {torchdiff-2.2.0 → torchdiff-2.3.0}/LICENSE +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/TorchDiff.egg-info/dependency_links.txt +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/TorchDiff.egg-info/requires.txt +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/TorchDiff.egg-info/top_level.txt +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/__init__.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/forward_ddim.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/reverse_ddim.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/sample_ddim.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/scheduler.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/test_ddim.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddim/train_ddim.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/__init__.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/forward_ddpm.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/reverse_ddpm.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/sample_ddpm.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/scheduler.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/test_ddpm.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ddpm/train_ddpm.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ldm/__init__.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ldm/autoencoder.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ldm/sample_ldm.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ldm/train_autoencoder.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/ldm/train_ldm.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/__init__.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/forward_sde.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/reverse_sde.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/sample_sde.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/scheduler.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/test_sde.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/sde/train_sde.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/setup.cfg +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/tests/__init__.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/tests/test_ddim.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/tests/test_ddpm.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/torchdiff/tests/test_sde.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/unclip/__init__.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/utils/__init__.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/utils/diff_net.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/utils/losses.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/utils/metrics.py +0 -0
- {torchdiff-2.2.0 → torchdiff-2.3.0}/utils/text_encoder.py +0 -0
|
@@ -9,8 +9,8 @@
|
|
|
9
9
|
|
|
10
10
|
[](https://opensource.org/licenses/MIT)
|
|
11
11
|
[](https://pytorch.org/)
|
|
12
|
-
[](https://pypi.org/project/torchdiff/)
|
|
13
|
+
[](https://www.python.org/)
|
|
14
14
|
[](https://pepy.tech/project/torchdiff)
|
|
15
15
|
[](https://github.com/LoqmanSamani/TorchDiff)
|
|
16
16
|
[](https://github.com/LoqmanSamani/TorchDiff)
|
|
@@ -118,7 +118,7 @@ trainer = TrainDDPM(
|
|
|
118
118
|
device = device,
|
|
119
119
|
grad_acc = 2
|
|
120
120
|
)
|
|
121
|
-
|
|
121
|
+
trainer()
|
|
122
122
|
|
|
123
123
|
# Sampling
|
|
124
124
|
sampler = SampleDDPM(
|
|
@@ -42,24 +42,20 @@ torchdiff/utils.py
|
|
|
42
42
|
torchdiff/tests/__init__.py
|
|
43
43
|
torchdiff/tests/test_ddim.py
|
|
44
44
|
torchdiff/tests/test_ddpm.py
|
|
45
|
-
torchdiff/tests/test_ldm.py
|
|
46
45
|
torchdiff/tests/test_sde.py
|
|
47
|
-
torchdiff/tests/test_unclip.py
|
|
48
46
|
unclip/__init__.py
|
|
49
|
-
unclip/
|
|
50
|
-
unclip/
|
|
51
|
-
unclip/
|
|
52
|
-
unclip/
|
|
53
|
-
unclip/
|
|
54
|
-
unclip/
|
|
55
|
-
unclip/
|
|
56
|
-
unclip/
|
|
57
|
-
unclip/train_prior.py
|
|
47
|
+
unclip/clip_encoder.py
|
|
48
|
+
unclip/forward_unclip.py
|
|
49
|
+
unclip/projections.py
|
|
50
|
+
unclip/reverse_unclip.py
|
|
51
|
+
unclip/scheduler.py
|
|
52
|
+
unclip/train_unclip_decoder.py
|
|
53
|
+
unclip/train_unclip_prior.py
|
|
54
|
+
unclip/unclip_decoder.py
|
|
58
55
|
unclip/unclip_sampler.py
|
|
59
|
-
unclip/
|
|
56
|
+
unclip/unclip_trainstormer_prior.py
|
|
60
57
|
unclip/upsampler_trainer.py
|
|
61
|
-
unclip/
|
|
62
|
-
unclip/val_metrics.py
|
|
58
|
+
unclip/upsampler_unclip.py
|
|
63
59
|
utils/__init__.py
|
|
64
60
|
utils/diff_net.py
|
|
65
61
|
utils/losses.py
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
__version__ = "2.3.0"
|
|
2
|
+
|
|
3
|
+
from .ddim import ForwardDDIM, ReverseDDIM, SchedulerDDIM, TrainDDIM, SampleDDIM
|
|
4
|
+
from .ddpm import ForwardDDPM, ReverseDDPM, SchedulerDDPM, TrainDDPM, SampleDDPM
|
|
5
|
+
from .ldm import TrainLDM, TrainAE, AutoencoderLDM, SampleLDM
|
|
6
|
+
from .sde import ForwardSDE, ReverseSDE, SchedulerSDE, TrainSDE, SampleSDE
|
|
7
|
+
from .unclip import (
|
|
8
|
+
ForwardUnCLIP, ReverseUnCLIP, SchedulerUnCLIP, CLIPEncoder,
|
|
9
|
+
SampleUnCLIP, UnClipDecoder, UnCLIPTransformerPrior,
|
|
10
|
+
CLIPContextProjection, CLIPEmbeddingProjection, TrainUnClipDecoder,
|
|
11
|
+
SampleUnCLIP, UpsamplerUnCLIP, TrainUpsamplerUnCLIP
|
|
12
|
+
)
|
|
13
|
+
from .utils import DiffusionNetwork, TextEncoder, Metrics, mse_loss, snr_capped_loss, ve_sigma_weighted_score_loss
|
|
@@ -345,7 +345,6 @@ class SchedulerDDIM(nn.Module):
|
|
|
345
345
|
"""
|
|
346
346
|
step_ratio = self.train_steps // self.sample_steps
|
|
347
347
|
inference_timesteps = torch.arange(0, self.train_steps, step_ratio)
|
|
348
|
-
|
|
349
348
|
self.register_buffer('inference_timesteps', inference_timesteps)
|
|
350
349
|
|
|
351
350
|
def set_inference_timesteps(self, num_inference_timesteps: int):
|
|
@@ -393,49 +392,49 @@ class TrainDDIM(nn.Module):
|
|
|
393
392
|
|
|
394
393
|
Parameters
|
|
395
394
|
----------
|
|
396
|
-
`
|
|
397
|
-
|
|
395
|
+
`diff_net` : nn.Module
|
|
396
|
+
Main model to predict noise/v/x0
|
|
398
397
|
fwd_ddim : nn.Module
|
|
399
398
|
Forward DDIM diffusion module for adding noise.
|
|
400
399
|
rwd_ddim: nn.Module
|
|
401
400
|
Reverse DDIM diffusion module for denoising.
|
|
402
401
|
`data_loader` : torch.utils.data.DataLoader
|
|
403
402
|
DataLoader for training data.
|
|
404
|
-
`
|
|
403
|
+
`optim` : torch.optim.Optimizer
|
|
405
404
|
Optimizer for training the noise predictor and conditional model (if applicable).
|
|
406
|
-
`
|
|
405
|
+
`loss_fn` : callable
|
|
407
406
|
Loss function to compute the difference between predicted and actual noise.
|
|
408
407
|
`val_loader` : torch.utils.data.DataLoader, optional
|
|
409
408
|
DataLoader for validation data, default None.
|
|
410
409
|
`max_epochs` : int, optional
|
|
411
|
-
Maximum number of training epochs (default:
|
|
412
|
-
`device` :
|
|
413
|
-
Device for computation (default: CUDA
|
|
414
|
-
`
|
|
410
|
+
Maximum number of training epochs (default: 100).
|
|
411
|
+
`device` : str
|
|
412
|
+
Device for computation (default: CUDA).
|
|
413
|
+
`cond_net` : nn.Module, optional
|
|
415
414
|
Model for conditional generation (e.g., text embeddings), default None.
|
|
416
415
|
`metrics_` : object, optional
|
|
417
416
|
Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
|
|
418
|
-
`
|
|
417
|
+
`tokenizer` : BertTokenizer, optional
|
|
419
418
|
Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
|
|
420
419
|
`max_token_length` : int, optional
|
|
421
420
|
Maximum length for tokenized prompts (default: 77).
|
|
422
421
|
`store_path` : str, optional
|
|
423
|
-
Path to save model checkpoints (default: "
|
|
422
|
+
Path to save model checkpoints (default: "ddim_train").
|
|
424
423
|
`patience` : int, optional
|
|
425
|
-
Number of epochs to wait for improvement before early stopping (default:
|
|
426
|
-
`
|
|
427
|
-
Number of epochs for learning rate warmup (default:
|
|
428
|
-
`
|
|
424
|
+
Number of epochs to wait for improvement before early stopping (default: 20).
|
|
425
|
+
`warmup_steps` : int, optional
|
|
426
|
+
Number of epochs for learning rate warmup (default: 1000).
|
|
427
|
+
`val_freq` : int, optional
|
|
429
428
|
Frequency (in epochs) for validation (default: 10).
|
|
430
|
-
`
|
|
429
|
+
`norm_range` : tuple, optional
|
|
431
430
|
Range for clamping generated images (default: (-1, 1)).
|
|
432
|
-
`
|
|
431
|
+
`norm_output` : bool, optional
|
|
433
432
|
Whether to normalize generated images to [0, 1] for metrics (default: True).
|
|
434
433
|
`use_ddp` : bool, optional
|
|
435
434
|
Whether to use Distributed Data Parallel training (default: False).
|
|
436
|
-
`
|
|
435
|
+
`grad_acc` : int, optional
|
|
437
436
|
Number of gradient accumulation steps before optimizer update (default: 1).
|
|
438
|
-
`
|
|
437
|
+
`log_freq` : int, optional
|
|
439
438
|
Number of epochs before printing loss.
|
|
440
439
|
use_comp : bool, optional
|
|
441
440
|
whether the model is internally compiled using torch.compile (default: false)
|
|
@@ -449,15 +448,15 @@ class TrainDDIM(nn.Module):
|
|
|
449
448
|
optim: torch.optim.Optimizer,
|
|
450
449
|
loss_fn: Callable,
|
|
451
450
|
val_loader: Optional[torch.utils.data.DataLoader] = None,
|
|
452
|
-
max_epochs: int =
|
|
451
|
+
max_epochs: int = 100,
|
|
453
452
|
device: str = 'cuda',
|
|
454
|
-
|
|
453
|
+
cond_net: torch.nn.Module = None,
|
|
455
454
|
metrics_: Optional[Any] = None,
|
|
456
|
-
|
|
455
|
+
tokenizer: Optional[BertTokenizer] = None,
|
|
457
456
|
max_token_length: int = 77,
|
|
458
457
|
store_path: Optional[str] = None,
|
|
459
|
-
patience: int =
|
|
460
|
-
warmup_steps: int =
|
|
458
|
+
patience: int = 20,
|
|
459
|
+
warmup_steps: int = 1000,
|
|
461
460
|
val_freq: int = 10,
|
|
462
461
|
norm_range: Tuple[float, float] = (-1, 1),
|
|
463
462
|
norm_output: bool = True,
|
|
@@ -481,11 +480,11 @@ class TrainDDIM(nn.Module):
|
|
|
481
480
|
self.diff_net = diff_net.to(self.device)
|
|
482
481
|
self.fwd_ddim = fwd_ddim.to(self.device)
|
|
483
482
|
self.rwd_ddim = rwd_ddim.to(self.device)
|
|
484
|
-
self.
|
|
483
|
+
self.cond_net = cond_net.to(self.device) if cond_net else None
|
|
485
484
|
self.metrics_ = metrics_
|
|
486
485
|
self.optim = optim
|
|
487
486
|
self.loss_fn = loss_fn
|
|
488
|
-
self.store_path = store_path or "
|
|
487
|
+
self.store_path = store_path or "ddim_train"
|
|
489
488
|
self.train_loader = train_loader
|
|
490
489
|
self.val_loader = val_loader
|
|
491
490
|
self.max_epochs = max_epochs
|
|
@@ -506,7 +505,7 @@ class TrainDDIM(nn.Module):
|
|
|
506
505
|
factor=0.5
|
|
507
506
|
)
|
|
508
507
|
self.warmup_lr_scheduler = self.warmup_scheduler(self.optim, warmup_steps)
|
|
509
|
-
if
|
|
508
|
+
if tokenizer is None:
|
|
510
509
|
try:
|
|
511
510
|
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
|
512
511
|
except Exception as e:
|
|
@@ -574,14 +573,14 @@ class TrainDDIM(nn.Module):
|
|
|
574
573
|
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
|
|
575
574
|
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
|
|
576
575
|
self.diff_net.load_state_dict(state_dict)
|
|
577
|
-
if self.
|
|
576
|
+
if self.cond_net is not None:
|
|
578
577
|
if 'model_state_dict_cond' in checkpoint and checkpoint['model_state_dict_cond'] is not None:
|
|
579
578
|
cond_state_dict = checkpoint['model_state_dict_cond']
|
|
580
579
|
if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
|
|
581
580
|
cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
|
|
582
581
|
elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
|
|
583
582
|
cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
|
|
584
|
-
self.
|
|
583
|
+
self.cond_net.load_state_dict(cond_state_dict)
|
|
585
584
|
else:
|
|
586
585
|
warnings.warn(
|
|
587
586
|
"Checkpoint contains no 'model_state_dict_cond' or it is None, "
|
|
@@ -625,8 +624,8 @@ class TrainDDIM(nn.Module):
|
|
|
625
624
|
----------
|
|
626
625
|
`optimizer` : torch.optim.Optimizer
|
|
627
626
|
Optimizer to apply the scheduler to.
|
|
628
|
-
`
|
|
629
|
-
Number of
|
|
627
|
+
`warmup_steps` : int
|
|
628
|
+
Number of steps for the warmup phase.
|
|
630
629
|
|
|
631
630
|
Returns
|
|
632
631
|
-------
|
|
@@ -646,9 +645,9 @@ class TrainDDIM(nn.Module):
|
|
|
646
645
|
device_ids=[self.ddp_local_rank],
|
|
647
646
|
find_unused_parameters=True
|
|
648
647
|
)
|
|
649
|
-
if self.
|
|
650
|
-
self.
|
|
651
|
-
self.
|
|
648
|
+
if self.cond_net is not None:
|
|
649
|
+
self.cond_net = DDP(
|
|
650
|
+
self.cond_net,
|
|
652
651
|
device_ids=[self.ddp_local_rank],
|
|
653
652
|
find_unused_parameters=True
|
|
654
653
|
)
|
|
@@ -662,20 +661,17 @@ class TrainDDIM(nn.Module):
|
|
|
662
661
|
|
|
663
662
|
Returns
|
|
664
663
|
-------
|
|
665
|
-
|
|
666
|
-
List of mean training losses per epoch.
|
|
667
|
-
best_val_loss : float
|
|
668
|
-
Best validation or training loss achieved.
|
|
664
|
+
losses: dictionlary contains train and validation losses
|
|
669
665
|
"""
|
|
670
666
|
self.diff_net.train()
|
|
671
|
-
if self.
|
|
672
|
-
self.
|
|
667
|
+
if self.cond_net is not None:
|
|
668
|
+
self.cond_net.train()
|
|
673
669
|
|
|
674
670
|
if self.use_comp:
|
|
675
671
|
try:
|
|
676
672
|
self.diff_net = torch.compile(self.diff_net)
|
|
677
|
-
if self.
|
|
678
|
-
self.
|
|
673
|
+
if self.cond_net is not None:
|
|
674
|
+
self.cond_net = torch.compile(self.cond_net)
|
|
679
675
|
except Exception as e:
|
|
680
676
|
if self.master_process:
|
|
681
677
|
print(f"Model compilation failed: {e}. Continuing without compilation.")
|
|
@@ -690,7 +686,7 @@ class TrainDDIM(nn.Module):
|
|
|
690
686
|
train_losses_epoch = []
|
|
691
687
|
for step, (x, y) in enumerate(pbar):
|
|
692
688
|
x = x.to(self.device)
|
|
693
|
-
if self.
|
|
689
|
+
if self.cond_net is not None:
|
|
694
690
|
y_encoded = self._process_conditional_input(y)
|
|
695
691
|
else:
|
|
696
692
|
y_encoded = None
|
|
@@ -705,8 +701,8 @@ class TrainDDIM(nn.Module):
|
|
|
705
701
|
if (step + 1) % self.grad_acc == 0:
|
|
706
702
|
scaler.unscale_(self.optim)
|
|
707
703
|
torch.nn.utils.clip_grad_norm_(self.diff_net.parameters(), max_norm=1.0)
|
|
708
|
-
if self.
|
|
709
|
-
torch.nn.utils.clip_grad_norm_(self.
|
|
704
|
+
if self.cond_net is not None:
|
|
705
|
+
torch.nn.utils.clip_grad_norm_(self.cond_net.parameters(), max_norm=1.0)
|
|
710
706
|
scaler.step(self.optim)
|
|
711
707
|
scaler.update()
|
|
712
708
|
self.optim.zero_grad()
|
|
@@ -786,7 +782,7 @@ class TrainDDIM(nn.Module):
|
|
|
786
782
|
).to(self.device)
|
|
787
783
|
input_ids = y_encoded["input_ids"]
|
|
788
784
|
attention_mask = y_encoded["attention_mask"]
|
|
789
|
-
y_encoded = self.
|
|
785
|
+
y_encoded = self.cond_net(input_ids, attention_mask)
|
|
790
786
|
return y_encoded
|
|
791
787
|
|
|
792
788
|
def _save_checkpoint(self, epoch: int, loss: float, pref: str = "") -> None:
|
|
@@ -807,10 +803,10 @@ class TrainDDIM(nn.Module):
|
|
|
807
803
|
else self.diff_net.state_dict()
|
|
808
804
|
)
|
|
809
805
|
cond_state = None
|
|
810
|
-
if self.
|
|
806
|
+
if self.cond_net is not None:
|
|
811
807
|
cond_state = (
|
|
812
|
-
self.
|
|
813
|
-
else self.
|
|
808
|
+
self.cond_net.module.state_dict() if self.use_ddp
|
|
809
|
+
else self.cond_net.state_dict()
|
|
814
810
|
)
|
|
815
811
|
checkpoint = {
|
|
816
812
|
'epoch': epoch,
|
|
@@ -826,8 +822,7 @@ class TrainDDIM(nn.Module):
|
|
|
826
822
|
filepath = os.path.join(self.store_path, filename)
|
|
827
823
|
os.makedirs(self.store_path, exist_ok=True)
|
|
828
824
|
torch.save(checkpoint, filepath)
|
|
829
|
-
|
|
830
|
-
print(f"Model saved at epoch {epoch}")
|
|
825
|
+
print(f"Model saved at epoch {epoch} with loss: {loss:.4f}")
|
|
831
826
|
except Exception as e:
|
|
832
827
|
print(f"Failed to save model: {e}")
|
|
833
828
|
|
|
@@ -856,8 +851,8 @@ class TrainDDIM(nn.Module):
|
|
|
856
851
|
"""
|
|
857
852
|
|
|
858
853
|
self.diff_net.eval()
|
|
859
|
-
if self.
|
|
860
|
-
self.
|
|
854
|
+
if self.cond_net is not None:
|
|
855
|
+
self.cond_net.eval()
|
|
861
856
|
|
|
862
857
|
val_losses = []
|
|
863
858
|
fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
|
|
@@ -865,7 +860,7 @@ class TrainDDIM(nn.Module):
|
|
|
865
860
|
for x, y in self.val_loader:
|
|
866
861
|
x = x.to(self.device)
|
|
867
862
|
x_orig = x.clone()
|
|
868
|
-
if self.
|
|
863
|
+
if self.cond_net is not None:
|
|
869
864
|
y_encoded = self._process_conditional_input(y)
|
|
870
865
|
else:
|
|
871
866
|
y_encoded = None
|
|
@@ -880,10 +875,10 @@ class TrainDDIM(nn.Module):
|
|
|
880
875
|
xt = torch.randn_like(x)
|
|
881
876
|
timesteps = self.fwd_ddim.vs.inference_timesteps.flip(0)
|
|
882
877
|
for i in range(len(timesteps) - 1):
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
time = torch.full((xt.shape[0],),
|
|
886
|
-
prev_time = torch.full((xt.shape[0],),
|
|
878
|
+
t_ = timesteps[i].item()
|
|
879
|
+
t_pre = timesteps[i + 1].item()
|
|
880
|
+
time = torch.full((xt.shape[0],), t_, device=self.device, dtype=torch.long)
|
|
881
|
+
prev_time = torch.full((xt.shape[0],), t_pre, device=self.device, dtype=torch.long)
|
|
887
882
|
pred = self.diff_net(xt, time, y_encoded, clip_embeddings=None)
|
|
888
883
|
xt, _ = self.rwd_ddim(xt, time, prev_time, pred)
|
|
889
884
|
x_hat = torch.clamp(xt, min=self.norm_range[0], max=self.norm_range[1])
|
|
@@ -915,8 +910,8 @@ class TrainDDIM(nn.Module):
|
|
|
915
910
|
lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
|
|
916
911
|
|
|
917
912
|
self.diff_net.train()
|
|
918
|
-
if self.
|
|
919
|
-
self.
|
|
913
|
+
if self.cond_net is not None:
|
|
914
|
+
self.cond_net.train()
|
|
920
915
|
return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
|
|
921
916
|
|
|
922
917
|
###==================================================================================================================###
|
|
@@ -931,13 +926,13 @@ class SampleDDIM(nn.Module):
|
|
|
931
926
|
|
|
932
927
|
Parameters
|
|
933
928
|
----------
|
|
934
|
-
`
|
|
929
|
+
`rwd_ddim` : nn.Module
|
|
935
930
|
Reverse diffusion module (e.g., ReverseDDIM) for the reverse process.
|
|
936
|
-
`
|
|
937
|
-
Trained model to predict noise at each time step.
|
|
938
|
-
`
|
|
931
|
+
`diff_net` : nn.Module
|
|
932
|
+
Trained model to predict noise/v/x0 at each time step.
|
|
933
|
+
`img_size` : tuple
|
|
939
934
|
Tuple of (height, width) specifying the generated image dimensions.
|
|
940
|
-
`
|
|
935
|
+
`cond_net` : nn.Module, optional
|
|
941
936
|
Model for conditional generation (e.g., text embeddings), default None.
|
|
942
937
|
`tokenizer` : str, optional
|
|
943
938
|
Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased").
|
|
@@ -947,9 +942,9 @@ class SampleDDIM(nn.Module):
|
|
|
947
942
|
Number of images to generate per batch (default: 1).
|
|
948
943
|
`in_channels` : int, optional
|
|
949
944
|
Number of input channels for generated images (default: 3).
|
|
950
|
-
`device` :
|
|
951
|
-
Device for computation (default: CUDA
|
|
952
|
-
`
|
|
945
|
+
`device` : str
|
|
946
|
+
Device for computation (default: CUDA).
|
|
947
|
+
`norm_range` : tuple, optional
|
|
953
948
|
Tuple of (min, max) for clamping generated images (default: (-1, 1)).
|
|
954
949
|
"""
|
|
955
950
|
def __init__(
|
|
@@ -957,7 +952,7 @@ class SampleDDIM(nn.Module):
|
|
|
957
952
|
rwd_ddim: torch.nn.Module,
|
|
958
953
|
diff_net: torch.nn.Module,
|
|
959
954
|
img_size: Tuple[int, int],
|
|
960
|
-
|
|
955
|
+
cond_net: Optional[torch.nn.Module] = None,
|
|
961
956
|
tokenizer: str = "bert-base-uncased",
|
|
962
957
|
max_token_length: int = 77,
|
|
963
958
|
batch_size: int = 1,
|
|
@@ -972,7 +967,7 @@ class SampleDDIM(nn.Module):
|
|
|
972
967
|
self.device = device
|
|
973
968
|
self.rwd_ddim = rwd_ddim.to(self.device)
|
|
974
969
|
self.diff_net = diff_net.to(self.device)
|
|
975
|
-
self.
|
|
970
|
+
self.cond_net = cond_net.to(self.device) if cond_net else None
|
|
976
971
|
self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
|
|
977
972
|
self.max_token_length = max_token_length
|
|
978
973
|
self.in_channels = in_channels
|
|
@@ -1035,21 +1030,21 @@ class SampleDDIM(nn.Module):
|
|
|
1035
1030
|
`save_imgs` : bool, optional
|
|
1036
1031
|
If True, saves generated images to `save_path` (default: True).
|
|
1037
1032
|
`save_path` : str, optional
|
|
1038
|
-
Directory to save generated images (default: "
|
|
1033
|
+
Directory to save generated images (default: "ddim_samples").
|
|
1039
1034
|
|
|
1040
1035
|
Returns
|
|
1041
1036
|
-------
|
|
1042
1037
|
samps (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width).
|
|
1043
1038
|
"""
|
|
1044
|
-
if conds is not None and self.
|
|
1039
|
+
if conds is not None and self.cond_net is None:
|
|
1045
1040
|
raise ValueError("Conditions provided but no conditional model specified")
|
|
1046
|
-
if conds is None and self.
|
|
1041
|
+
if conds is None and self.cond_net is not None:
|
|
1047
1042
|
raise ValueError("Conditions must be provided for conditional model")
|
|
1048
1043
|
|
|
1049
1044
|
init_samps = torch.randn(self.batch_size, self.in_channels, self.img_size[0], self.img_size[1]).to(self.device)
|
|
1050
1045
|
self.diff_net.eval()
|
|
1051
|
-
if self.
|
|
1052
|
-
self.
|
|
1046
|
+
if self.cond_net:
|
|
1047
|
+
self.cond_net.eval()
|
|
1053
1048
|
timesteps = self.rwd_ddim.vs.inference_timesteps
|
|
1054
1049
|
timesteps = timesteps.flip(0)
|
|
1055
1050
|
iterator = tqdm(
|
|
@@ -1059,10 +1054,10 @@ class SampleDDIM(nn.Module):
|
|
|
1059
1054
|
dynamic_ncols=True,
|
|
1060
1055
|
leave=True,
|
|
1061
1056
|
)
|
|
1062
|
-
if self.
|
|
1057
|
+
if self.cond_net is not None and conds is not None:
|
|
1063
1058
|
input_ids, attention_masks = self.tokenize(conds)
|
|
1064
1059
|
key_padding_mask = (attention_masks == 0)
|
|
1065
|
-
y = self.
|
|
1060
|
+
y = self.cond_net(input_ids, key_padding_mask)
|
|
1066
1061
|
else:
|
|
1067
1062
|
y = None
|
|
1068
1063
|
|
|
@@ -1070,9 +1065,10 @@ class SampleDDIM(nn.Module):
|
|
|
1070
1065
|
xt = init_samps
|
|
1071
1066
|
for i in iterator:
|
|
1072
1067
|
t_current = timesteps[i].item()
|
|
1073
|
-
|
|
1068
|
+
t_prev = timesteps[i + 1].item()
|
|
1069
|
+
#assert t_current > t_prev or t_prev == 0
|
|
1074
1070
|
time = torch.full((self.batch_size,), t_current, device=self.device, dtype=torch.long)
|
|
1075
|
-
prev_time = torch.full((self.batch_size,),
|
|
1071
|
+
prev_time = torch.full((self.batch_size,), t_prev, device=self.device, dtype=torch.long)
|
|
1076
1072
|
pred = self.diff_net(xt, time, y, clip_embeddings=None)
|
|
1077
1073
|
xt, _ = self.rwd_ddim(xt, time, prev_time, pred)
|
|
1078
1074
|
samps = torch.clamp(xt, min=self.norm_range[0], max=self.norm_range[1])
|
|
@@ -1081,7 +1077,7 @@ class SampleDDIM(nn.Module):
|
|
|
1081
1077
|
if save_imgs:
|
|
1082
1078
|
os.makedirs(save_path, exist_ok=True)
|
|
1083
1079
|
for i in range(samps.size(0)):
|
|
1084
|
-
img_path = os.path.join(save_path, f"img_{i
|
|
1080
|
+
img_path = os.path.join(save_path, f"img_{i+1}.png")
|
|
1085
1081
|
save_image(samps[i], img_path)
|
|
1086
1082
|
return samps
|
|
1087
1083
|
|
|
@@ -1102,6 +1098,6 @@ class SampleDDIM(nn.Module):
|
|
|
1102
1098
|
"""
|
|
1103
1099
|
self.device = device
|
|
1104
1100
|
self.diff_net.to(device)
|
|
1105
|
-
if self.
|
|
1106
|
-
self.
|
|
1101
|
+
if self.cond_net:
|
|
1102
|
+
self.cond_net.to(device)
|
|
1107
1103
|
return super().to(device)
|