monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
monai/inferers/inferer.py CHANGED
@@ -11,24 +11,41 @@
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ import math
14
15
  import warnings
15
16
  from abc import ABC, abstractmethod
16
17
  from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
18
+ from functools import partial
17
19
  from pydoc import locate
18
20
  from typing import Any
19
21
 
20
22
  import torch
21
23
  import torch.nn as nn
24
+ import torch.nn.functional as F
22
25
 
23
26
  from monai.apps.utils import get_logger
27
+ from monai.data import decollate_batch
24
28
  from monai.data.meta_tensor import MetaTensor
25
29
  from monai.data.thread_buffer import ThreadBuffer
26
30
  from monai.inferers.merger import AvgMerger, Merger
27
31
  from monai.inferers.splitter import Splitter
28
32
  from monai.inferers.utils import compute_importance_map, sliding_window_inference
29
- from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import
33
+ from monai.networks.nets import (
34
+ VQVAE,
35
+ AutoencoderKL,
36
+ ControlNet,
37
+ DecoderOnlyTransformer,
38
+ DiffusionModelUNet,
39
+ SPADEAutoencoderKL,
40
+ SPADEDiffusionModelUNet,
41
+ )
42
+ from monai.networks.schedulers import Scheduler
43
+ from monai.transforms import CenterSpatialCrop, SpatialPad
44
+ from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import
30
45
  from monai.visualize import CAM, GradCAM, GradCAMpp
31
46
 
47
+ tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
48
+
32
49
  logger = get_logger(__name__)
33
50
 
34
51
  __all__ = [
@@ -752,3 +769,1264 @@ class SliceInferer(SlidingWindowInferer):
752
769
  return out
753
770
 
754
771
  return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out)
