diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from packaging import version
22
22
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23
23
 
24
24
  from ...configuration_utils import FrozenDict
25
- from ...loaders import TextualInversionLoaderMixin
25
+ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
26
26
  from ...models import AutoencoderKL, UNet2DConditionModel
27
27
  from ...schedulers import KarrasDiffusionSchedulers
28
28
  from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
@@ -138,13 +138,20 @@ def prepare_mask_and_masked_image(image, mask):
138
138
  return mask, masked_image
139
139
 
140
140
 
141
- class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
141
+ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
142
142
  r"""
143
143
  Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
144
144
 
145
145
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
146
146
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
147
147
 
148
+ In addition the pipeline inherits the following loading methods:
149
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
150
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
151
+
152
+ as well as the following saving methods:
153
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
154
+
148
155
  Args:
149
156
  vae ([`AutoencoderKL`]):
150
157
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -22,7 +22,7 @@ from packaging import version
22
22
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23
23
 
24
24
  from ...configuration_utils import FrozenDict
25
- from ...loaders import TextualInversionLoaderMixin
25
+ from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
26
26
  from ...models import AutoencoderKL, UNet2DConditionModel
27
27
  from ...schedulers import KarrasDiffusionSchedulers
28
28
  from ...utils import (
@@ -41,17 +41,17 @@ from .safety_checker import StableDiffusionSafetyChecker
41
41
  logger = logging.get_logger(__name__)
42
42
 
43
43
 
44
- def preprocess_image(image):
44
+ def preprocess_image(image, batch_size):
45
45
  w, h = image.size
46
46
  w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
47
47
  image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
48
48
  image = np.array(image).astype(np.float32) / 255.0
49
- image = image[None].transpose(0, 3, 1, 2)
49
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
50
50
  image = torch.from_numpy(image)
51
51
  return 2.0 * image - 1.0
52
52
 
53
53
 
54
- def preprocess_mask(mask, scale_factor=8):
54
+ def preprocess_mask(mask, batch_size, scale_factor=8):
55
55
  if not isinstance(mask, torch.FloatTensor):
56
56
  mask = mask.convert("L")
57
57
  w, h = mask.size
@@ -59,7 +59,7 @@ def preprocess_mask(mask, scale_factor=8):
59
59
  mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
60
60
  mask = np.array(mask).astype(np.float32) / 255.0
61
61
  mask = np.tile(mask, (4, 1, 1))
62
- mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
62
+ mask = np.vstack([mask[None]] * batch_size)
63
63
  mask = 1 - mask # repaint white, keep black
64
64
  mask = torch.from_numpy(mask)
65
65
  return mask
@@ -82,13 +82,23 @@ def preprocess_mask(mask, scale_factor=8):
82
82
  return mask
83
83
 
84
84
 
85
- class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin):
85
+ class StableDiffusionInpaintPipelineLegacy(
86
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
87
+ ):
86
88
  r"""
87
89
  Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
88
90
 
89
91
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
90
92
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
91
93
 
94
+ In addition the pipeline inherits the following loading methods:
95
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
96
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
97
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
98
+
99
+ as well as the following saving methods:
100
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
101
+
92
102
  Args:
93
103
  vae ([`AutoencoderKL`]):
94
104
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -511,14 +521,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
511
521
 
512
522
  return timesteps, num_inference_steps - t_start
513
523
 
514
- def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
524
+ def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator):
515
525
  image = image.to(device=self.device, dtype=dtype)
516
526
  init_latent_dist = self.vae.encode(image).latent_dist
517
527
  init_latents = init_latent_dist.sample(generator=generator)
518
528
  init_latents = self.vae.config.scaling_factor * init_latents
519
529
 
520
530
  # Expand init_latents for batch_size and num_images_per_prompt
521
- init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
531
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
522
532
  init_latents_orig = init_latents
523
533
 
524
534
  # add noise to latents using the timesteps
@@ -649,9 +659,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
649
659
 
650
660
  # 4. Preprocess image and mask
651
661
  if not isinstance(image, torch.FloatTensor):
