monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
monai/inferers/inferer.py
CHANGED
@@ -11,24 +11,41 @@
|
|
11
11
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
|
+
import math
|
14
15
|
import warnings
|
15
16
|
from abc import ABC, abstractmethod
|
16
17
|
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
18
|
+
from functools import partial
|
17
19
|
from pydoc import locate
|
18
20
|
from typing import Any
|
19
21
|
|
20
22
|
import torch
|
21
23
|
import torch.nn as nn
|
24
|
+
import torch.nn.functional as F
|
22
25
|
|
23
26
|
from monai.apps.utils import get_logger
|
27
|
+
from monai.data import decollate_batch
|
24
28
|
from monai.data.meta_tensor import MetaTensor
|
25
29
|
from monai.data.thread_buffer import ThreadBuffer
|
26
30
|
from monai.inferers.merger import AvgMerger, Merger
|
27
31
|
from monai.inferers.splitter import Splitter
|
28
32
|
from monai.inferers.utils import compute_importance_map, sliding_window_inference
|
29
|
-
from monai.
|
33
|
+
from monai.networks.nets import (
|
34
|
+
VQVAE,
|
35
|
+
AutoencoderKL,
|
36
|
+
ControlNet,
|
37
|
+
DecoderOnlyTransformer,
|
38
|
+
DiffusionModelUNet,
|
39
|
+
SPADEAutoencoderKL,
|
40
|
+
SPADEDiffusionModelUNet,
|
41
|
+
)
|
42
|
+
from monai.networks.schedulers import Scheduler
|
43
|
+
from monai.transforms import CenterSpatialCrop, SpatialPad
|
44
|
+
from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import
|
30
45
|
from monai.visualize import CAM, GradCAM, GradCAMpp
|
31
46
|
|
47
|
+
tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
|
48
|
+
|
32
49
|
logger = get_logger(__name__)
|
33
50
|
|
34
51
|
__all__ = [
|
@@ -752,3 +769,1264 @@ class SliceInferer(SlidingWindowInferer):
|
|
752
769
|
return out
|
753
770
|
|
754
771
|
return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out)
|
772
|
+
|
773
|
+
|
774
|
+
class DiffusionInferer(Inferer):
|
775
|
+
"""
|
776
|
+
DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass
|
777
|
+
for a training iteration, and sample from the model.
|
778
|
+
|
779
|
+
Args:
|
780
|
+
scheduler: diffusion scheduler.
|
781
|
+
"""
|
782
|
+
|
783
|
+
def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override]
|
784
|
+
super().__init__()
|
785
|
+
|
786
|
+
self.scheduler = scheduler
|
787
|
+
|
788
|
+
def __call__( # type: ignore[override]
|
789
|
+
self,
|
790
|
+
inputs: torch.Tensor,
|
791
|
+
diffusion_model: DiffusionModelUNet,
|
792
|
+
noise: torch.Tensor,
|
793
|
+
timesteps: torch.Tensor,
|
794
|
+
condition: torch.Tensor | None = None,
|
795
|
+
mode: str = "crossattn",
|
796
|
+
seg: torch.Tensor | None = None,
|
797
|
+
) -> torch.Tensor:
|
798
|
+
"""
|
799
|
+
Implements the forward pass for a supervised training iteration.
|
800
|
+
|
801
|
+
Args:
|
802
|
+
inputs: Input image to which noise is added.
|
803
|
+
diffusion_model: diffusion model.
|
804
|
+
noise: random noise, of the same shape as the input.
|
805
|
+
timesteps: random timesteps.
|
806
|
+
condition: Conditioning for network input.
|
807
|
+
mode: Conditioning mode for the network.
|
808
|
+
seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be
|
809
|
+
provided on the forward (for SPADE-like AE or SPADE-like DM)
|
810
|
+
"""
|
811
|
+
if mode not in ["crossattn", "concat"]:
|
812
|
+
raise NotImplementedError(f"{mode} condition is not supported")
|
813
|
+
|
814
|
+
noisy_image: torch.Tensor = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
|
815
|
+
if mode == "concat":
|
816
|
+
if condition is None:
|
817
|
+
raise ValueError("Conditioning is required for concat condition")
|
818
|
+
else:
|
819
|
+
noisy_image = torch.cat([noisy_image, condition], dim=1)
|
820
|
+
condition = None
|
821
|
+
diffusion_model = (
|
822
|
+
partial(diffusion_model, seg=seg)
|
823
|
+
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
|
824
|
+
else diffusion_model
|
825
|
+
)
|
826
|
+
prediction: torch.Tensor = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition)
|
827
|
+
|
828
|
+
return prediction
|
829
|
+
|
830
|
+
@torch.no_grad()
|
831
|
+
def sample(
|
832
|
+
self,
|
833
|
+
input_noise: torch.Tensor,
|
834
|
+
diffusion_model: DiffusionModelUNet,
|
835
|
+
scheduler: Scheduler | None = None,
|
836
|
+
save_intermediates: bool | None = False,
|
837
|
+
intermediate_steps: int | None = 100,
|
838
|
+
conditioning: torch.Tensor | None = None,
|
839
|
+
mode: str = "crossattn",
|
840
|
+
verbose: bool = True,
|
841
|
+
seg: torch.Tensor | None = None,
|
842
|
+
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
843
|
+
"""
|
844
|
+
Args:
|
845
|
+
input_noise: random noise, of the same shape as the desired sample.
|
846
|
+
diffusion_model: model to sample from.
|
847
|
+
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
|
848
|
+
save_intermediates: whether to return intermediates along the sampling change
|
849
|
+
intermediate_steps: if save_intermediates is True, saves every n steps
|
850
|
+
conditioning: Conditioning for network input.
|
851
|
+
mode: Conditioning mode for the network.
|
852
|
+
verbose: if true, prints the progression bar of the sampling process.
|
853
|
+
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
|
854
|
+
"""
|
855
|
+
if mode not in ["crossattn", "concat"]:
|
856
|
+
raise NotImplementedError(f"{mode} condition is not supported")
|
857
|
+
if mode == "concat" and conditioning is None:
|
858
|
+
raise ValueError("Conditioning must be supplied for if condition mode is concat.")
|
859
|
+
if not scheduler:
|
860
|
+
scheduler = self.scheduler
|
861
|
+
image = input_noise
|
862
|
+
if verbose and has_tqdm:
|
863
|
+
progress_bar = tqdm(scheduler.timesteps)
|
864
|
+
else:
|
865
|
+
progress_bar = iter(scheduler.timesteps)
|
866
|
+
intermediates = []
|
867
|
+
for t in progress_bar:
|
868
|
+
# 1. predict noise model_output
|
869
|
+
diffusion_model = (
|
870
|
+
partial(diffusion_model, seg=seg)
|
871
|
+
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
|
872
|
+
else diffusion_model
|
873
|
+
)
|
874
|
+
if mode == "concat" and conditioning is not None:
|
875
|
+
model_input = torch.cat([image, conditioning], dim=1)
|
876
|
+
model_output = diffusion_model(
|
877
|
+
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None
|
878
|
+
)
|
879
|
+
else:
|
880
|
+
model_output = diffusion_model(
|
881
|
+
image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
|
882
|
+
)
|
883
|
+
|
884
|
+
# 2. compute previous image: x_t -> x_t-1
|
885
|
+
image, _ = scheduler.step(model_output, t, image)
|
886
|
+
if save_intermediates and t % intermediate_steps == 0:
|
887
|
+
intermediates.append(image)
|
888
|
+
if save_intermediates:
|
889
|
+
return image, intermediates
|
890
|
+
else:
|
891
|
+
return image
|
892
|
+
|
893
|
+
@torch.no_grad()
|
894
|
+
def get_likelihood(
|
895
|
+
self,
|
896
|
+
inputs: torch.Tensor,
|
897
|
+
diffusion_model: DiffusionModelUNet,
|
898
|
+
scheduler: Scheduler | None = None,
|
899
|
+
save_intermediates: bool | None = False,
|
900
|
+
conditioning: torch.Tensor | None = None,
|
901
|
+
mode: str = "crossattn",
|
902
|
+
original_input_range: tuple = (0, 255),
|
903
|
+
scaled_input_range: tuple = (0, 1),
|
904
|
+
verbose: bool = True,
|
905
|
+
seg: torch.Tensor | None = None,
|
906
|
+
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
907
|
+
"""
|
908
|
+
Computes the log-likelihoods for an input.
|
909
|
+
|
910
|
+
Args:
|
911
|
+
inputs: input images, NxCxHxW[xD]
|
912
|
+
diffusion_model: model to compute likelihood from
|
913
|
+
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
|
914
|
+
save_intermediates: save the intermediate spatial KL maps
|
915
|
+
conditioning: Conditioning for network input.
|
916
|
+
mode: Conditioning mode for the network.
|
917
|
+
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
|
918
|
+
scaled_input_range: the [min,max] intensity range of the input data after scaling.
|
919
|
+
verbose: if true, prints the progression bar of the sampling process.
|
920
|
+
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
|
921
|
+
"""
|
922
|
+
|
923
|
+
if not scheduler:
|
924
|
+
scheduler = self.scheduler
|
925
|
+
if scheduler._get_name() != "DDPMScheduler":
|
926
|
+
raise NotImplementedError(
|
927
|
+
f"Likelihood computation is only compatible with DDPMScheduler,"
|
928
|
+
f" you are using {scheduler._get_name()}"
|
929
|
+
)
|
930
|
+
if mode not in ["crossattn", "concat"]:
|
931
|
+
raise NotImplementedError(f"{mode} condition is not supported")
|
932
|
+
if mode == "concat" and conditioning is None:
|
933
|
+
raise ValueError("Conditioning must be supplied for if condition mode is concat.")
|
934
|
+
if verbose and has_tqdm:
|
935
|
+
progress_bar = tqdm(scheduler.timesteps)
|
936
|
+
else:
|
937
|
+
progress_bar = iter(scheduler.timesteps)
|
938
|
+
intermediates = []
|
939
|
+
noise = torch.randn_like(inputs).to(inputs.device)
|
940
|
+
total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
|
941
|
+
for t in progress_bar:
|
942
|
+
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
|
943
|
+
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
|
944
|
+
diffusion_model = (
|
945
|
+
partial(diffusion_model, seg=seg)
|
946
|
+
if isinstance(diffusion_model, SPADEDiffusionModelUNet)
|
947
|
+
else diffusion_model
|
948
|
+
)
|
949
|
+
if mode == "concat" and conditioning is not None:
|
950
|
+
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
|
951
|
+
model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None)
|
952
|
+
else:
|
953
|
+
model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning)
|
954
|
+
|
955
|
+
# get the model's predicted mean, and variance if it is predicted
|
956
|
+
if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
|
957
|
+
model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
|
958
|
+
else:
|
959
|
+
predicted_variance = None
|
960
|
+
|
961
|
+
# 1. compute alphas, betas
|
962
|
+
alpha_prod_t = scheduler.alphas_cumprod[t]
|
963
|
+
alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one
|
964
|
+
beta_prod_t = 1 - alpha_prod_t
|
965
|
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
966
|
+
|
967
|
+
# 2. compute predicted original sample from predicted noise also called
|
968
|
+
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
969
|
+
if scheduler.prediction_type == "epsilon":
|
970
|
+
pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
971
|
+
elif scheduler.prediction_type == "sample":
|
972
|
+
pred_original_sample = model_output
|
973
|
+
elif scheduler.prediction_type == "v_prediction":
|
974
|
+
pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output
|
975
|
+
# 3. Clip "predicted x_0"
|
976
|
+
if scheduler.clip_sample:
|
977
|
+
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
978
|
+
|
979
|
+
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
980
|
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
981
|
+
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t
|
982
|
+
current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
|
983
|
+
|
984
|
+
# 5. Compute predicted previous sample µ_t
|
985
|
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
986
|
+
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
|
987
|
+
|
988
|
+
# get the posterior mean and variance
|
989
|
+
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
|
990
|
+
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
|
991
|
+
|
992
|
+
log_posterior_variance = torch.log(posterior_variance)
|
993
|
+
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
|
994
|
+
|
995
|
+
if t == 0:
|
996
|
+
# compute -log p(x_0|x_1)
|
997
|
+
kl = -self._get_decoder_log_likelihood(
|
998
|
+
inputs=inputs,
|
999
|
+
means=predicted_mean,
|
1000
|
+
log_scales=0.5 * log_predicted_variance,
|
1001
|
+
original_input_range=original_input_range,
|
1002
|
+
scaled_input_range=scaled_input_range,
|
1003
|
+
)
|
1004
|
+
else:
|
1005
|
+
# compute kl between two normals
|
1006
|
+
kl = 0.5 * (
|
1007
|
+
-1.0
|
1008
|
+
+ log_predicted_variance
|
1009
|
+
- log_posterior_variance
|
1010
|
+
+ torch.exp(log_posterior_variance - log_predicted_variance)
|
1011
|
+
+ ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance)
|
1012
|
+
)
|
1013
|
+
total_kl += kl.view(kl.shape[0], -1).mean(dim=1)
|
1014
|
+
if save_intermediates:
|
1015
|
+
intermediates.append(kl.cpu())
|
1016
|
+
|
1017
|
+
if save_intermediates:
|
1018
|
+
return total_kl, intermediates
|
1019
|
+
else:
|
1020
|
+
return total_kl
|
1021
|
+
|
1022
|
+
def _approx_standard_normal_cdf(self, x):
|
1023
|
+
"""
|
1024
|
+
A fast approximation of the cumulative distribution function of the
|
1025
|
+
standard normal. Code adapted from https://github.com/openai/improved-diffusion.
|
1026
|
+
"""
|
1027
|
+
|
1028
|
+
return 0.5 * (
|
1029
|
+
1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3)))
|
1030
|
+
)
|
1031
|
+
|
1032
|
+
def _get_decoder_log_likelihood(
|
1033
|
+
self,
|
1034
|
+
inputs: torch.Tensor,
|
1035
|
+
means: torch.Tensor,
|
1036
|
+
log_scales: torch.Tensor,
|
1037
|
+
original_input_range: tuple = (0, 255),
|
1038
|
+
scaled_input_range: tuple = (0, 1),
|
1039
|
+
) -> torch.Tensor:
|
1040
|
+
"""
|
1041
|
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
1042
|
+
given image. Code adapted from https://github.com/openai/improved-diffusion.
|
1043
|
+
|
1044
|
+
Args:
|
1045
|
+
input: the target images. It is assumed that this was uint8 values,
|
1046
|
+
rescaled to the range [-1, 1].
|
1047
|
+
means: the Gaussian mean Tensor.
|
1048
|
+
log_scales: the Gaussian log stddev Tensor.
|
1049
|
+
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
|
1050
|
+
scaled_input_range: the [min,max] intensity range of the input data after scaling.
|
1051
|
+
"""
|
1052
|
+
if inputs.shape != means.shape:
|
1053
|
+
raise ValueError(f"Inputs and means must have the same shape, got {inputs.shape} and {means.shape}")
|
1054
|
+
bin_width = (scaled_input_range[1] - scaled_input_range[0]) / (
|
1055
|
+
original_input_range[1] - original_input_range[0]
|
1056
|
+
)
|
1057
|
+
centered_x = inputs - means
|
1058
|
+
inv_stdv = torch.exp(-log_scales)
|
1059
|
+
plus_in = inv_stdv * (centered_x + bin_width / 2)
|
1060
|
+
cdf_plus = self._approx_standard_normal_cdf(plus_in)
|
1061
|
+
min_in = inv_stdv * (centered_x - bin_width / 2)
|
1062
|
+
cdf_min = self._approx_standard_normal_cdf(min_in)
|
1063
|
+
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
|
1064
|
+
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
|
1065
|
+
cdf_delta = cdf_plus - cdf_min
|
1066
|
+
log_probs = torch.where(
|
1067
|
+
inputs < -0.999,
|
1068
|
+
log_cdf_plus,
|
1069
|
+
torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
|
1070
|
+
)
|
1071
|
+
return log_probs
|
1072
|
+
|
1073
|
+
|
1074
|
+
class LatentDiffusionInferer(DiffusionInferer):
|
1075
|
+
"""
|
1076
|
+
LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can
|
1077
|
+
be used to perform a signal forward pass for a training iteration, and sample from the model.
|
1078
|
+
|
1079
|
+
Args:
|
1080
|
+
scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
1081
|
+
scale_factor: scale factor to multiply the values of the latent representation before processing it by the
|
1082
|
+
second stage.
|
1083
|
+
ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape.
|
1084
|
+
autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a
|
1085
|
+
difference between the autoencoder's latent shape and the DM shape.
|
1086
|
+
"""
|
1087
|
+
|
1088
|
+
def __init__(
|
1089
|
+
self,
|
1090
|
+
scheduler: Scheduler,
|
1091
|
+
scale_factor: float = 1.0,
|
1092
|
+
ldm_latent_shape: list | None = None,
|
1093
|
+
autoencoder_latent_shape: list | None = None,
|
1094
|
+
) -> None:
|
1095
|
+
super().__init__(scheduler=scheduler)
|
1096
|
+
self.scale_factor = scale_factor
|
1097
|
+
if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None):
|
1098
|
+
raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None, and vice versa.")
|
1099
|
+
self.ldm_latent_shape = ldm_latent_shape
|
1100
|
+
self.autoencoder_latent_shape = autoencoder_latent_shape
|
1101
|
+
if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None:
|
1102
|
+
self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
|
1103
|
+
self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)
|
1104
|
+
|
1105
|
+
def __call__( # type: ignore[override]
|
1106
|
+
self,
|
1107
|
+
inputs: torch.Tensor,
|
1108
|
+
autoencoder_model: AutoencoderKL | VQVAE,
|
1109
|
+
diffusion_model: DiffusionModelUNet,
|
1110
|
+
noise: torch.Tensor,
|
1111
|
+
timesteps: torch.Tensor,
|
1112
|
+
condition: torch.Tensor | None = None,
|
1113
|
+
mode: str = "crossattn",
|
1114
|
+
seg: torch.Tensor | None = None,
|
1115
|
+
) -> torch.Tensor:
|
1116
|
+
"""
|
1117
|
+
Implements the forward pass for a supervised training iteration.
|
1118
|
+
|
1119
|
+
Args:
|
1120
|
+
inputs: input image to which the latent representation will be extracted and noise is added.
|
1121
|
+
autoencoder_model: first stage model.
|
1122
|
+
diffusion_model: diffusion model.
|
1123
|
+
noise: random noise, of the same shape as the latent representation.
|
1124
|
+
timesteps: random timesteps.
|
1125
|
+
condition: conditioning for network input.
|
1126
|
+
mode: Conditioning mode for the network.
|
1127
|
+
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
|
1128
|
+
"""
|
1129
|
+
with torch.no_grad():
|
1130
|
+
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
|
1131
|
+
|
1132
|
+
if self.ldm_latent_shape is not None:
|
1133
|
+
latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
|
1134
|
+
|
1135
|
+
prediction: torch.Tensor = super().__call__(
|
1136
|
+
inputs=latent,
|
1137
|
+
diffusion_model=diffusion_model,
|
1138
|
+
noise=noise,
|
1139
|
+
timesteps=timesteps,
|
1140
|
+
condition=condition,
|
1141
|
+
mode=mode,
|
1142
|
+
seg=seg,
|
1143
|
+
)
|
1144
|
+
return prediction
|
1145
|
+
|
1146
|
+
@torch.no_grad()
|
1147
|
+
def sample( # type: ignore[override]
|
1148
|
+
self,
|
1149
|
+
input_noise: torch.Tensor,
|
1150
|
+
autoencoder_model: AutoencoderKL | VQVAE,
|
1151
|
+
diffusion_model: DiffusionModelUNet,
|
1152
|
+
scheduler: Scheduler | None = None,
|
1153
|
+
save_intermediates: bool | None = False,
|
1154
|
+
intermediate_steps: int | None = 100,
|
1155
|
+
conditioning: torch.Tensor | None = None,
|
1156
|
+
mode: str = "crossattn",
|
1157
|
+
verbose: bool = True,
|
1158
|
+
seg: torch.Tensor | None = None,
|
1159
|
+
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
1160
|
+
"""
|
1161
|
+
Args:
|
1162
|
+
input_noise: random noise, of the same shape as the desired latent representation.
|
1163
|
+
autoencoder_model: first stage model.
|
1164
|
+
diffusion_model: model to sample from.
|
1165
|
+
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
|
1166
|
+
save_intermediates: whether to return intermediates along the sampling change
|
1167
|
+
intermediate_steps: if save_intermediates is True, saves every n steps
|
1168
|
+
conditioning: Conditioning for network input.
|
1169
|
+
mode: Conditioning mode for the network.
|
1170
|
+
verbose: if true, prints the progression bar of the sampling process.
|
1171
|
+
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
|
1172
|
+
is instance of SPADEAutoencoderKL, segmentation must be provided.
|
1173
|
+
"""
|
1174
|
+
|
1175
|
+
if (
|
1176
|
+
isinstance(autoencoder_model, SPADEAutoencoderKL)
|
1177
|
+
and isinstance(diffusion_model, SPADEDiffusionModelUNet)
|
1178
|
+
and autoencoder_model.decoder.label_nc != diffusion_model.label_nc
|
1179
|
+
):
|
1180
|
+
raise ValueError(
|
1181
|
+
f"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic"
|
1182
|
+
f"labels for each must be compatible, but got {autoencoder_model.decoder.label_nc} and"
|
1183
|
+
f"{diffusion_model.label_nc}"
|
1184
|
+
)
|
1185
|
+
|
1186
|
+
outputs = super().sample(
|
1187
|
+
input_noise=input_noise,
|
1188
|
+
diffusion_model=diffusion_model,
|
1189
|
+
scheduler=scheduler,
|
1190
|
+
save_intermediates=save_intermediates,
|
1191
|
+
intermediate_steps=intermediate_steps,
|
1192
|
+
conditioning=conditioning,
|
1193
|
+
mode=mode,
|
1194
|
+
verbose=verbose,
|
1195
|
+
seg=seg,
|
1196
|
+
)
|
1197
|
+
|
1198
|
+
if save_intermediates:
|
1199
|
+
latent, latent_intermediates = outputs
|
1200
|
+
else:
|
1201
|
+
latent = outputs
|
1202
|
+
|
1203
|
+
if self.autoencoder_latent_shape is not None:
|
1204
|
+
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
|
1205
|
+
latent_intermediates = [
|
1206
|
+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
|
1207
|
+
]
|
1208
|
+
|
1209
|
+
decode = autoencoder_model.decode_stage_2_outputs
|
1210
|
+
if isinstance(autoencoder_model, SPADEAutoencoderKL):
|
1211
|
+
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
|
1212
|
+
image = decode(latent / self.scale_factor)
|
1213
|
+
|
1214
|
+
if save_intermediates:
|
1215
|
+
intermediates = []
|
1216
|
+
for latent_intermediate in latent_intermediates:
|
1217
|
+
decode = autoencoder_model.decode_stage_2_outputs
|
1218
|
+
if isinstance(autoencoder_model, SPADEAutoencoderKL):
|
1219
|
+
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
|
1220
|
+
intermediates.append(decode(latent_intermediate / self.scale_factor))
|
1221
|
+
return image, intermediates
|
1222
|
+
|
1223
|
+
else:
|
1224
|
+
return image
|
1225
|
+
|
1226
|
+
@torch.no_grad()
|
1227
|
+
def get_likelihood( # type: ignore[override]
|
1228
|
+
self,
|
1229
|
+
inputs: torch.Tensor,
|
1230
|
+
autoencoder_model: AutoencoderKL | VQVAE,
|
1231
|
+
diffusion_model: DiffusionModelUNet,
|
1232
|
+
scheduler: Scheduler | None = None,
|
1233
|
+
save_intermediates: bool | None = False,
|
1234
|
+
conditioning: torch.Tensor | None = None,
|
1235
|
+
mode: str = "crossattn",
|
1236
|
+
original_input_range: tuple | None = (0, 255),
|
1237
|
+
scaled_input_range: tuple | None = (0, 1),
|
1238
|
+
verbose: bool = True,
|
1239
|
+
resample_latent_likelihoods: bool = False,
|
1240
|
+
resample_interpolation_mode: str = "nearest",
|
1241
|
+
seg: torch.Tensor | None = None,
|
1242
|
+
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
1243
|
+
"""
|
1244
|
+
Computes the log-likelihoods of the latent representations of the input.
|
1245
|
+
|
1246
|
+
Args:
|
1247
|
+
inputs: input images, NxCxHxW[xD]
|
1248
|
+
autoencoder_model: first stage model.
|
1249
|
+
diffusion_model: model to compute likelihood from
|
1250
|
+
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
|
1251
|
+
save_intermediates: save the intermediate spatial KL maps
|
1252
|
+
conditioning: Conditioning for network input.
|
1253
|
+
mode: Conditioning mode for the network.
|
1254
|
+
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
|
1255
|
+
scaled_input_range: the [min,max] intensity range of the input data after scaling.
|
1256
|
+
verbose: if true, prints the progression bar of the sampling process.
|
1257
|
+
resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
|
1258
|
+
dimension as the input images.
|
1259
|
+
resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
|
1260
|
+
or 'trilinear;
|
1261
|
+
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
|
1262
|
+
is instance of SPADEAutoencoderKL, segmentation must be provided.
|
1263
|
+
"""
|
1264
|
+
if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
|
1265
|
+
raise ValueError(
|
1266
|
+
f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
|
1267
|
+
)
|
1268
|
+
latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
|
1269
|
+
|
1270
|
+
if self.ldm_latent_shape is not None:
|
1271
|
+
latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)
|
1272
|
+
|
1273
|
+
outputs = super().get_likelihood(
|
1274
|
+
inputs=latents,
|
1275
|
+
diffusion_model=diffusion_model,
|
1276
|
+
scheduler=scheduler,
|
1277
|
+
save_intermediates=save_intermediates,
|
1278
|
+
conditioning=conditioning,
|
1279
|
+
mode=mode,
|
1280
|
+
verbose=verbose,
|
1281
|
+
seg=seg,
|
1282
|
+
)
|
1283
|
+
|
1284
|
+
if save_intermediates and resample_latent_likelihoods:
|
1285
|
+
intermediates = outputs[1]
|
1286
|
+
resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)
|
1287
|
+
intermediates = [resizer(x) for x in intermediates]
|
1288
|
+
outputs = (outputs[0], intermediates)
|
1289
|
+
return outputs
|
1290
|
+
|
1291
|
+
|
1292
|
+
class ControlNetDiffusionInferer(DiffusionInferer):
|
1293
|
+
"""
|
1294
|
+
ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal
|
1295
|
+
forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning.
|
1296
|
+
|
1297
|
+
Args:
|
1298
|
+
scheduler: diffusion scheduler.
|
1299
|
+
"""
|
1300
|
+
|
1301
|
+
def __init__(self, scheduler: Scheduler) -> None:
|
1302
|
+
Inferer.__init__(self)
|
1303
|
+
self.scheduler = scheduler
|
1304
|
+
|
1305
|
+
def __call__( # type: ignore[override]
|
1306
|
+
self,
|
1307
|
+
inputs: torch.Tensor,
|
1308
|
+
diffusion_model: DiffusionModelUNet,
|
1309
|
+
controlnet: ControlNet,
|
1310
|
+
noise: torch.Tensor,
|
1311
|
+
timesteps: torch.Tensor,
|
1312
|
+
cn_cond: torch.Tensor,
|
1313
|
+
condition: torch.Tensor | None = None,
|
1314
|
+
mode: str = "crossattn",
|
1315
|
+
seg: torch.Tensor | None = None,
|
1316
|
+
) -> torch.Tensor:
|
1317
|
+
"""
|
1318
|
+
Implements the forward pass for a supervised training iteration.
|
1319
|
+
|
1320
|
+
Args:
|
1321
|
+
inputs: Input image to which noise is added.
|
1322
|
+
diffusion_model: diffusion model.
|
1323
|
+
controlnet: controlnet sub-network.
|
1324
|
+
noise: random noise, of the same shape as the input.
|
1325
|
+
timesteps: random timesteps.
|
1326
|
+
cn_cond: conditioning image for the ControlNet.
|
1327
|
+
condition: Conditioning for network input.
|
1328
|
+
mode: Conditioning mode for the network.
|
1329
|
+
seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be
|
1330
|
+
provided on the forward (for SPADE-like AE or SPADE-like DM)
|
1331
|
+
"""
|
1332
|
+
if mode not in ["crossattn", "concat"]:
|
1333
|
+
raise NotImplementedError(f"{mode} condition is not supported")
|
1334
|
+
|
1335
|
+
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
|
1336
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1337
|
+
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
|
1338
|
+
)
|
1339
|
+
if mode == "concat" and condition is not None:
|
1340
|
+
noisy_image = torch.cat([noisy_image, condition], dim=1)
|
1341
|
+
condition = None
|
1342
|
+
|
1343
|
+
diffuse = diffusion_model
|
1344
|
+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
1345
|
+
diffuse = partial(diffusion_model, seg=seg)
|
1346
|
+
|
1347
|
+
prediction: torch.Tensor = diffuse(
|
1348
|
+
x=noisy_image,
|
1349
|
+
timesteps=timesteps,
|
1350
|
+
context=condition,
|
1351
|
+
down_block_additional_residuals=down_block_res_samples,
|
1352
|
+
mid_block_additional_residual=mid_block_res_sample,
|
1353
|
+
)
|
1354
|
+
|
1355
|
+
return prediction
|
1356
|
+
|
1357
|
+
@torch.no_grad()
|
1358
|
+
def sample( # type: ignore[override]
|
1359
|
+
self,
|
1360
|
+
input_noise: torch.Tensor,
|
1361
|
+
diffusion_model: DiffusionModelUNet,
|
1362
|
+
controlnet: ControlNet,
|
1363
|
+
cn_cond: torch.Tensor,
|
1364
|
+
scheduler: Scheduler | None = None,
|
1365
|
+
save_intermediates: bool | None = False,
|
1366
|
+
intermediate_steps: int | None = 100,
|
1367
|
+
conditioning: torch.Tensor | None = None,
|
1368
|
+
mode: str = "crossattn",
|
1369
|
+
verbose: bool = True,
|
1370
|
+
seg: torch.Tensor | None = None,
|
1371
|
+
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
1372
|
+
"""
|
1373
|
+
Args:
|
1374
|
+
input_noise: random noise, of the same shape as the desired sample.
|
1375
|
+
diffusion_model: model to sample from.
|
1376
|
+
controlnet: controlnet sub-network.
|
1377
|
+
cn_cond: conditioning image for the ControlNet.
|
1378
|
+
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
|
1379
|
+
save_intermediates: whether to return intermediates along the sampling change
|
1380
|
+
intermediate_steps: if save_intermediates is True, saves every n steps
|
1381
|
+
conditioning: Conditioning for network input.
|
1382
|
+
mode: Conditioning mode for the network.
|
1383
|
+
verbose: if true, prints the progression bar of the sampling process.
|
1384
|
+
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
|
1385
|
+
"""
|
1386
|
+
if mode not in ["crossattn", "concat"]:
|
1387
|
+
raise NotImplementedError(f"{mode} condition is not supported")
|
1388
|
+
|
1389
|
+
if not scheduler:
|
1390
|
+
scheduler = self.scheduler
|
1391
|
+
image = input_noise
|
1392
|
+
if verbose and has_tqdm:
|
1393
|
+
progress_bar = tqdm(scheduler.timesteps)
|
1394
|
+
else:
|
1395
|
+
progress_bar = iter(scheduler.timesteps)
|
1396
|
+
intermediates = []
|
1397
|
+
for t in progress_bar:
|
1398
|
+
# 1. ControlNet forward
|
1399
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1400
|
+
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
|
1401
|
+
)
|
1402
|
+
# 2. predict noise model_output
|
1403
|
+
diffuse = diffusion_model
|
1404
|
+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
1405
|
+
diffuse = partial(diffusion_model, seg=seg)
|
1406
|
+
|
1407
|
+
if mode == "concat" and conditioning is not None:
|
1408
|
+
model_input = torch.cat([image, conditioning], dim=1)
|
1409
|
+
model_output = diffuse(
|
1410
|
+
model_input,
|
1411
|
+
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
1412
|
+
context=None,
|
1413
|
+
down_block_additional_residuals=down_block_res_samples,
|
1414
|
+
mid_block_additional_residual=mid_block_res_sample,
|
1415
|
+
)
|
1416
|
+
else:
|
1417
|
+
model_output = diffuse(
|
1418
|
+
image,
|
1419
|
+
timesteps=torch.Tensor((t,)).to(input_noise.device),
|
1420
|
+
context=conditioning,
|
1421
|
+
down_block_additional_residuals=down_block_res_samples,
|
1422
|
+
mid_block_additional_residual=mid_block_res_sample,
|
1423
|
+
)
|
1424
|
+
|
1425
|
+
# 3. compute previous image: x_t -> x_t-1
|
1426
|
+
image, _ = scheduler.step(model_output, t, image)
|
1427
|
+
if save_intermediates and t % intermediate_steps == 0:
|
1428
|
+
intermediates.append(image)
|
1429
|
+
if save_intermediates:
|
1430
|
+
return image, intermediates
|
1431
|
+
else:
|
1432
|
+
return image
|
1433
|
+
|
1434
|
+
@torch.no_grad()
|
1435
|
+
def get_likelihood( # type: ignore[override]
|
1436
|
+
self,
|
1437
|
+
inputs: torch.Tensor,
|
1438
|
+
diffusion_model: DiffusionModelUNet,
|
1439
|
+
controlnet: ControlNet,
|
1440
|
+
cn_cond: torch.Tensor,
|
1441
|
+
scheduler: Scheduler | None = None,
|
1442
|
+
save_intermediates: bool | None = False,
|
1443
|
+
conditioning: torch.Tensor | None = None,
|
1444
|
+
mode: str = "crossattn",
|
1445
|
+
original_input_range: tuple = (0, 255),
|
1446
|
+
scaled_input_range: tuple = (0, 1),
|
1447
|
+
verbose: bool = True,
|
1448
|
+
seg: torch.Tensor | None = None,
|
1449
|
+
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
1450
|
+
"""
|
1451
|
+
Computes the log-likelihoods for an input.
|
1452
|
+
|
1453
|
+
Args:
|
1454
|
+
inputs: input images, NxCxHxW[xD]
|
1455
|
+
diffusion_model: model to compute likelihood from
|
1456
|
+
controlnet: controlnet sub-network.
|
1457
|
+
cn_cond: conditioning image for the ControlNet.
|
1458
|
+
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
|
1459
|
+
save_intermediates: save the intermediate spatial KL maps
|
1460
|
+
conditioning: Conditioning for network input.
|
1461
|
+
mode: Conditioning mode for the network.
|
1462
|
+
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
|
1463
|
+
scaled_input_range: the [min,max] intensity range of the input data after scaling.
|
1464
|
+
verbose: if true, prints the progression bar of the sampling process.
|
1465
|
+
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
|
1466
|
+
"""
|
1467
|
+
|
1468
|
+
if not scheduler:
|
1469
|
+
scheduler = self.scheduler
|
1470
|
+
if scheduler._get_name() != "DDPMScheduler":
|
1471
|
+
raise NotImplementedError(
|
1472
|
+
f"Likelihood computation is only compatible with DDPMScheduler,"
|
1473
|
+
f" you are using {scheduler._get_name()}"
|
1474
|
+
)
|
1475
|
+
if mode not in ["crossattn", "concat"]:
|
1476
|
+
raise NotImplementedError(f"{mode} condition is not supported")
|
1477
|
+
if verbose and has_tqdm:
|
1478
|
+
progress_bar = tqdm(scheduler.timesteps)
|
1479
|
+
else:
|
1480
|
+
progress_bar = iter(scheduler.timesteps)
|
1481
|
+
intermediates = []
|
1482
|
+
noise = torch.randn_like(inputs).to(inputs.device)
|
1483
|
+
total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
|
1484
|
+
for t in progress_bar:
|
1485
|
+
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
|
1486
|
+
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
|
1487
|
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
1488
|
+
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
|
1489
|
+
)
|
1490
|
+
|
1491
|
+
diffuse = diffusion_model
|
1492
|
+
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
|
1493
|
+
diffuse = partial(diffusion_model, seg=seg)
|
1494
|
+
|
1495
|
+
if mode == "concat" and conditioning is not None:
|
1496
|
+
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
|
1497
|
+
model_output = diffuse(
|
1498
|
+
noisy_image,
|
1499
|
+
timesteps=timesteps,
|
1500
|
+
context=None,
|
1501
|
+
down_block_additional_residuals=down_block_res_samples,
|
1502
|
+
mid_block_additional_residual=mid_block_res_sample,
|
1503
|
+
)
|
1504
|
+
else:
|
1505
|
+
model_output = diffuse(
|
1506
|
+
x=noisy_image,
|
1507
|
+
timesteps=timesteps,
|
1508
|
+
context=conditioning,
|
1509
|
+
down_block_additional_residuals=down_block_res_samples,
|
1510
|
+
mid_block_additional_residual=mid_block_res_sample,
|
1511
|
+
)
|
1512
|
+
# get the model's predicted mean, and variance if it is predicted
|
1513
|
+
if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
|
1514
|
+
model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
|
1515
|
+
else:
|
1516
|
+
predicted_variance = None
|
1517
|
+
|
1518
|
+
# 1. compute alphas, betas
|
1519
|
+
alpha_prod_t = scheduler.alphas_cumprod[t]
|
1520
|
+
alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one
|
1521
|
+
beta_prod_t = 1 - alpha_prod_t
|
1522
|
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
1523
|
+
|
1524
|
+
# 2. compute predicted original sample from predicted noise also called
|
1525
|
+
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
1526
|
+
if scheduler.prediction_type == "epsilon":
|
1527
|
+
pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
1528
|
+
elif scheduler.prediction_type == "sample":
|
1529
|
+
pred_original_sample = model_output
|
1530
|
+
elif scheduler.prediction_type == "v_prediction":
|
1531
|
+
pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output
|
1532
|
+
# 3. Clip "predicted x_0"
|
1533
|
+
if scheduler.clip_sample:
|
1534
|
+
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
1535
|
+
|
1536
|
+
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
1537
|
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
1538
|
+
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t
|
1539
|
+
current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
|
1540
|
+
|
1541
|
+
# 5. Compute predicted previous sample µ_t
|
1542
|
+
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
1543
|
+
predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
|
1544
|
+
|
1545
|
+
# get the posterior mean and variance
|
1546
|
+
posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
|
1547
|
+
posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
|
1548
|
+
|
1549
|
+
log_posterior_variance = torch.log(posterior_variance)
|
1550
|
+
log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
|
1551
|
+
|
1552
|
+
if t == 0:
|
1553
|
+
# compute -log p(x_0|x_1)
|
1554
|
+
kl = -super()._get_decoder_log_likelihood(
|
1555
|
+
inputs=inputs,
|
1556
|
+
means=predicted_mean,
|
1557
|
+
log_scales=0.5 * log_predicted_variance,
|
1558
|
+
original_input_range=original_input_range,
|
1559
|
+
scaled_input_range=scaled_input_range,
|
1560
|
+
)
|
1561
|
+
else:
|
1562
|
+
# compute kl between two normals
|
1563
|
+
kl = 0.5 * (
|
1564
|
+
-1.0
|
1565
|
+
+ log_predicted_variance
|
1566
|
+
- log_posterior_variance
|
1567
|
+
+ torch.exp(log_posterior_variance - log_predicted_variance)
|
1568
|
+
+ ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance)
|
1569
|
+
)
|
1570
|
+
total_kl += kl.view(kl.shape[0], -1).mean(dim=1)
|
1571
|
+
if save_intermediates:
|
1572
|
+
intermediates.append(kl.cpu())
|
1573
|
+
|
1574
|
+
if save_intermediates:
|
1575
|
+
return total_kl, intermediates
|
1576
|
+
else:
|
1577
|
+
return total_kl
|
1578
|
+
|
1579
|
+
|
1580
|
+
class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
|
1581
|
+
"""
|
1582
|
+
ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet,
|
1583
|
+
and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from
|
1584
|
+
the model.
|
1585
|
+
|
1586
|
+
Args:
|
1587
|
+
scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
1588
|
+
scale_factor: scale factor to multiply the values of the latent representation before processing it by the
|
1589
|
+
second stage.
|
1590
|
+
ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape.
|
1591
|
+
autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a
|
1592
|
+
difference between the autoencoder's latent shape and the DM shape.
|
1593
|
+
"""
|
1594
|
+
|
1595
|
+
def __init__(
|
1596
|
+
self,
|
1597
|
+
scheduler: Scheduler,
|
1598
|
+
scale_factor: float = 1.0,
|
1599
|
+
ldm_latent_shape: list | None = None,
|
1600
|
+
autoencoder_latent_shape: list | None = None,
|
1601
|
+
) -> None:
|
1602
|
+
super().__init__(scheduler=scheduler)
|
1603
|
+
self.scale_factor = scale_factor
|
1604
|
+
if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None):
|
1605
|
+
raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.")
|
1606
|
+
self.ldm_latent_shape = ldm_latent_shape
|
1607
|
+
self.autoencoder_latent_shape = autoencoder_latent_shape
|
1608
|
+
if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None:
|
1609
|
+
self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
|
1610
|
+
self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)
|
1611
|
+
|
1612
|
+
def __call__( # type: ignore[override]
|
1613
|
+
self,
|
1614
|
+
inputs: torch.Tensor,
|
1615
|
+
autoencoder_model: AutoencoderKL | VQVAE,
|
1616
|
+
diffusion_model: DiffusionModelUNet,
|
1617
|
+
controlnet: ControlNet,
|
1618
|
+
noise: torch.Tensor,
|
1619
|
+
timesteps: torch.Tensor,
|
1620
|
+
cn_cond: torch.Tensor,
|
1621
|
+
condition: torch.Tensor | None = None,
|
1622
|
+
mode: str = "crossattn",
|
1623
|
+
seg: torch.Tensor | None = None,
|
1624
|
+
) -> torch.Tensor:
|
1625
|
+
"""
|
1626
|
+
Implements the forward pass for a supervised training iteration.
|
1627
|
+
|
1628
|
+
Args:
|
1629
|
+
inputs: input image to which the latent representation will be extracted and noise is added.
|
1630
|
+
autoencoder_model: first stage model.
|
1631
|
+
diffusion_model: diffusion model.
|
1632
|
+
controlnet: instance of ControlNet model
|
1633
|
+
noise: random noise, of the same shape as the latent representation.
|
1634
|
+
timesteps: random timesteps.
|
1635
|
+
cn_cond: conditioning tensor for the ControlNet network
|
1636
|
+
condition: conditioning for network input.
|
1637
|
+
mode: Conditioning mode for the network.
|
1638
|
+
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
|
1639
|
+
"""
|
1640
|
+
with torch.no_grad():
|
1641
|
+
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
|
1642
|
+
|
1643
|
+
if self.ldm_latent_shape is not None:
|
1644
|
+
latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
|
1645
|
+
|
1646
|
+
if cn_cond.shape[2:] != latent.shape[2:]:
|
1647
|
+
cn_cond = F.interpolate(cn_cond, latent.shape[2:])
|
1648
|
+
|
1649
|
+
prediction = super().__call__(
|
1650
|
+
inputs=latent,
|
1651
|
+
diffusion_model=diffusion_model,
|
1652
|
+
controlnet=controlnet,
|
1653
|
+
noise=noise,
|
1654
|
+
timesteps=timesteps,
|
1655
|
+
cn_cond=cn_cond,
|
1656
|
+
condition=condition,
|
1657
|
+
mode=mode,
|
1658
|
+
seg=seg,
|
1659
|
+
)
|
1660
|
+
|
1661
|
+
return prediction
|
1662
|
+
|
1663
|
+
@torch.no_grad()
|
1664
|
+
def sample( # type: ignore[override]
|
1665
|
+
self,
|
1666
|
+
input_noise: torch.Tensor,
|
1667
|
+
autoencoder_model: AutoencoderKL | VQVAE,
|
1668
|
+
diffusion_model: DiffusionModelUNet,
|
1669
|
+
controlnet: ControlNet,
|
1670
|
+
cn_cond: torch.Tensor,
|
1671
|
+
scheduler: Scheduler | None = None,
|
1672
|
+
save_intermediates: bool | None = False,
|
1673
|
+
intermediate_steps: int | None = 100,
|
1674
|
+
conditioning: torch.Tensor | None = None,
|
1675
|
+
mode: str = "crossattn",
|
1676
|
+
verbose: bool = True,
|
1677
|
+
seg: torch.Tensor | None = None,
|
1678
|
+
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
1679
|
+
"""
|
1680
|
+
Args:
|
1681
|
+
input_noise: random noise, of the same shape as the desired latent representation.
|
1682
|
+
autoencoder_model: first stage model.
|
1683
|
+
diffusion_model: model to sample from.
|
1684
|
+
controlnet: instance of ControlNet model.
|
1685
|
+
cn_cond: conditioning tensor for the ControlNet network.
|
1686
|
+
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
|
1687
|
+
save_intermediates: whether to return intermediates along the sampling change
|
1688
|
+
intermediate_steps: if save_intermediates is True, saves every n steps
|
1689
|
+
conditioning: Conditioning for network input.
|
1690
|
+
mode: Conditioning mode for the network.
|
1691
|
+
verbose: if true, prints the progression bar of the sampling process.
|
1692
|
+
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
|
1693
|
+
is instance of SPADEAutoencoderKL, segmentation must be provided.
|
1694
|
+
"""
|
1695
|
+
|
1696
|
+
if (
|
1697
|
+
isinstance(autoencoder_model, SPADEAutoencoderKL)
|
1698
|
+
and isinstance(diffusion_model, SPADEDiffusionModelUNet)
|
1699
|
+
and autoencoder_model.decoder.label_nc != diffusion_model.label_nc
|
1700
|
+
):
|
1701
|
+
raise ValueError(
|
1702
|
+
"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic"
|
1703
|
+
"labels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}"
|
1704
|
+
)
|
1705
|
+
|
1706
|
+
if cn_cond.shape[2:] != input_noise.shape[2:]:
|
1707
|
+
cn_cond = F.interpolate(cn_cond, input_noise.shape[2:])
|
1708
|
+
|
1709
|
+
outputs = super().sample(
|
1710
|
+
input_noise=input_noise,
|
1711
|
+
diffusion_model=diffusion_model,
|
1712
|
+
controlnet=controlnet,
|
1713
|
+
cn_cond=cn_cond,
|
1714
|
+
scheduler=scheduler,
|
1715
|
+
save_intermediates=save_intermediates,
|
1716
|
+
intermediate_steps=intermediate_steps,
|
1717
|
+
conditioning=conditioning,
|
1718
|
+
mode=mode,
|
1719
|
+
verbose=verbose,
|
1720
|
+
seg=seg,
|
1721
|
+
)
|
1722
|
+
|
1723
|
+
if save_intermediates:
|
1724
|
+
latent, latent_intermediates = outputs
|
1725
|
+
else:
|
1726
|
+
latent = outputs
|
1727
|
+
|
1728
|
+
if self.autoencoder_latent_shape is not None:
|
1729
|
+
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
|
1730
|
+
latent_intermediates = [
|
1731
|
+
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
|
1732
|
+
]
|
1733
|
+
|
1734
|
+
decode = autoencoder_model.decode_stage_2_outputs
|
1735
|
+
if isinstance(autoencoder_model, SPADEAutoencoderKL):
|
1736
|
+
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
|
1737
|
+
|
1738
|
+
image = decode(latent / self.scale_factor)
|
1739
|
+
|
1740
|
+
if save_intermediates:
|
1741
|
+
intermediates = []
|
1742
|
+
for latent_intermediate in latent_intermediates:
|
1743
|
+
decode = autoencoder_model.decode_stage_2_outputs
|
1744
|
+
if isinstance(autoencoder_model, SPADEAutoencoderKL):
|
1745
|
+
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
|
1746
|
+
intermediates.append(decode(latent_intermediate / self.scale_factor))
|
1747
|
+
return image, intermediates
|
1748
|
+
|
1749
|
+
else:
|
1750
|
+
return image
|
1751
|
+
|
1752
|
+
@torch.no_grad()
|
1753
|
+
def get_likelihood( # type: ignore[override]
|
1754
|
+
self,
|
1755
|
+
inputs: torch.Tensor,
|
1756
|
+
autoencoder_model: AutoencoderKL | VQVAE,
|
1757
|
+
diffusion_model: DiffusionModelUNet,
|
1758
|
+
controlnet: ControlNet,
|
1759
|
+
cn_cond: torch.Tensor,
|
1760
|
+
scheduler: Scheduler | None = None,
|
1761
|
+
save_intermediates: bool | None = False,
|
1762
|
+
conditioning: torch.Tensor | None = None,
|
1763
|
+
mode: str = "crossattn",
|
1764
|
+
original_input_range: tuple | None = (0, 255),
|
1765
|
+
scaled_input_range: tuple | None = (0, 1),
|
1766
|
+
verbose: bool = True,
|
1767
|
+
resample_latent_likelihoods: bool = False,
|
1768
|
+
resample_interpolation_mode: str = "nearest",
|
1769
|
+
seg: torch.Tensor | None = None,
|
1770
|
+
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
1771
|
+
"""
|
1772
|
+
Computes the log-likelihoods of the latent representations of the input.
|
1773
|
+
|
1774
|
+
Args:
|
1775
|
+
inputs: input images, NxCxHxW[xD]
|
1776
|
+
autoencoder_model: first stage model.
|
1777
|
+
diffusion_model: model to compute likelihood from
|
1778
|
+
controlnet: instance of ControlNet model.
|
1779
|
+
cn_cond: conditioning tensor for the ControlNet network.
|
1780
|
+
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
|
1781
|
+
save_intermediates: save the intermediate spatial KL maps
|
1782
|
+
conditioning: Conditioning for network input.
|
1783
|
+
mode: Conditioning mode for the network.
|
1784
|
+
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
|
1785
|
+
scaled_input_range: the [min,max] intensity range of the input data after scaling.
|
1786
|
+
verbose: if true, prints the progression bar of the sampling process.
|
1787
|
+
resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
|
1788
|
+
dimension as the input images.
|
1789
|
+
resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
|
1790
|
+
or 'trilinear;
|
1791
|
+
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
|
1792
|
+
is instance of SPADEAutoencoderKL, segmentation must be provided.
|
1793
|
+
"""
|
1794
|
+
if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
|
1795
|
+
raise ValueError(
|
1796
|
+
f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
|
1797
|
+
)
|
1798
|
+
|
1799
|
+
latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
|
1800
|
+
|
1801
|
+
if cn_cond.shape[2:] != latents.shape[2:]:
|
1802
|
+
cn_cond = F.interpolate(cn_cond, latents.shape[2:])
|
1803
|
+
|
1804
|
+
if self.ldm_latent_shape is not None:
|
1805
|
+
latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)
|
1806
|
+
|
1807
|
+
outputs = super().get_likelihood(
|
1808
|
+
inputs=latents,
|
1809
|
+
diffusion_model=diffusion_model,
|
1810
|
+
controlnet=controlnet,
|
1811
|
+
cn_cond=cn_cond,
|
1812
|
+
scheduler=scheduler,
|
1813
|
+
save_intermediates=save_intermediates,
|
1814
|
+
conditioning=conditioning,
|
1815
|
+
mode=mode,
|
1816
|
+
verbose=verbose,
|
1817
|
+
seg=seg,
|
1818
|
+
)
|
1819
|
+
|
1820
|
+
if save_intermediates and resample_latent_likelihoods:
|
1821
|
+
intermediates = outputs[1]
|
1822
|
+
resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)
|
1823
|
+
intermediates = [resizer(x) for x in intermediates]
|
1824
|
+
outputs = (outputs[0], intermediates)
|
1825
|
+
return outputs
|
1826
|
+
|
1827
|
+
|
1828
|
+
class VQVAETransformerInferer(nn.Module):
|
1829
|
+
"""
|
1830
|
+
Class to perform inference with a VQVAE + Transformer model.
|
1831
|
+
"""
|
1832
|
+
|
1833
|
+
def __init__(self) -> None:
|
1834
|
+
Inferer.__init__(self)
|
1835
|
+
|
1836
|
+
def __call__(
|
1837
|
+
self,
|
1838
|
+
inputs: torch.Tensor,
|
1839
|
+
vqvae_model: VQVAE,
|
1840
|
+
transformer_model: DecoderOnlyTransformer,
|
1841
|
+
ordering: Ordering,
|
1842
|
+
condition: torch.Tensor | None = None,
|
1843
|
+
return_latent: bool = False,
|
1844
|
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]:
|
1845
|
+
"""
|
1846
|
+
Implements the forward pass for a supervised training iteration.
|
1847
|
+
|
1848
|
+
Args:
|
1849
|
+
inputs: input image to which the latent representation will be extracted.
|
1850
|
+
vqvae_model: first stage model.
|
1851
|
+
transformer_model: autoregressive transformer model.
|
1852
|
+
ordering: ordering of the quantised latent representation.
|
1853
|
+
return_latent: also return latent sequence and spatial dim of the latent.
|
1854
|
+
condition: conditioning for network input.
|
1855
|
+
"""
|
1856
|
+
with torch.no_grad():
|
1857
|
+
latent = vqvae_model.index_quantize(inputs)
|
1858
|
+
|
1859
|
+
latent_spatial_dim = tuple(latent.shape[1:])
|
1860
|
+
latent = latent.reshape(latent.shape[0], -1)
|
1861
|
+
latent = latent[:, ordering.get_sequence_ordering()]
|
1862
|
+
|
1863
|
+
# get the targets for the loss
|
1864
|
+
target = latent.clone()
|
1865
|
+
# Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token.
|
1866
|
+
# Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens.
|
1867
|
+
latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings)
|
1868
|
+
# crop the last token as we do not need the probability of the token that follows it
|
1869
|
+
latent = latent[:, :-1]
|
1870
|
+
latent = latent.long()
|
1871
|
+
|
1872
|
+
# train on a part of the sequence if it is longer than max_seq_length
|
1873
|
+
seq_len = latent.shape[1]
|
1874
|
+
max_seq_len = transformer_model.max_seq_len
|
1875
|
+
if max_seq_len < seq_len:
|
1876
|
+
start = int(torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item())
|
1877
|
+
else:
|
1878
|
+
start = 0
|
1879
|
+
prediction: torch.Tensor = transformer_model(x=latent[:, start : start + max_seq_len], context=condition)
|
1880
|
+
if return_latent:
|
1881
|
+
return prediction, target[:, start : start + max_seq_len], latent_spatial_dim
|
1882
|
+
else:
|
1883
|
+
return prediction
|
1884
|
+
|
1885
|
+
@torch.no_grad()
|
1886
|
+
def sample(
|
1887
|
+
self,
|
1888
|
+
latent_spatial_dim: tuple[int, int, int] | tuple[int, int],
|
1889
|
+
starting_tokens: torch.Tensor,
|
1890
|
+
vqvae_model: VQVAE,
|
1891
|
+
transformer_model: DecoderOnlyTransformer,
|
1892
|
+
ordering: Ordering,
|
1893
|
+
conditioning: torch.Tensor | None = None,
|
1894
|
+
temperature: float = 1.0,
|
1895
|
+
top_k: int | None = None,
|
1896
|
+
verbose: bool = True,
|
1897
|
+
) -> torch.Tensor:
|
1898
|
+
"""
|
1899
|
+
Sampling function for the VQVAE + Transformer model.
|
1900
|
+
|
1901
|
+
Args:
|
1902
|
+
latent_spatial_dim: shape of the sampled image.
|
1903
|
+
starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value.
|
1904
|
+
vqvae_model: first stage model.
|
1905
|
+
transformer_model: model to sample from.
|
1906
|
+
conditioning: Conditioning for network input.
|
1907
|
+
temperature: temperature for sampling.
|
1908
|
+
top_k: top k sampling.
|
1909
|
+
verbose: if true, prints the progression bar of the sampling process.
|
1910
|
+
"""
|
1911
|
+
seq_len = math.prod(latent_spatial_dim)
|
1912
|
+
|
1913
|
+
if verbose and has_tqdm:
|
1914
|
+
progress_bar = tqdm(range(seq_len))
|
1915
|
+
else:
|
1916
|
+
progress_bar = iter(range(seq_len))
|
1917
|
+
|
1918
|
+
latent_seq = starting_tokens.long()
|
1919
|
+
for _ in progress_bar:
|
1920
|
+
# if the sequence context is growing too long we must crop it at block_size
|
1921
|
+
if latent_seq.size(1) <= transformer_model.max_seq_len:
|
1922
|
+
idx_cond = latent_seq
|
1923
|
+
else:
|
1924
|
+
idx_cond = latent_seq[:, -transformer_model.max_seq_len :]
|
1925
|
+
|
1926
|
+
# forward the model to get the logits for the index in the sequence
|
1927
|
+
logits = transformer_model(x=idx_cond, context=conditioning)
|
1928
|
+
# pluck the logits at the final step and scale by desired temperature
|
1929
|
+
logits = logits[:, -1, :] / temperature
|
1930
|
+
# optionally crop the logits to only the top k options
|
1931
|
+
if top_k is not None:
|
1932
|
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
1933
|
+
logits[logits < v[:, [-1]]] = -float("Inf")
|
1934
|
+
# apply softmax to convert logits to (normalized) probabilities
|
1935
|
+
probs = F.softmax(logits, dim=-1)
|
1936
|
+
# remove the chance to be sampled the BOS token
|
1937
|
+
probs[:, vqvae_model.num_embeddings] = 0
|
1938
|
+
# sample from the distribution
|
1939
|
+
idx_next = torch.multinomial(probs, num_samples=1)
|
1940
|
+
# append sampled index to the running sequence and continue
|
1941
|
+
latent_seq = torch.cat((latent_seq, idx_next), dim=1)
|
1942
|
+
|
1943
|
+
latent_seq = latent_seq[:, 1:]
|
1944
|
+
latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()]
|
1945
|
+
latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim)
|
1946
|
+
|
1947
|
+
return vqvae_model.decode_samples(latent)
|
1948
|
+
|
1949
|
+
@torch.no_grad()
|
1950
|
+
def get_likelihood(
|
1951
|
+
self,
|
1952
|
+
inputs: torch.Tensor,
|
1953
|
+
vqvae_model: VQVAE,
|
1954
|
+
transformer_model: DecoderOnlyTransformer,
|
1955
|
+
ordering: Ordering,
|
1956
|
+
condition: torch.Tensor | None = None,
|
1957
|
+
resample_latent_likelihoods: bool = False,
|
1958
|
+
resample_interpolation_mode: str = "nearest",
|
1959
|
+
verbose: bool = False,
|
1960
|
+
) -> torch.Tensor:
|
1961
|
+
"""
|
1962
|
+
Computes the log-likelihoods of the latent representations of the input.
|
1963
|
+
|
1964
|
+
Args:
|
1965
|
+
inputs: input images, NxCxHxW[xD]
|
1966
|
+
vqvae_model: first stage model.
|
1967
|
+
transformer_model: autoregressive transformer model.
|
1968
|
+
ordering: ordering of the quantised latent representation.
|
1969
|
+
condition: conditioning for network input.
|
1970
|
+
resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
|
1971
|
+
dimension as the input images.
|
1972
|
+
resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
|
1973
|
+
or 'trilinear;
|
1974
|
+
verbose: if true, prints the progression bar of the sampling process.
|
1975
|
+
|
1976
|
+
"""
|
1977
|
+
if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
|
1978
|
+
raise ValueError(
|
1979
|
+
f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
|
1980
|
+
)
|
1981
|
+
|
1982
|
+
with torch.no_grad():
|
1983
|
+
latent = vqvae_model.index_quantize(inputs)
|
1984
|
+
|
1985
|
+
latent_spatial_dim = tuple(latent.shape[1:])
|
1986
|
+
latent = latent.reshape(latent.shape[0], -1)
|
1987
|
+
latent = latent[:, ordering.get_sequence_ordering()]
|
1988
|
+
seq_len = math.prod(latent_spatial_dim)
|
1989
|
+
|
1990
|
+
# Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token.
|
1991
|
+
# Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens.
|
1992
|
+
latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings)
|
1993
|
+
latent = latent.long()
|
1994
|
+
|
1995
|
+
# get the first batch, up to max_seq_length, efficiently
|
1996
|
+
logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition)
|
1997
|
+
probs = F.softmax(logits, dim=-1)
|
1998
|
+
# target token for each set of logits is the next token along
|
1999
|
+
target = latent[:, 1:]
|
2000
|
+
probs = torch.gather(probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2)
|
2001
|
+
|
2002
|
+
# if we have not covered the full sequence we continue with inefficient looping
|
2003
|
+
if probs.shape[1] < target.shape[1]:
|
2004
|
+
if verbose and has_tqdm:
|
2005
|
+
progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len))
|
2006
|
+
else:
|
2007
|
+
progress_bar = iter(range(transformer_model.max_seq_len, seq_len))
|
2008
|
+
|
2009
|
+
for i in progress_bar:
|
2010
|
+
idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1]
|
2011
|
+
# forward the model to get the logits for the index in the sequence
|
2012
|
+
logits = transformer_model(x=idx_cond, context=condition)
|
2013
|
+
# pluck the logits at the final step
|
2014
|
+
logits = logits[:, -1, :]
|
2015
|
+
# apply softmax to convert logits to (normalized) probabilities
|
2016
|
+
p = F.softmax(logits, dim=-1)
|
2017
|
+
# select correct values and append
|
2018
|
+
p = torch.gather(p, 1, target[:, i].unsqueeze(1))
|
2019
|
+
|
2020
|
+
probs = torch.cat((probs, p), dim=1)
|
2021
|
+
|
2022
|
+
# convert to log-likelihood
|
2023
|
+
probs = torch.log(probs)
|
2024
|
+
|
2025
|
+
# reshape
|
2026
|
+
probs = probs[:, ordering.get_revert_sequence_ordering()]
|
2027
|
+
probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim)
|
2028
|
+
if resample_latent_likelihoods:
|
2029
|
+
resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)
|
2030
|
+
probs_reshaped = resizer(probs_reshaped[:, None, ...])
|
2031
|
+
|
2032
|
+
return probs_reshaped
|