772
+
773
+
774
+ class DiffusionInferer(Inferer):
775
+ """
776
+ DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass
777
+ for a training iteration, and sample from the model.
778
+
779
+ Args:
780
+ scheduler: diffusion scheduler.
781
+ """
782
+
783
+ def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override]
784
+ super().__init__()
785
+
786
+ self.scheduler = scheduler
787
+
788
+ def __call__( # type: ignore[override]
789
+ self,
790
+ inputs: torch.Tensor,
791
+ diffusion_model: DiffusionModelUNet,
792
+ noise: torch.Tensor,
793
+ timesteps: torch.Tensor,
794
+ condition: torch.Tensor | None = None,
795
+ mode: str = "crossattn",
796
+ seg: torch.Tensor | None = None,
797
+ ) -> torch.Tensor:
798
+ """
799
+ Implements the forward pass for a supervised training iteration.
800
+
801
+ Args:
802
+ inputs: Input image to which noise is added.
803
+ diffusion_model: diffusion model.
804
+ noise: random noise, of the same shape as the input.
805
+ timesteps: random timesteps.
806
+ condition: Conditioning for network input.
807
+ mode: Conditioning mode for the network.
808
+ seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be
809
+ provided on the forward (for SPADE-like AE or SPADE-like DM)
810
+ """
811
+ if mode not in ["crossattn", "concat"]:
812
+ raise NotImplementedError(f"{mode} condition is not supported")
813
+
814
+ noisy_image: torch.Tensor = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
815
+ if mode == "concat":
816
+ if condition is None:
817
+ raise ValueError("Conditioning is required for concat condition")
818
+ else:
819
+ noisy_image = torch.cat([noisy_image, condition], dim=1)
820
+ condition = None
821
+ diffusion_model = (
822
+ partial(diffusion_model, seg=seg)
823
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet)
824
+ else diffusion_model
825
+ )
826
+ prediction: torch.Tensor = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition)
827
+
828
+ return prediction
829
+
830
+ @torch.no_grad()
831
+ def sample(
832
+ self,
833
+ input_noise: torch.Tensor,
834
+ diffusion_model: DiffusionModelUNet,
835
+ scheduler: Scheduler | None = None,
836
+ save_intermediates: bool | None = False,
837
+ intermediate_steps: int | None = 100,
838
+ conditioning: torch.Tensor | None = None,
839
+ mode: str = "crossattn",
840
+ verbose: bool = True,
841
+ seg: torch.Tensor | None = None,
842
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
843
+ """
844
+ Args:
845
+ input_noise: random noise, of the same shape as the desired sample.
846
+ diffusion_model: model to sample from.
847
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
848
+ save_intermediates: whether to return intermediates along the sampling change
849
+ intermediate_steps: if save_intermediates is True, saves every n steps
850
+ conditioning: Conditioning for network input.
851
+ mode: Conditioning mode for the network.
852
+ verbose: if true, prints the progression bar of the sampling process.
853
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
854
+ """
855
+ if mode not in ["crossattn", "concat"]:
856
+ raise NotImplementedError(f"{mode} condition is not supported")
857
+ if mode == "concat" and conditioning is None:
858
+ raise ValueError("Conditioning must be supplied for if condition mode is concat.")
859
+ if not scheduler:
860
+ scheduler = self.scheduler
861
+ image = input_noise
862
+ if verbose and has_tqdm:
863
+ progress_bar = tqdm(scheduler.timesteps)
864
+ else:
865
+ progress_bar = iter(scheduler.timesteps)
866
+ intermediates = []
867
+ for t in progress_bar:
868
+ # 1. predict noise model_output
869
+ diffusion_model = (
870
+ partial(diffusion_model, seg=seg)
871
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet)
872
+ else diffusion_model
873
+ )
874
+ if mode == "concat" and conditioning is not None:
875
+ model_input = torch.cat([image, conditioning], dim=1)
876
+ model_output = diffusion_model(
877
+ model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None
878
+ )
879
+ else:
880
+ model_output = diffusion_model(
881
+ image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
882
+ )
883
+
884
+ # 2. compute previous image: x_t -> x_t-1
885
+ image, _ = scheduler.step(model_output, t, image)
886
+ if save_intermediates and t % intermediate_steps == 0:
887
+ intermediates.append(image)
888
+ if save_intermediates:
889
+ return image, intermediates
890
+ else:
891
+ return image
892
+
893
+ @torch.no_grad()
894
+ def get_likelihood(
895
+ self,
896
+ inputs: torch.Tensor,
897
+ diffusion_model: DiffusionModelUNet,
898
+ scheduler: Scheduler | None = None,
899
+ save_intermediates: bool | None = False,
900
+ conditioning: torch.Tensor | None = None,
901
+ mode: str = "crossattn",
902
+ original_input_range: tuple = (0, 255),
903
+ scaled_input_range: tuple = (0, 1),
904
+ verbose: bool = True,
905
+ seg: torch.Tensor | None = None,
906
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
907
+ """
908
+ Computes the log-likelihoods for an input.
909
+
910
+ Args:
911
+ inputs: input images, NxCxHxW[xD]
912
+ diffusion_model: model to compute likelihood from
913
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
914
+ save_intermediates: save the intermediate spatial KL maps
915
+ conditioning: Conditioning for network input.
916
+ mode: Conditioning mode for the network.
917
+ original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
918
+ scaled_input_range: the [min,max] intensity range of the input data after scaling.
919
+ verbose: if true, prints the progression bar of the sampling process.
920
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
921
+ """
922
+
923
+ if not scheduler:
924
+ scheduler = self.scheduler
925
+ if scheduler._get_name() != "DDPMScheduler":
926
+ raise NotImplementedError(
927
+ f"Likelihood computation is only compatible with DDPMScheduler,"
928
+ f" you are using {scheduler._get_name()}"
929
+ )
930
+ if mode not in ["crossattn", "concat"]:
931
+ raise NotImplementedError(f"{mode} condition is not supported")
932
+ if mode == "concat" and conditioning is None:
933
+ raise ValueError("Conditioning must be supplied for if condition mode is concat.")
934
+ if verbose and has_tqdm:
935
+ progress_bar = tqdm(scheduler.timesteps)
936
+ else:
937
+ progress_bar = iter(scheduler.timesteps)
938
+ intermediates = []
939
+ noise = torch.randn_like(inputs).to(inputs.device)
940
+ total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
941
+ for t in progress_bar:
942
+ timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
943
+ noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
944
+ diffusion_model = (
945
+ partial(diffusion_model, seg=seg)
946
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet)
947
+ else diffusion_model
948
+ )
949
+ if mode == "concat" and conditioning is not None:
950
+ noisy_image = torch.cat([noisy_image, conditioning], dim=1)
951
+ model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None)
952
+ else:
953
+ model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning)
954
+
955
+ # get the model's predicted mean, and variance if it is predicted
956
+ if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
957
+ model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
958
+ else:
959
+ predicted_variance = None
960
+
961
+ # 1. compute alphas, betas
962
+ alpha_prod_t = scheduler.alphas_cumprod[t]
963
+ alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one
964
+ beta_prod_t = 1 - alpha_prod_t
965
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
966
+
967
+ # 2. compute predicted original sample from predicted noise also called
968
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
969
+ if scheduler.prediction_type == "epsilon":
970
+ pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
971
+ elif scheduler.prediction_type == "sample":
972
+ pred_original_sample = model_output
973
+ elif scheduler.prediction_type == "v_prediction":
974
+ pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output
975
+ # 3. Clip "predicted x_0"
976
+ if scheduler.clip_sample:
977
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
978
+
979
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
980
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
981
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t
982
+ current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
983
+
984
+ # 5. Compute predicted previous sample µ_t
985
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
986
+ predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
987
+
988
+ # get the posterior mean and variance
989
+ posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
990
+ posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
991
+
992
+ log_posterior_variance = torch.log(posterior_variance)
993
+ log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
994
+
995
+ if t == 0:
996
+ # compute -log p(x_0|x_1)
997
+ kl = -self._get_decoder_log_likelihood(
998
+ inputs=inputs,
999
+ means=predicted_mean,
1000
+ log_scales=0.5 * log_predicted_variance,
1001
+ original_input_range=original_input_range,
1002
+ scaled_input_range=scaled_input_range,
1003
+ )
1004
+ else:
1005
+ # compute kl between two normals
1006
+ kl = 0.5 * (
1007
+ -1.0
1008
+ + log_predicted_variance
1009
+ - log_posterior_variance
1010
+ + torch.exp(log_posterior_variance - log_predicted_variance)
1011
+ + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance)
1012
+ )
1013
+ total_kl += kl.view(kl.shape[0], -1).mean(dim=1)
1014
+ if save_intermediates:
1015
+ intermediates.append(kl.cpu())
1016
+
1017
+ if save_intermediates:
1018
+ return total_kl, intermediates
1019
+ else:
1020
+ return total_kl
1021
+
1022
+ def _approx_standard_normal_cdf(self, x):
1023
+ """
1024
+ A fast approximation of the cumulative distribution function of the
1025
+ standard normal. Code adapted from https://github.com/openai/improved-diffusion.
1026
+ """
1027
+
1028
+ return 0.5 * (
1029
+ 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3)))
1030
+ )
1031
+
1032
+ def _get_decoder_log_likelihood(
1033
+ self,
1034
+ inputs: torch.Tensor,
1035
+ means: torch.Tensor,
1036
+ log_scales: torch.Tensor,
1037
+ original_input_range: tuple = (0, 255),
1038
+ scaled_input_range: tuple = (0, 1),
1039
+ ) -> torch.Tensor:
1040
+ """
1041
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
1042
+ given image. Code adapted from https://github.com/openai/improved-diffusion.
1043
+
1044
+ Args:
1045
+ input: the target images. It is assumed that this was uint8 values,
1046
+ rescaled to the range [-1, 1].
1047
+ means: the Gaussian mean Tensor.
1048
+ log_scales: the Gaussian log stddev Tensor.
1049
+ original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
1050
+ scaled_input_range: the [min,max] intensity range of the input data after scaling.
1051
+ """
1052
+ if inputs.shape != means.shape:
1053
+ raise ValueError(f"Inputs and means must have the same shape, got {inputs.shape} and {means.shape}")
1054
+ bin_width = (scaled_input_range[1] - scaled_input_range[0]) / (
1055
+ original_input_range[1] - original_input_range[0]
1056
+ )
1057
+ centered_x = inputs - means
1058
+ inv_stdv = torch.exp(-log_scales)
1059
+ plus_in = inv_stdv * (centered_x + bin_width / 2)
1060
+ cdf_plus = self._approx_standard_normal_cdf(plus_in)
1061
+ min_in = inv_stdv * (centered_x - bin_width / 2)
1062
+ cdf_min = self._approx_standard_normal_cdf(min_in)
1063
+ log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
1064
+ log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
1065
+ cdf_delta = cdf_plus - cdf_min
1066
+ log_probs = torch.where(
1067
+ inputs < -0.999,
1068
+ log_cdf_plus,
1069
+ torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
1070
+ )
1071
+ return log_probs
1072
+
1073
+
1074
+ class LatentDiffusionInferer(DiffusionInferer):
1075
+ """
1076
+ LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can
1077
+ be used to perform a signal forward pass for a training iteration, and sample from the model.
1078
+
1079
+ Args:
1080
+ scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.
1081
+ scale_factor: scale factor to multiply the values of the latent representation before processing it by the
1082
+ second stage.
1083
+ ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape.
1084
+ autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a
1085
+ difference between the autoencoder's latent shape and the DM shape.
1086
+ """
1087
+
1088
+ def __init__(
1089
+ self,
1090
+ scheduler: Scheduler,
1091
+ scale_factor: float = 1.0,
1092
+ ldm_latent_shape: list | None = None,
1093
+ autoencoder_latent_shape: list | None = None,
1094
+ ) -> None:
1095
+ super().__init__(scheduler=scheduler)
1096
+ self.scale_factor = scale_factor
1097
+ if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None):
1098
+ raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None, and vice versa.")
1099
+ self.ldm_latent_shape = ldm_latent_shape
1100
+ self.autoencoder_latent_shape = autoencoder_latent_shape
1101
+ if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None:
1102
+ self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
1103
+ self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)
1104
+
1105
+ def __call__( # type: ignore[override]
1106
+ self,
1107
+ inputs: torch.Tensor,
1108
+ autoencoder_model: AutoencoderKL | VQVAE,
1109
+ diffusion_model: DiffusionModelUNet,
1110
+ noise: torch.Tensor,
1111
+ timesteps: torch.Tensor,
1112
+ condition: torch.Tensor | None = None,
1113
+ mode: str = "crossattn",
1114
+ seg: torch.Tensor | None = None,
1115
+ ) -> torch.Tensor:
1116
+ """
1117
+ Implements the forward pass for a supervised training iteration.
1118
+
1119
+ Args:
1120
+ inputs: input image to which the latent representation will be extracted and noise is added.
1121
+ autoencoder_model: first stage model.
1122
+ diffusion_model: diffusion model.
1123
+ noise: random noise, of the same shape as the latent representation.
1124
+ timesteps: random timesteps.
1125
+ condition: conditioning for network input.
1126
+ mode: Conditioning mode for the network.
1127
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
1128
+ """
1129
+ with torch.no_grad():
1130
+ latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
1131
+
1132
+ if self.ldm_latent_shape is not None:
1133
+ latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
1134
+
1135
+ prediction: torch.Tensor = super().__call__(
1136
+ inputs=latent,
1137
+ diffusion_model=diffusion_model,
1138
+ noise=noise,
1139
+ timesteps=timesteps,
1140
+ condition=condition,
1141
+ mode=mode,
1142
+ seg=seg,
1143
+ )
1144
+ return prediction
1145
+
1146
+ @torch.no_grad()
1147
+ def sample( # type: ignore[override]
1148
+ self,
1149
+ input_noise: torch.Tensor,
1150
+ autoencoder_model: AutoencoderKL | VQVAE,
1151
+ diffusion_model: DiffusionModelUNet,
1152
+ scheduler: Scheduler | None = None,
1153
+ save_intermediates: bool | None = False,
1154
+ intermediate_steps: int | None = 100,
1155
+ conditioning: torch.Tensor | None = None,
1156
+ mode: str = "crossattn",
1157
+ verbose: bool = True,
1158
+ seg: torch.Tensor | None = None,
1159
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
1160
+ """
1161
+ Args:
1162
+ input_noise: random noise, of the same shape as the desired latent representation.
1163
+ autoencoder_model: first stage model.
1164
+ diffusion_model: model to sample from.
1165
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
1166
+ save_intermediates: whether to return intermediates along the sampling change
1167
+ intermediate_steps: if save_intermediates is True, saves every n steps
1168
+ conditioning: Conditioning for network input.
1169
+ mode: Conditioning mode for the network.
1170
+ verbose: if true, prints the progression bar of the sampling process.
1171
+ seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
1172
+ is instance of SPADEAutoencoderKL, segmentation must be provided.
1173
+ """
1174
+
1175
+ if (
1176
+ isinstance(autoencoder_model, SPADEAutoencoderKL)
1177
+ and isinstance(diffusion_model, SPADEDiffusionModelUNet)
1178
+ and autoencoder_model.decoder.label_nc != diffusion_model.label_nc
1179
+ ):
1180
+ raise ValueError(
1181
+ f"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic"
1182
+ f"labels for each must be compatible, but got {autoencoder_model.decoder.label_nc} and"
1183
+ f"{diffusion_model.label_nc}"
1184
+ )
1185
+
1186
+ outputs = super().sample(
1187
+ input_noise=input_noise,
1188
+ diffusion_model=diffusion_model,
1189
+ scheduler=scheduler,
1190
+ save_intermediates=save_intermediates,
1191
+ intermediate_steps=intermediate_steps,
1192
+ conditioning=conditioning,
1193
+ mode=mode,
1194
+ verbose=verbose,
1195
+ seg=seg,
1196
+ )
1197
+
1198
+ if save_intermediates:
1199
+ latent, latent_intermediates = outputs
1200
+ else:
1201
+ latent = outputs
1202
+
1203
+ if self.autoencoder_latent_shape is not None:
1204
+ latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1205
+ latent_intermediates = [
1206
+ torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1207
+ ]
1208
+
1209
+ decode = autoencoder_model.decode_stage_2_outputs
1210
+ if isinstance(autoencoder_model, SPADEAutoencoderKL):
1211
+ decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
1212
+ image = decode(latent / self.scale_factor)
1213
+
1214
+ if save_intermediates:
1215
+ intermediates = []
1216
+ for latent_intermediate in latent_intermediates:
1217
+ decode = autoencoder_model.decode_stage_2_outputs
1218
+ if isinstance(autoencoder_model, SPADEAutoencoderKL):
1219
+ decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
1220
+ intermediates.append(decode(latent_intermediate / self.scale_factor))
1221
+ return image, intermediates
1222
+
1223
+ else:
1224
+ return image
1225
+
1226
+ @torch.no_grad()
1227
+ def get_likelihood( # type: ignore[override]
1228
+ self,
1229
+ inputs: torch.Tensor,
1230
+ autoencoder_model: AutoencoderKL | VQVAE,
1231
+ diffusion_model: DiffusionModelUNet,
1232
+ scheduler: Scheduler | None = None,
1233
+ save_intermediates: bool | None = False,
1234
+ conditioning: torch.Tensor | None = None,
1235
+ mode: str = "crossattn",
1236
+ original_input_range: tuple | None = (0, 255),
1237
+ scaled_input_range: tuple | None = (0, 1),
1238
+ verbose: bool = True,
1239
+ resample_latent_likelihoods: bool = False,
1240
+ resample_interpolation_mode: str = "nearest",
1241
+ seg: torch.Tensor | None = None,
1242
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
1243
+ """
1244
+ Computes the log-likelihoods of the latent representations of the input.
1245
+
1246
+ Args:
1247
+ inputs: input images, NxCxHxW[xD]
1248
+ autoencoder_model: first stage model.
1249
+ diffusion_model: model to compute likelihood from
1250
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
1251
+ save_intermediates: save the intermediate spatial KL maps
1252
+ conditioning: Conditioning for network input.
1253
+ mode: Conditioning mode for the network.
1254
+ original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
1255
+ scaled_input_range: the [min,max] intensity range of the input data after scaling.
1256
+ verbose: if true, prints the progression bar of the sampling process.
1257
+ resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
1258
+ dimension as the input images.
1259
+ resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
1260
+ or 'trilinear;
1261
+ seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
1262
+ is instance of SPADEAutoencoderKL, segmentation must be provided.
1263
+ """
1264
+ if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
1265
+ raise ValueError(
1266
+ f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
1267
+ )
1268
+ latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
1269
+
1270
+ if self.ldm_latent_shape is not None:
1271
+ latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)
1272
+
1273
+ outputs = super().get_likelihood(
1274
+ inputs=latents,
1275
+ diffusion_model=diffusion_model,
1276
+ scheduler=scheduler,
1277
+ save_intermediates=save_intermediates,
1278
+ conditioning=conditioning,
1279
+ mode=mode,
1280
+ verbose=verbose,
1281
+ seg=seg,
1282
+ )
1283
+
1284
+ if save_intermediates and resample_latent_likelihoods:
1285
+ intermediates = outputs[1]
1286
+ resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)
1287
+ intermediates = [resizer(x) for x in intermediates]
1288
+ outputs = (outputs[0], intermediates)
1289
+ return outputs
1290
+
1291
+
1292
+ class ControlNetDiffusionInferer(DiffusionInferer):
1293
+ """
1294
+ ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal
1295
+ forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning.
1296
+
1297
+ Args:
1298
+ scheduler: diffusion scheduler.
1299
+ """
1300
+
1301
+ def __init__(self, scheduler: Scheduler) -> None:
1302
+ Inferer.__init__(self)
1303
+ self.scheduler = scheduler
1304
+
1305
+ def __call__( # type: ignore[override]
1306
+ self,
1307
+ inputs: torch.Tensor,
1308
+ diffusion_model: DiffusionModelUNet,
1309
+ controlnet: ControlNet,
1310
+ noise: torch.Tensor,
1311
+ timesteps: torch.Tensor,
1312
+ cn_cond: torch.Tensor,
1313
+ condition: torch.Tensor | None = None,
1314
+ mode: str = "crossattn",
1315
+ seg: torch.Tensor | None = None,
1316
+ ) -> torch.Tensor:
1317
+ """
1318
+ Implements the forward pass for a supervised training iteration.
1319
+
1320
+ Args:
1321
+ inputs: Input image to which noise is added.
1322
+ diffusion_model: diffusion model.
1323
+ controlnet: controlnet sub-network.
1324
+ noise: random noise, of the same shape as the input.
1325
+ timesteps: random timesteps.
1326
+ cn_cond: conditioning image for the ControlNet.
1327
+ condition: Conditioning for network input.
1328
+ mode: Conditioning mode for the network.
1329
+ seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be
1330
+ provided on the forward (for SPADE-like AE or SPADE-like DM)
1331
+ """
1332
+ if mode not in ["crossattn", "concat"]:
1333
+ raise NotImplementedError(f"{mode} condition is not supported")
1334
+
1335
+ noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1336
+ down_block_res_samples, mid_block_res_sample = controlnet(
1337
+ x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
1338
+ )
1339
+ if mode == "concat" and condition is not None:
1340
+ noisy_image = torch.cat([noisy_image, condition], dim=1)
1341
+ condition = None
1342
+
1343
+ diffuse = diffusion_model
1344
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet):
1345
+ diffuse = partial(diffusion_model, seg=seg)
1346
+
1347
+ prediction: torch.Tensor = diffuse(
1348
+ x=noisy_image,
1349
+ timesteps=timesteps,
1350
+ context=condition,
1351
+ down_block_additional_residuals=down_block_res_samples,
1352
+ mid_block_additional_residual=mid_block_res_sample,
1353
+ )
1354
+
1355
+ return prediction
1356
+
1357
+ @torch.no_grad()
1358
+ def sample( # type: ignore[override]
1359
+ self,
1360
+ input_noise: torch.Tensor,
1361
+ diffusion_model: DiffusionModelUNet,
1362
+ controlnet: ControlNet,
1363
+ cn_cond: torch.Tensor,
1364
+ scheduler: Scheduler | None = None,
1365
+ save_intermediates: bool | None = False,
1366
+ intermediate_steps: int | None = 100,
1367
+ conditioning: torch.Tensor | None = None,
1368
+ mode: str = "crossattn",
1369
+ verbose: bool = True,
1370
+ seg: torch.Tensor | None = None,
1371
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
1372
+ """
1373
+ Args:
1374
+ input_noise: random noise, of the same shape as the desired sample.
1375
+ diffusion_model: model to sample from.
1376
+ controlnet: controlnet sub-network.
1377
+ cn_cond: conditioning image for the ControlNet.
1378
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
1379
+ save_intermediates: whether to return intermediates along the sampling change
1380
+ intermediate_steps: if save_intermediates is True, saves every n steps
1381
+ conditioning: Conditioning for network input.
1382
+ mode: Conditioning mode for the network.
1383
+ verbose: if true, prints the progression bar of the sampling process.
1384
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
1385
+ """
1386
+ if mode not in ["crossattn", "concat"]:
1387
+ raise NotImplementedError(f"{mode} condition is not supported")
1388
+
1389
+ if not scheduler:
1390
+ scheduler = self.scheduler
1391
+ image = input_noise
1392
+ if verbose and has_tqdm:
1393
+ progress_bar = tqdm(scheduler.timesteps)
1394
+ else:
1395
+ progress_bar = iter(scheduler.timesteps)
1396
+ intermediates = []
1397
+ for t in progress_bar:
1398
+ # 1. ControlNet forward
1399
+ down_block_res_samples, mid_block_res_sample = controlnet(
1400
+ x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
1401
+ )
1402
+ # 2. predict noise model_output
1403
+ diffuse = diffusion_model
1404
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet):
1405
+ diffuse = partial(diffusion_model, seg=seg)
1406
+
1407
+ if mode == "concat" and conditioning is not None:
1408
+ model_input = torch.cat([image, conditioning], dim=1)
1409
+ model_output = diffuse(
1410
+ model_input,
1411
+ timesteps=torch.Tensor((t,)).to(input_noise.device),
1412
+ context=None,
1413
+ down_block_additional_residuals=down_block_res_samples,
1414
+ mid_block_additional_residual=mid_block_res_sample,
1415
+ )
1416
+ else:
1417
+ model_output = diffuse(
1418
+ image,
1419
+ timesteps=torch.Tensor((t,)).to(input_noise.device),
1420
+ context=conditioning,
1421
+ down_block_additional_residuals=down_block_res_samples,
1422
+ mid_block_additional_residual=mid_block_res_sample,
1423
+ )
1424
+
1425
+ # 3. compute previous image: x_t -> x_t-1
1426
+ image, _ = scheduler.step(model_output, t, image)
1427
+ if save_intermediates and t % intermediate_steps == 0:
1428
+ intermediates.append(image)
1429
+ if save_intermediates:
1430
+ return image, intermediates
1431
+ else:
1432
+ return image
1433
+
1434
+ @torch.no_grad()
1435
+ def get_likelihood( # type: ignore[override]
1436
+ self,
1437
+ inputs: torch.Tensor,
1438
+ diffusion_model: DiffusionModelUNet,
1439
+ controlnet: ControlNet,
1440
+ cn_cond: torch.Tensor,
1441
+ scheduler: Scheduler | None = None,
1442
+ save_intermediates: bool | None = False,
1443
+ conditioning: torch.Tensor | None = None,
1444
+ mode: str = "crossattn",
1445
+ original_input_range: tuple = (0, 255),
1446
+ scaled_input_range: tuple = (0, 1),
1447
+ verbose: bool = True,
1448
+ seg: torch.Tensor | None = None,
1449
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
1450
+ """
1451
+ Computes the log-likelihoods for an input.
1452
+
1453
+ Args:
1454
+ inputs: input images, NxCxHxW[xD]
1455
+ diffusion_model: model to compute likelihood from
1456
+ controlnet: controlnet sub-network.
1457
+ cn_cond: conditioning image for the ControlNet.
1458
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
1459
+ save_intermediates: save the intermediate spatial KL maps
1460
+ conditioning: Conditioning for network input.
1461
+ mode: Conditioning mode for the network.
1462
+ original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
1463
+ scaled_input_range: the [min,max] intensity range of the input data after scaling.
1464
+ verbose: if true, prints the progression bar of the sampling process.
1465
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
1466
+ """
1467
+
1468
+ if not scheduler:
1469
+ scheduler = self.scheduler
1470
+ if scheduler._get_name() != "DDPMScheduler":
1471
+ raise NotImplementedError(
1472
+ f"Likelihood computation is only compatible with DDPMScheduler,"
1473
+ f" you are using {scheduler._get_name()}"
1474
+ )
1475
+ if mode not in ["crossattn", "concat"]:
1476
+ raise NotImplementedError(f"{mode} condition is not supported")
1477
+ if verbose and has_tqdm:
1478
+ progress_bar = tqdm(scheduler.timesteps)
1479
+ else:
1480
+ progress_bar = iter(scheduler.timesteps)
1481
+ intermediates = []
1482
+ noise = torch.randn_like(inputs).to(inputs.device)
1483
+ total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
1484
+ for t in progress_bar:
1485
+ timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
1486
+ noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
1487
+ down_block_res_samples, mid_block_res_sample = controlnet(
1488
+ x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
1489
+ )
1490
+
1491
+ diffuse = diffusion_model
1492
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet):
1493
+ diffuse = partial(diffusion_model, seg=seg)
1494
+
1495
+ if mode == "concat" and conditioning is not None:
1496
+ noisy_image = torch.cat([noisy_image, conditioning], dim=1)
1497
+ model_output = diffuse(
1498
+ noisy_image,
1499
+ timesteps=timesteps,
1500
+ context=None,
1501
+ down_block_additional_residuals=down_block_res_samples,
1502
+ mid_block_additional_residual=mid_block_res_sample,
1503
+ )
1504
+ else:
1505
+ model_output = diffuse(
1506
+ x=noisy_image,
1507
+ timesteps=timesteps,
1508
+ context=conditioning,
1509
+ down_block_additional_residuals=down_block_res_samples,
1510
+ mid_block_additional_residual=mid_block_res_sample,
1511
+ )
1512
+ # get the model's predicted mean, and variance if it is predicted
1513
+ if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
1514
+ model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
1515
+ else:
1516
+ predicted_variance = None
1517
+
1518
+ # 1. compute alphas, betas
1519
+ alpha_prod_t = scheduler.alphas_cumprod[t]
1520
+ alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one
1521
+ beta_prod_t = 1 - alpha_prod_t
1522
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
1523
+
1524
+ # 2. compute predicted original sample from predicted noise also called
1525
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
1526
+ if scheduler.prediction_type == "epsilon":
1527
+ pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
1528
+ elif scheduler.prediction_type == "sample":
1529
+ pred_original_sample = model_output
1530
+ elif scheduler.prediction_type == "v_prediction":
1531
+ pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output
1532
+ # 3. Clip "predicted x_0"
1533
+ if scheduler.clip_sample:
1534
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
1535
+
1536
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
1537
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
1538
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t
1539
+ current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
1540
+
1541
+ # 5. Compute predicted previous sample µ_t
1542
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
1543
+ predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
1544
+
1545
+ # get the posterior mean and variance
1546
+ posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
1547
+ posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
1548
+
1549
+ log_posterior_variance = torch.log(posterior_variance)
1550
+ log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
1551
+
1552
+ if t == 0:
1553
+ # compute -log p(x_0|x_1)
1554
+ kl = -super()._get_decoder_log_likelihood(
1555
+ inputs=inputs,
1556
+ means=predicted_mean,
1557
+ log_scales=0.5 * log_predicted_variance,
1558
+ original_input_range=original_input_range,
1559
+ scaled_input_range=scaled_input_range,
1560
+ )
1561
+ else:
1562
+ # compute kl between two normals
1563
+ kl = 0.5 * (
1564
+ -1.0
1565
+ + log_predicted_variance
1566
+ - log_posterior_variance
1567
+ + torch.exp(log_posterior_variance - log_predicted_variance)
1568
+ + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance)
1569
+ )
1570
+ total_kl += kl.view(kl.shape[0], -1).mean(dim=1)
1571
+ if save_intermediates:
1572
+ intermediates.append(kl.cpu())
1573
+
1574
+ if save_intermediates:
1575
+ return total_kl, intermediates
1576
+ else:
1577
+ return total_kl
1578
+
1579
+
1580
+ class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
1581
+ """
1582
+ ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet,
1583
+ and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from
1584
+ the model.
1585
+
1586
+ Args:
1587
+ scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.
1588
+ scale_factor: scale factor to multiply the values of the latent representation before processing it by the
1589
+ second stage.
1590
+ ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape.
1591
+ autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a
1592
+ difference between the autoencoder's latent shape and the DM shape.
1593
+ """
1594
+
1595
+ def __init__(
1596
+ self,
1597
+ scheduler: Scheduler,
1598
+ scale_factor: float = 1.0,
1599
+ ldm_latent_shape: list | None = None,
1600
+ autoencoder_latent_shape: list | None = None,
1601
+ ) -> None:
1602
+ super().__init__(scheduler=scheduler)
1603
+ self.scale_factor = scale_factor
1604
+ if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None):
1605
+ raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.")
1606
+ self.ldm_latent_shape = ldm_latent_shape
1607
+ self.autoencoder_latent_shape = autoencoder_latent_shape
1608
+ if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None:
1609
+ self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
1610
+ self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)
1611
+
1612
+ def __call__( # type: ignore[override]
1613
+ self,
1614
+ inputs: torch.Tensor,
1615
+ autoencoder_model: AutoencoderKL | VQVAE,
1616
+ diffusion_model: DiffusionModelUNet,
1617
+ controlnet: ControlNet,
1618
+ noise: torch.Tensor,
1619
+ timesteps: torch.Tensor,
1620
+ cn_cond: torch.Tensor,
1621
+ condition: torch.Tensor | None = None,
1622
+ mode: str = "crossattn",
1623
+ seg: torch.Tensor | None = None,
1624
+ ) -> torch.Tensor:
1625
+ """
1626
+ Implements the forward pass for a supervised training iteration.
1627
+
1628
+ Args:
1629
+ inputs: input image to which the latent representation will be extracted and noise is added.
1630
+ autoencoder_model: first stage model.
1631
+ diffusion_model: diffusion model.
1632
+ controlnet: instance of ControlNet model
1633
+ noise: random noise, of the same shape as the latent representation.
1634
+ timesteps: random timesteps.
1635
+ cn_cond: conditioning tensor for the ControlNet network
1636
+ condition: conditioning for network input.
1637
+ mode: Conditioning mode for the network.
1638
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
1639
+ """
1640
+ with torch.no_grad():
1641
+ latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
1642
+
1643
+ if self.ldm_latent_shape is not None:
1644
+ latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
1645
+
1646
+ if cn_cond.shape[2:] != latent.shape[2:]:
1647
+ cn_cond = F.interpolate(cn_cond, latent.shape[2:])
1648
+
1649
+ prediction = super().__call__(
1650
+ inputs=latent,
1651
+ diffusion_model=diffusion_model,
1652
+ controlnet=controlnet,
1653
+ noise=noise,
1654
+ timesteps=timesteps,
1655
+ cn_cond=cn_cond,
1656
+ condition=condition,
1657
+ mode=mode,
1658
+ seg=seg,
1659
+ )
1660
+
1661
+ return prediction
1662
+
1663
+ @torch.no_grad()
1664
+ def sample( # type: ignore[override]
1665
+ self,
1666
+ input_noise: torch.Tensor,
1667
+ autoencoder_model: AutoencoderKL | VQVAE,
1668
+ diffusion_model: DiffusionModelUNet,
1669
+ controlnet: ControlNet,
1670
+ cn_cond: torch.Tensor,
1671
+ scheduler: Scheduler | None = None,
1672
+ save_intermediates: bool | None = False,
1673
+ intermediate_steps: int | None = 100,
1674
+ conditioning: torch.Tensor | None = None,
1675
+ mode: str = "crossattn",
1676
+ verbose: bool = True,
1677
+ seg: torch.Tensor | None = None,
1678
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
1679
+ """
1680
+ Args:
1681
+ input_noise: random noise, of the same shape as the desired latent representation.
1682
+ autoencoder_model: first stage model.
1683
+ diffusion_model: model to sample from.
1684
+ controlnet: instance of ControlNet model.
1685
+ cn_cond: conditioning tensor for the ControlNet network.
1686
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
1687
+ save_intermediates: whether to return intermediates along the sampling change
1688
+ intermediate_steps: if save_intermediates is True, saves every n steps
1689
+ conditioning: Conditioning for network input.
1690
+ mode: Conditioning mode for the network.
1691
+ verbose: if true, prints the progression bar of the sampling process.
1692
+ seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
1693
+ is instance of SPADEAutoencoderKL, segmentation must be provided.
1694
+ """
1695
+
1696
+ if (
1697
+ isinstance(autoencoder_model, SPADEAutoencoderKL)
1698
+ and isinstance(diffusion_model, SPADEDiffusionModelUNet)
1699
+ and autoencoder_model.decoder.label_nc != diffusion_model.label_nc
1700
+ ):
1701
+ raise ValueError(
1702
+ "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic"
1703
+ "labels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}"
1704
+ )
1705
+
1706
+ if cn_cond.shape[2:] != input_noise.shape[2:]:
1707
+ cn_cond = F.interpolate(cn_cond, input_noise.shape[2:])
1708
+
1709
+ outputs = super().sample(
1710
+ input_noise=input_noise,
1711
+ diffusion_model=diffusion_model,
1712
+ controlnet=controlnet,
1713
+ cn_cond=cn_cond,
1714
+ scheduler=scheduler,
1715
+ save_intermediates=save_intermediates,
1716
+ intermediate_steps=intermediate_steps,
1717
+ conditioning=conditioning,
1718
+ mode=mode,
1719
+ verbose=verbose,
1720
+ seg=seg,
1721
+ )
1722
+
1723
+ if save_intermediates:
1724
+ latent, latent_intermediates = outputs
1725
+ else:
1726
+ latent = outputs
1727
+
1728
+ if self.autoencoder_latent_shape is not None:
1729
+ latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
1730
+ latent_intermediates = [
1731
+ torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
1732
+ ]
1733
+
1734
+ decode = autoencoder_model.decode_stage_2_outputs
1735
+ if isinstance(autoencoder_model, SPADEAutoencoderKL):
1736
+ decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
1737
+
1738
+ image = decode(latent / self.scale_factor)
1739
+
1740
+ if save_intermediates:
1741
+ intermediates = []
1742
+ for latent_intermediate in latent_intermediates:
1743
+ decode = autoencoder_model.decode_stage_2_outputs
1744
+ if isinstance(autoencoder_model, SPADEAutoencoderKL):
1745
+ decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
1746
+ intermediates.append(decode(latent_intermediate / self.scale_factor))
1747
+ return image, intermediates
1748
+
1749
+ else:
1750
+ return image
1751
+
1752
+ @torch.no_grad()
1753
+ def get_likelihood( # type: ignore[override]
1754
+ self,
1755
+ inputs: torch.Tensor,
1756
+ autoencoder_model: AutoencoderKL | VQVAE,
1757
+ diffusion_model: DiffusionModelUNet,
1758
+ controlnet: ControlNet,
1759
+ cn_cond: torch.Tensor,
1760
+ scheduler: Scheduler | None = None,
1761
+ save_intermediates: bool | None = False,
1762
+ conditioning: torch.Tensor | None = None,
1763
+ mode: str = "crossattn",
1764
+ original_input_range: tuple | None = (0, 255),
1765
+ scaled_input_range: tuple | None = (0, 1),
1766
+ verbose: bool = True,
1767
+ resample_latent_likelihoods: bool = False,
1768
+ resample_interpolation_mode: str = "nearest",
1769
+ seg: torch.Tensor | None = None,
1770
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
1771
+ """
1772
+ Computes the log-likelihoods of the latent representations of the input.
1773
+
1774
+ Args:
1775
+ inputs: input images, NxCxHxW[xD]
1776
+ autoencoder_model: first stage model.
1777
+ diffusion_model: model to compute likelihood from
1778
+ controlnet: instance of ControlNet model.
1779
+ cn_cond: conditioning tensor for the ControlNet network.
1780
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
1781
+ save_intermediates: save the intermediate spatial KL maps
1782
+ conditioning: Conditioning for network input.
1783
+ mode: Conditioning mode for the network.
1784
+ original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
1785
+ scaled_input_range: the [min,max] intensity range of the input data after scaling.
1786
+ verbose: if true, prints the progression bar of the sampling process.
1787
+ resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
1788
+ dimension as the input images.
1789
+ resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
1790
+ or 'trilinear;
1791
+ seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
1792
+ is instance of SPADEAutoencoderKL, segmentation must be provided.
1793
+ """
1794
+ if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
1795
+ raise ValueError(
1796
+ f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
1797
+ )
1798
+
1799
+ latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
1800
+
1801
+ if cn_cond.shape[2:] != latents.shape[2:]:
1802
+ cn_cond = F.interpolate(cn_cond, latents.shape[2:])
1803
+
1804
+ if self.ldm_latent_shape is not None:
1805
+ latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)
1806
+
1807
+ outputs = super().get_likelihood(
1808
+ inputs=latents,
1809
+ diffusion_model=diffusion_model,
1810
+ controlnet=controlnet,
1811
+ cn_cond=cn_cond,
1812
+ scheduler=scheduler,
1813
+ save_intermediates=save_intermediates,
1814
+ conditioning=conditioning,
1815
+ mode=mode,
1816
+ verbose=verbose,
1817
+ seg=seg,
1818
+ )
1819
+
1820
+ if save_intermediates and resample_latent_likelihoods:
1821
+ intermediates = outputs[1]
1822
+ resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)
1823
+ intermediates = [resizer(x) for x in intermediates]
1824
+ outputs = (outputs[0], intermediates)
1825
+ return outputs
1826
+
1827
+
1828
+ class VQVAETransformerInferer(nn.Module):
1829
+ """
1830
+ Class to perform inference with a VQVAE + Transformer model.
1831
+ """
1832
+
1833
+ def __init__(self) -> None:
1834
+ Inferer.__init__(self)
1835
+
1836
+ def __call__(
1837
+ self,
1838
+ inputs: torch.Tensor,
1839
+ vqvae_model: VQVAE,
1840
+ transformer_model: DecoderOnlyTransformer,
1841
+ ordering: Ordering,
1842
+ condition: torch.Tensor | None = None,
1843
+ return_latent: bool = False,
1844
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]:
1845
+ """
1846
+ Implements the forward pass for a supervised training iteration.
1847
+
1848
+ Args:
1849
+ inputs: input image to which the latent representation will be extracted.
1850
+ vqvae_model: first stage model.
1851
+ transformer_model: autoregressive transformer model.
1852
+ ordering: ordering of the quantised latent representation.
1853
+ return_latent: also return latent sequence and spatial dim of the latent.
1854
+ condition: conditioning for network input.
1855
+ """
1856
+ with torch.no_grad():
1857
+ latent = vqvae_model.index_quantize(inputs)
1858
+
1859
+ latent_spatial_dim = tuple(latent.shape[1:])
1860
+ latent = latent.reshape(latent.shape[0], -1)
1861
+ latent = latent[:, ordering.get_sequence_ordering()]
1862
+
1863
+ # get the targets for the loss
1864
+ target = latent.clone()
1865
+ # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token.
1866
+ # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens.
1867
+ latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings)
1868
+ # crop the last token as we do not need the probability of the token that follows it
1869
+ latent = latent[:, :-1]
1870
+ latent = latent.long()
1871
+
1872
+ # train on a part of the sequence if it is longer than max_seq_length
1873
+ seq_len = latent.shape[1]
1874
+ max_seq_len = transformer_model.max_seq_len
1875
+ if max_seq_len < seq_len:
1876
+ start = int(torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item())
1877
+ else:
1878
+ start = 0
1879
+ prediction: torch.Tensor = transformer_model(x=latent[:, start : start + max_seq_len], context=condition)
1880
+ if return_latent:
1881
+ return prediction, target[:, start : start + max_seq_len], latent_spatial_dim
1882
+ else:
1883
+ return prediction
1884
+
1885
+ @torch.no_grad()
1886
+ def sample(
1887
+ self,
1888
+ latent_spatial_dim: tuple[int, int, int] | tuple[int, int],
1889
+ starting_tokens: torch.Tensor,
1890
+ vqvae_model: VQVAE,
1891
+ transformer_model: DecoderOnlyTransformer,
1892
+ ordering: Ordering,
1893
+ conditioning: torch.Tensor | None = None,
1894
+ temperature: float = 1.0,
1895
+ top_k: int | None = None,
1896
+ verbose: bool = True,
1897
+ ) -> torch.Tensor:
1898
+ """
1899
+ Sampling function for the VQVAE + Transformer model.
1900
+
1901
+ Args:
1902
+ latent_spatial_dim: shape of the sampled image.
1903
+ starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value.
1904
+ vqvae_model: first stage model.
1905
+ transformer_model: model to sample from.
1906
+ conditioning: Conditioning for network input.
1907
+ temperature: temperature for sampling.
1908
+ top_k: top k sampling.
1909
+ verbose: if true, prints the progression bar of the sampling process.
1910
+ """
1911
+ seq_len = math.prod(latent_spatial_dim)
1912
+
1913
+ if verbose and has_tqdm:
1914
+ progress_bar = tqdm(range(seq_len))
1915
+ else:
1916
+ progress_bar = iter(range(seq_len))
1917
+
1918
+ latent_seq = starting_tokens.long()
1919
+ for _ in progress_bar:
1920
+ # if the sequence context is growing too long we must crop it at block_size
1921
+ if latent_seq.size(1) <= transformer_model.max_seq_len:
1922
+ idx_cond = latent_seq
1923
+ else:
1924
+ idx_cond = latent_seq[:, -transformer_model.max_seq_len :]
1925
+
1926
+ # forward the model to get the logits for the index in the sequence
1927
+ logits = transformer_model(x=idx_cond, context=conditioning)
1928
+ # pluck the logits at the final step and scale by desired temperature
1929
+ logits = logits[:, -1, :] / temperature
1930
+ # optionally crop the logits to only the top k options
1931
+ if top_k is not None:
1932
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
1933
+ logits[logits < v[:, [-1]]] = -float("Inf")
1934
+ # apply softmax to convert logits to (normalized) probabilities
1935
+ probs = F.softmax(logits, dim=-1)
1936
+ # remove the chance to be sampled the BOS token
1937
+ probs[:, vqvae_model.num_embeddings] = 0
1938
+ # sample from the distribution
1939
+ idx_next = torch.multinomial(probs, num_samples=1)
1940
+ # append sampled index to the running sequence and continue
1941
+ latent_seq = torch.cat((latent_seq, idx_next), dim=1)
1942
+
1943
+ latent_seq = latent_seq[:, 1:]
1944
+ latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()]
1945
+ latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim)
1946
+
1947
+ return vqvae_model.decode_samples(latent)
1948
+
1949
+ @torch.no_grad()
1950
+ def get_likelihood(
1951
+ self,
1952
+ inputs: torch.Tensor,
1953
+ vqvae_model: VQVAE,
1954
+ transformer_model: DecoderOnlyTransformer,
1955
+ ordering: Ordering,
1956
+ condition: torch.Tensor | None = None,
1957
+ resample_latent_likelihoods: bool = False,
1958
+ resample_interpolation_mode: str = "nearest",
1959
+ verbose: bool = False,
1960
+ ) -> torch.Tensor:
1961
+ """
1962
+ Computes the log-likelihoods of the latent representations of the input.
1963
+
1964
+ Args:
1965
+ inputs: input images, NxCxHxW[xD]
1966
+ vqvae_model: first stage model.
1967
+ transformer_model: autoregressive transformer model.
1968
+ ordering: ordering of the quantised latent representation.
1969
+ condition: conditioning for network input.
1970
+ resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
1971
+ dimension as the input images.
1972
+ resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
1973
+ or 'trilinear;
1974
+ verbose: if true, prints the progression bar of the sampling process.
1975
+
1976
+ """
1977
+ if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
1978
+ raise ValueError(
1979
+ f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
1980
+ )
1981
+
1982
+ with torch.no_grad():
1983
+ latent = vqvae_model.index_quantize(inputs)
1984
+
1985
+ latent_spatial_dim = tuple(latent.shape[1:])
1986
+ latent = latent.reshape(latent.shape[0], -1)
1987
+ latent = latent[:, ordering.get_sequence_ordering()]
1988
+ seq_len = math.prod(latent_spatial_dim)
1989
+
1990
+ # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token.
1991
+ # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens.
1992
+ latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings)
1993
+ latent = latent.long()
1994
+
1995
+ # get the first batch, up to max_seq_length, efficiently
1996
+ logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition)
1997
+ probs = F.softmax(logits, dim=-1)
1998
+ # target token for each set of logits is the next token along
1999
+ target = latent[:, 1:]
2000
+ probs = torch.gather(probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2)
2001
+
2002
+ # if we have not covered the full sequence we continue with inefficient looping
2003
+ if probs.shape[1] < target.shape[1]:
2004
+ if verbose and has_tqdm:
2005
+ progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len))
2006
+ else:
2007
+ progress_bar = iter(range(transformer_model.max_seq_len, seq_len))
2008
+
2009
+ for i in progress_bar:
2010
+ idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1]
2011
+ # forward the model to get the logits for the index in the sequence
2012
+ logits = transformer_model(x=idx_cond, context=condition)
2013
+ # pluck the logits at the final step
2014
+ logits = logits[:, -1, :]
2015
+ # apply softmax to convert logits to (normalized) probabilities
2016
+ p = F.softmax(logits, dim=-1)
2017
+ # select correct values and append
2018
+ p = torch.gather(p, 1, target[:, i].unsqueeze(1))
2019
+
2020
+ probs = torch.cat((probs, p), dim=1)
2021
+
2022
+ # convert to log-likelihood
2023
+ probs = torch.log(probs)
2024
+
2025
+ # reshape
2026
+ probs = probs[:, ordering.get_revert_sequence_ordering()]
2027
+ probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim)
2028
+ if resample_latent_likelihoods:
2029
+ resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)
2030
+ probs_reshaped = resizer(probs_reshaped[:, None, ...])
2031
+
2032
+ return probs_reshaped