652
- image = preprocess_image(image)
662
+ image = preprocess_image(image, batch_size)
653
663
 
654
- mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
664
+ mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
655
665
 
656
666
  # 5. set timesteps
657
667
  self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -661,12 +671,12 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
661
671
  # 6. Prepare latent variables
662
672
  # encode the init image into latents and scale the latents
663
673
  latents, init_latents_orig, noise = self.prepare_latents(
664
- image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
674
+ image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
665
675
  )
666
676
 
667
677
  # 7. Prepare mask latent
668
678
  mask = mask_image.to(device=self.device, dtype=latents.dtype)
669
- mask = torch.cat([mask] * batch_size * num_images_per_prompt)
679
+ mask = torch.cat([mask] * num_images_per_prompt)
670
680
 
671
681
  # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
672
682
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -20,7 +20,7 @@ import PIL
20
20
  import torch
21
21
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
22
22
 
23
- from ...loaders import TextualInversionLoaderMixin
23
+ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
24
24
  from ...models import AutoencoderKL, UNet2DConditionModel
25
25
  from ...schedulers import KarrasDiffusionSchedulers
26
26
  from ...utils import (
@@ -61,13 +61,20 @@ def preprocess(image):
61
61
  return image
62
62
 
63
63
 
64
- class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
64
+ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
65
65
  r"""
66
66
  Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
67
67
 
68
68
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
69
69
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
70
70
 
71
+ In addition the pipeline inherits the following loading methods:
72
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
73
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
74
+
75
+ as well as the following saving methods:
76
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
77
+
71
78
  Args:
72
79
  vae ([`AutoencoderKL`]):
73
80
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -36,6 +36,7 @@ from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
36
36
  from ...utils import (
37
37
  PIL_INTERPOLATION,
38
38
  BaseOutput,
39
+ deprecate,
39
40
  is_accelerate_available,
40
41
  is_accelerate_version,
41
42
  logging,
@@ -721,23 +722,31 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
721
722
  )
722
723
 
723
724
  if isinstance(generator, list):
724
- init_latents = [
725
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
726
- ]
727
- init_latents = torch.cat(init_latents, dim=0)
725
+ latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]
726
+ latents = torch.cat(latents, dim=0)
728
727
  else:
729
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
730
-
731
- init_latents = self.vae.config.scaling_factor * init_latents
732
-
733
- if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
734
- raise ValueError(
735
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
736
- )
728
+ latents = self.vae.encode(image).latent_dist.sample(generator)
729
+
730
+ latents = self.vae.config.scaling_factor * latents
731
+
732
+ if batch_size != latents.shape[0]:
733
+ if batch_size % latents.shape[0] == 0:
734
+ # expand image_latents for batch_size
735
+ deprecation_message = (
736
+ f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
737
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
738
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
739
+ " your script to pass as many initial images as text prompts to suppress this warning."
740
+ )
741
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
742
+ additional_latents_per_image = batch_size // latents.shape[0]
743
+ latents = torch.cat([latents] * additional_latents_per_image, dim=0)
744
+ else:
745
+ raise ValueError(
746
+ f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
747
+ )
737
748
  else:
738
- init_latents = torch.cat([init_latents], dim=0)
739
-
740
- latents = init_latents
749
+ latents = torch.cat([latents], dim=0)
741
750
 
742
751
  return latents
743
752
 
@@ -759,23 +768,18 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
759
768
  )
760
769
 
761
770
  def auto_corr_loss(self, hidden_states, generator=None):
762
- batch_size, channel, height, width = hidden_states.shape
763
- if batch_size > 1:
764
- raise ValueError("Only batch_size 1 is supported for now")
765
-
766
- hidden_states = hidden_states.squeeze(0)
767
- # hidden_states must be shape [C,H,W] now
768
771
  reg_loss = 0.0
769
772
  for i in range(hidden_states.shape[0]):
770
- noise = hidden_states[i][None, None, :, :]
771
- while True:
772
- roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
773
- reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
774
- reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
775
-
776
- if noise.shape[2] <= 8:
777
- break
778
- noise = F.avg_pool2d(noise, kernel_size=2)
773
+ for j in range(hidden_states.shape[1]):
774
+ noise = hidden_states[i : i + 1, j : j + 1, :, :]
775
+ while True:
776
+ roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
777
+ reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
778
+ reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
779
+
780
+ if noise.shape[2] <= 8:
781
+ break
782
+ noise = F.avg_pool2d(noise, kernel_size=2)
779
783
  return reg_loss
780
784
 
781
785
  def kl_divergence(self, hidden_states):
@@ -13,18 +13,20 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import Callable, List, Optional, Union
16
+ from typing import Any, Callable, List, Optional, Union
17
17
 
18
18
  import numpy as np
19
19
  import PIL
20
20
  import torch
21
- from transformers import CLIPTextModel, CLIPTokenizer
21
+ import torch.nn.functional as F
22
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
22
23
 
23
24
  from ...loaders import TextualInversionLoaderMixin
24
25
  from ...models import AutoencoderKL, UNet2DConditionModel
25
26
  from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
26
- from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
27
- from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
27
+ from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
28
+ from ..pipeline_utils import DiffusionPipeline
29
+ from . import StableDiffusionPipelineOutput
28
30
 
29
31
 
30
32
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -76,6 +78,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
76
78
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
77
79
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
78
80
  """
81
+ _optional_components = ["watermarker", "safety_checker", "feature_extractor"]
79
82
 
80
83
  def __init__(
81
84
  self,
@@ -85,12 +88,16 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
85
88
  unet: UNet2DConditionModel,
86
89
  low_res_scheduler: DDPMScheduler,
87
90
  scheduler: KarrasDiffusionSchedulers,
91
+ safety_checker: Optional[Any] = None,
92
+ feature_extractor: Optional[CLIPImageProcessor] = None,
93
+ watermarker: Optional[Any] = None,
88
94
  max_noise_level: int = 350,
89
95
  ):
90
96
  super().__init__()
91
97
 
92
- if hasattr(vae, "config"):
93
- # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
98
+ if hasattr(
99
+ vae, "config"
100
+ ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
94
101
  is_vae_scaling_factor_set_to_0_08333 = (
95
102
  hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
96
103
  )
@@ -113,6 +120,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
113
120
  unet=unet,
114
121
  low_res_scheduler=low_res_scheduler,
115
122
  scheduler=scheduler,
123
+ safety_checker=safety_checker,
124
+ watermarker=watermarker,
125
+ feature_extractor=feature_extractor,
116
126
  )
117
127
  self.register_to_config(max_noise_level=max_noise_level)
118
128
 
@@ -129,10 +139,36 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
129
139
 
130
140
  device = torch.device(f"cuda:{gpu_id}")
131
141
 
132
- for cpu_offloaded_model in [self.unet, self.text_encoder]:
142
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
133
143
  if cpu_offloaded_model is not None:
134
144
  cpu_offload(cpu_offloaded_model, device)
135
145
 
146
+ def enable_model_cpu_offload(self, gpu_id=0):
147
+ r"""
148
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
149
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
150
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
151
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
152
+ """
153
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
154
+ from accelerate import cpu_offload_with_hook
155
+ else:
156
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
157
+
158
+ device = torch.device(f"cuda:{gpu_id}")
159
+
160
+ if self.device.type != "cpu":
161
+ self.to("cpu", silence_dtype_warnings=True)
162
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
163
+
164
+ hook = None
165
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
166
+ if cpu_offloaded_model is not None:
167
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
168
+
169
+ # We'll offload the last model manually.
170
+ self.final_offload_hook = hook
171
+
136
172
  @property
137
173
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
138
174
  def _execution_device(self):
@@ -152,6 +188,23 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
152
188
  return torch.device(module._hf_hook.execution_device)
153
189
  return self.device
154
190
 
191
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
192
+ def run_safety_checker(self, image, device, dtype):
193
+ if self.safety_checker is not None:
194
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
195
+ image, nsfw_detected, watermark_detected = self.safety_checker(
196
+ images=image,
197
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
198
+ )
199
+ else:
200
+ nsfw_detected = None
201
+ watermark_detected = None
202
+
203
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
204
+ self.unet_offload_hook.offload()
205
+
206
+ return image, nsfw_detected, watermark_detected
207
+
155
208
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
156
209
  def _encode_prompt(
157
210
  self,
@@ -645,13 +698,43 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
645
698
  # 10. Post-processing
646
699
  # make sure the VAE is in float32 mode, as it overflows in float16
647
700
  self.vae.to(dtype=torch.float32)
648
- image = self.decode_latents(latents.float())
701
+
702
+ # TODO(Patrick, William) - clean up when attention is refactored
703
+ use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
704
+ use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
705
+ # if xformers or torch_2_0 is used attention block does not need
706
+ # to be in float32 which can save lots of memory
707
+ if not use_torch_2_0_attn and not use_xformers:
708
+ self.vae.post_quant_conv.to(latents.dtype)
709
+ self.vae.decoder.conv_in.to(latents.dtype)
710
+ self.vae.decoder.mid_block.to(latents.dtype)
711
+ else:
712
+ latents = latents.float()
649
713
 
650
714
  # 11. Convert to PIL
651
715
  if output_type == "pil":
716
+ image = self.decode_latents(latents)
717
+
718
+ image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
719
+
652
720
  image = self.numpy_to_pil(image)
653
721
 
722
+ # 11. Apply watermark
723
+ if self.watermarker is not None:
724
+ image = self.watermarker.apply_watermark(image)
725
+ elif output_type == "pt":
726
+ latents = 1 / self.vae.config.scaling_factor * latents
727
+ image = self.vae.decode(latents).sample
728
+ has_nsfw_concept = None
729
+ else:
730
+ image = self.decode_latents(latents)
731
+ has_nsfw_concept = None
732
+
733
+ # Offload last model to CPU
734
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
735
+ self.final_offload_hook.offload()
736
+
654
737
  if not return_dict:
655
- return (image,)
738
+ return (image, has_nsfw_concept)
656
739
 
657
- return ImagePipelineOutput(images=image)
740
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
@@ -15,7 +15,7 @@ from ...models.attention_processor import (
15
15
  AttnProcessor,
16
16
  )
17
17
  from ...models.dual_transformer_2d import DualTransformer2DModel
18
- from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
18
+ from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
19
19
  from ...models.transformer_2d import Transformer2DModel
20
20
  from ...models.unet_2d_condition import UNet2DConditionOutput
21
21
  from ...utils import logging
@@ -183,11 +183,16 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
183
183
  class_embed_type (`str`, *optional*, defaults to None):
184
184
  The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
185
185
  `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
186
+ addition_embed_type (`str`, *optional*, defaults to None):
187
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
188
+ "text". "text" will use the `TextTimeEmbedding` layer.
186
189
  num_class_embeds (`int`, *optional*, defaults to None):
187
190
  Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
188
191
  class conditioning with `class_embed_type` equal to `None`.
189
192
  time_embedding_type (`str`, *optional*, default to `positional`):
190
193
  The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
194
+ time_embedding_dim (`int`, *optional*, default to `None`):
195
+ An optional override for the dimension of the projected time embedding.
191
196
  time_embedding_act_fn (`str`, *optional*, default to `None`):
192
197
  Optional activation function to use on the time embeddings only one time before they as passed to the rest
193
198
  of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
@@ -246,12 +251,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
246
251
  dual_cross_attention: bool = False,
247
252
  use_linear_projection: bool = False,
248
253
  class_embed_type: Optional[str] = None,
254
+ addition_embed_type: Optional[str] = None,
249
255
  num_class_embeds: Optional[int] = None,
250
256
  upcast_attention: bool = False,
251
257
  resnet_time_scale_shift: str = "default",
252
258
  resnet_skip_time_act: bool = False,
253
259
  resnet_out_scale_factor: int = 1.0,
254
260
  time_embedding_type: str = "positional",
261
+ time_embedding_dim: Optional[int] = None,
255
262
  time_embedding_act_fn: Optional[str] = None,
256
263
  timestep_post_act: Optional[str] = None,
257
264
  time_cond_proj_dim: Optional[int] = None,
@@ -261,6 +268,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
261
268
  class_embeddings_concat: bool = False,
262
269
  mid_block_only_cross_attention: Optional[bool] = None,
263
270
  cross_attention_norm: Optional[str] = None,
271
+ addition_embed_type_num_heads=64,
264
272
  ):
265
273
  super().__init__()
266
274
 
@@ -311,7 +319,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
311
319
 
312
320
  # time
313
321
  if time_embedding_type == "fourier":
314
- time_embed_dim = block_out_channels[0] * 2
322
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
315
323
  if time_embed_dim % 2 != 0:
316
324
  raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
317
325
  self.time_proj = GaussianFourierProjection(
@@ -319,7 +327,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
319
327
  )
320
328
  timestep_input_dim = time_embed_dim
321
329
  elif time_embedding_type == "positional":
322
- time_embed_dim = block_out_channels[0] * 4
330
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
323
331
 
324
332
  self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
325
333
  timestep_input_dim = block_out_channels[0]
@@ -345,7 +353,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
345
353
  if class_embed_type is None and num_class_embeds is not None:
346
354
  self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
347
355
  elif class_embed_type == "timestep":
348
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
356
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
349
357
  elif class_embed_type == "identity":
350
358
  self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
351
359
  elif class_embed_type == "projection":
@@ -370,6 +378,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
370
378
  else:
371
379
  self.class_embedding = None
372
380
 
381
+ if addition_embed_type == "text":
382
+ if encoder_hid_dim is not None:
383
+ text_time_embedding_from_dim = encoder_hid_dim
384
+ else:
385
+ text_time_embedding_from_dim = cross_attention_dim
386
+
387
+ self.add_embedding = TextTimeEmbedding(
388
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
389
+ )
390
+ elif addition_embed_type is not None:
391
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
392
+
373
393
  if time_embedding_act_fn is None:
374
394
  self.time_embed_act = None
375
395
  elif time_embedding_act_fn == "swish":
@@ -534,7 +554,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
534
554
  self.conv_norm_out = nn.GroupNorm(
535
555
  num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
536
556
  )
537
- self.conv_act = nn.SiLU()
557
+
558
+ if act_fn == "swish":
559
+ self.conv_act = lambda x: F.silu(x)
560
+ elif act_fn == "mish":
561
+ self.conv_act = nn.Mish()
562
+ elif act_fn == "silu":
563
+ self.conv_act = nn.SiLU()
564
+ elif act_fn == "gelu":
565
+ self.conv_act = nn.GELU()
566
+ else:
567
+ raise ValueError(f"Unsupported activation function: {act_fn}")
568
+
538
569
  else:
539
570
  self.conv_norm_out = None
540
571
  self.conv_act = None
@@ -745,7 +776,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
745
776
 
746
777
  t_emb = self.time_proj(timesteps)
747
778
 
748
- # timesteps does not contain any weights and will always return f32 tensors
779
+ # `Timesteps` does not contain any weights and will always return f32 tensors
749
780
  # but time_embedding might actually be running in fp16. so we need to cast here.
750
781
  # there might be better ways to encapsulate this.
751
782
  t_emb = t_emb.to(dtype=self.dtype)
@@ -759,6 +790,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
759
790
  if self.config.class_embed_type == "timestep":
760
791
  class_labels = self.time_proj(class_labels)
761
792
 
793
+ # `Timesteps` does not contain any weights and will always return f32 tensors
794
+ # there might be better ways to encapsulate this.
795
+ class_labels = class_labels.to(dtype=sample.dtype)
796
+
762
797
  class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
763
798
 
764
799
  if self.config.class_embeddings_concat:
@@ -766,6 +801,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
766
801
  else:
767
802
  emb = emb + class_emb
768
803
 
804
+ if self.config.addition_embed_type == "text":
805
+ aug_emb = self.add_embedding(encoder_hidden_states)
806
+ emb = emb + aug_emb
807
+
769
808
  if self.time_embed_act is not None:
770
809
  emb = self.time_embed_act(emb)
771
810