diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from packaging import version
22
22
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23
23
 
24
24
  from ...configuration_utils import FrozenDict
25
- from ...loaders import TextualInversionLoaderMixin
25
+ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
26
26
  from ...models import AutoencoderKL, UNet2DConditionModel
27
27
  from ...schedulers import KarrasDiffusionSchedulers
28
28
  from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
@@ -138,13 +138,20 @@ def prepare_mask_and_masked_image(image, mask):
138
138
  return mask, masked_image
139
139
 
140
140
 
141
- class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
141
+ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
142
142
  r"""
143
143
  Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
144
144
 
145
145
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
146
146
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
147
147
 
148
+ In addition the pipeline inherits the following loading methods:
149
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
150
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
151
+
152
+ as well as the following saving methods:
153
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
154
+
148
155
  Args:
149
156
  vae ([`AutoencoderKL`]):
150
157
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -22,7 +22,7 @@ from packaging import version
22
22
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23
23
 
24
24
  from ...configuration_utils import FrozenDict
25
- from ...loaders import TextualInversionLoaderMixin
25
+ from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
26
26
  from ...models import AutoencoderKL, UNet2DConditionModel
27
27
  from ...schedulers import KarrasDiffusionSchedulers
28
28
  from ...utils import (
@@ -41,17 +41,17 @@ from .safety_checker import StableDiffusionSafetyChecker
41
41
  logger = logging.get_logger(__name__)
42
42
 
43
43
 
44
- def preprocess_image(image):
44
+ def preprocess_image(image, batch_size):
45
45
  w, h = image.size
46
46
  w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
47
47
  image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
48
48
  image = np.array(image).astype(np.float32) / 255.0
49
- image = image[None].transpose(0, 3, 1, 2)
49
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
50
50
  image = torch.from_numpy(image)
51
51
  return 2.0 * image - 1.0
52
52
 
53
53
 
54
- def preprocess_mask(mask, scale_factor=8):
54
+ def preprocess_mask(mask, batch_size, scale_factor=8):
55
55
  if not isinstance(mask, torch.FloatTensor):
56
56
  mask = mask.convert("L")
57
57
  w, h = mask.size
@@ -59,7 +59,7 @@ def preprocess_mask(mask, scale_factor=8):
59
59
  mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
60
60
  mask = np.array(mask).astype(np.float32) / 255.0
61
61
  mask = np.tile(mask, (4, 1, 1))
62
- mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
62
+ mask = np.vstack([mask[None]] * batch_size)
63
63
  mask = 1 - mask # repaint white, keep black
64
64
  mask = torch.from_numpy(mask)
65
65
  return mask
@@ -82,13 +82,23 @@ def preprocess_mask(mask, scale_factor=8):
82
82
  return mask
83
83
 
84
84
 
85
- class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin):
85
+ class StableDiffusionInpaintPipelineLegacy(
86
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
87
+ ):
86
88
  r"""
87
89
  Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
88
90
 
89
91
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
90
92
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
91
93
 
94
+ In addition the pipeline inherits the following loading methods:
95
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
96
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
97
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
98
+
99
+ as well as the following saving methods:
100
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
101
+
92
102
  Args:
93
103
  vae ([`AutoencoderKL`]):
94
104
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -511,14 +521,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
511
521
 
512
522
  return timesteps, num_inference_steps - t_start
513
523
 
514
- def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
524
+ def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator):
515
525
  image = image.to(device=self.device, dtype=dtype)
516
526
  init_latent_dist = self.vae.encode(image).latent_dist
517
527
  init_latents = init_latent_dist.sample(generator=generator)
518
528
  init_latents = self.vae.config.scaling_factor * init_latents
519
529
 
520
530
  # Expand init_latents for batch_size and num_images_per_prompt
521
- init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
531
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
522
532
  init_latents_orig = init_latents
523
533
 
524
534
  # add noise to latents using the timesteps
@@ -649,9 +659,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
649
659
 
650
660
  # 4. Preprocess image and mask
651
661
  if not isinstance(image, torch.FloatTensor):
652
- image = preprocess_image(image)
662
+ image = preprocess_image(image, batch_size)
653
663
 
654
- mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
664
+ mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
655
665
 
656
666
  # 5. set timesteps
657
667
  self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -661,12 +671,12 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
661
671
  # 6. Prepare latent variables
662
672
  # encode the init image into latents and scale the latents
663
673
  latents, init_latents_orig, noise = self.prepare_latents(
664
- image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
674
+ image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
665
675
  )
666
676
 
667
677
  # 7. Prepare mask latent
668
678
  mask = mask_image.to(device=self.device, dtype=latents.dtype)
669
- mask = torch.cat([mask] * batch_size * num_images_per_prompt)
679
+ mask = torch.cat([mask] * num_images_per_prompt)
670
680
 
671
681
  # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
672
682
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -20,7 +20,7 @@ import PIL
20
20
  import torch
21
21
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
22
22
 
23
- from ...loaders import TextualInversionLoaderMixin
23
+ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
24
24
  from ...models import AutoencoderKL, UNet2DConditionModel
25
25
  from ...schedulers import KarrasDiffusionSchedulers
26
26
  from ...utils import (
@@ -61,13 +61,20 @@ def preprocess(image):
61
61
  return image
62
62
 
63
63
 
64
- class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
64
+ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
65
65
  r"""
66
66
  Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
67
67
 
68
68
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
69
69
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
70
70
 
71
+ In addition the pipeline inherits the following loading methods:
72
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
73
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
74
+
75
+ as well as the following saving methods:
76
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
77
+
71
78
  Args:
72
79
  vae ([`AutoencoderKL`]):
73
80
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -36,6 +36,7 @@ from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
36
36
  from ...utils import (
37
37
  PIL_INTERPOLATION,
38
38
  BaseOutput,
39
+ deprecate,
39
40
  is_accelerate_available,
40
41
  is_accelerate_version,
41
42
  logging,
@@ -721,23 +722,31 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
721
722
  )
722
723
 
723
724
  if isinstance(generator, list):
724
- init_latents = [
725
- self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
726
- ]
727
- init_latents = torch.cat(init_latents, dim=0)
725
+ latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]
726
+ latents = torch.cat(latents, dim=0)
728
727
  else:
729
- init_latents = self.vae.encode(image).latent_dist.sample(generator)
730
-
731
- init_latents = self.vae.config.scaling_factor * init_latents
732
-
733
- if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
734
- raise ValueError(
735
- f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
736
- )
728
+ latents = self.vae.encode(image).latent_dist.sample(generator)
729
+
730
+ latents = self.vae.config.scaling_factor * latents
731
+
732
+ if batch_size != latents.shape[0]:
733
+ if batch_size % latents.shape[0] == 0:
734
+ # expand image_latents for batch_size
735
+ deprecation_message = (
736
+ f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
737
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
738
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
739
+ " your script to pass as many initial images as text prompts to suppress this warning."
740
+ )
741
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
742
+ additional_latents_per_image = batch_size // latents.shape[0]
743
+ latents = torch.cat([latents] * additional_latents_per_image, dim=0)
744
+ else:
745
+ raise ValueError(
746
+ f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
747
+ )
737
748
  else:
738
- init_latents = torch.cat([init_latents], dim=0)
739
-
740
- latents = init_latents
749
+ latents = torch.cat([latents], dim=0)
741
750
 
742
751
  return latents
743
752
 
@@ -759,23 +768,18 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
759
768
  )
760
769
 
761
770
  def auto_corr_loss(self, hidden_states, generator=None):
762
- batch_size, channel, height, width = hidden_states.shape
763
- if batch_size > 1:
764
- raise ValueError("Only batch_size 1 is supported for now")
765
-
766
- hidden_states = hidden_states.squeeze(0)
767
- # hidden_states must be shape [C,H,W] now
768
771
  reg_loss = 0.0
769
772
  for i in range(hidden_states.shape[0]):
770
- noise = hidden_states[i][None, None, :, :]
771
- while True:
772
- roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
773
- reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
774
- reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
775
-
776
- if noise.shape[2] <= 8:
777
- break
778
- noise = F.avg_pool2d(noise, kernel_size=2)
773
+ for j in range(hidden_states.shape[1]):
774
+ noise = hidden_states[i : i + 1, j : j + 1, :, :]
775
+ while True:
776
+ roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
777
+ reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
778
+ reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
779
+
780
+ if noise.shape[2] <= 8:
781
+ break
782
+ noise = F.avg_pool2d(noise, kernel_size=2)
779
783
  return reg_loss
780
784
 
781
785
  def kl_divergence(self, hidden_states):
@@ -13,18 +13,20 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import Callable, List, Optional, Union
16
+ from typing import Any, Callable, List, Optional, Union
17
17
 
18
18
  import numpy as np
19
19
  import PIL
20
20
  import torch
21
- from transformers import CLIPTextModel, CLIPTokenizer
21
+ import torch.nn.functional as F
22
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
22
23
 
23
24
  from ...loaders import TextualInversionLoaderMixin
24
25
  from ...models import AutoencoderKL, UNet2DConditionModel
25
26
  from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
26
- from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
27
- from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
27
+ from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
28
+ from ..pipeline_utils import DiffusionPipeline
29
+ from . import StableDiffusionPipelineOutput
28
30
 
29
31
 
30
32
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -76,6 +78,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
76
78
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
77
79
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
78
80
  """
81
+ _optional_components = ["watermarker", "safety_checker", "feature_extractor"]
79
82
 
80
83
  def __init__(
81
84
  self,
@@ -85,12 +88,16 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
85
88
  unet: UNet2DConditionModel,
86
89
  low_res_scheduler: DDPMScheduler,
87
90
  scheduler: KarrasDiffusionSchedulers,
91
+ safety_checker: Optional[Any] = None,
92
+ feature_extractor: Optional[CLIPImageProcessor] = None,
93
+ watermarker: Optional[Any] = None,
88
94
  max_noise_level: int = 350,
89
95
  ):
90
96
  super().__init__()
91
97
 
92
- if hasattr(vae, "config"):
93
- # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
98
+ if hasattr(
99
+ vae, "config"
100
+ ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
94
101
  is_vae_scaling_factor_set_to_0_08333 = (
95
102
  hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
96
103
  )
@@ -113,6 +120,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
113
120
  unet=unet,
114
121
  low_res_scheduler=low_res_scheduler,
115
122
  scheduler=scheduler,
123
+ safety_checker=safety_checker,
124
+ watermarker=watermarker,
125
+ feature_extractor=feature_extractor,
116
126
  )
117
127
  self.register_to_config(max_noise_level=max_noise_level)
118
128
 
@@ -129,10 +139,36 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
129
139
 
130
140
  device = torch.device(f"cuda:{gpu_id}")
131
141
 
132
- for cpu_offloaded_model in [self.unet, self.text_encoder]:
142
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
133
143
  if cpu_offloaded_model is not None:
134
144
  cpu_offload(cpu_offloaded_model, device)
135
145
 
146
+ def enable_model_cpu_offload(self, gpu_id=0):
147
+ r"""
148
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
149
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
150
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
151
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
152
+ """
153
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
154
+ from accelerate import cpu_offload_with_hook
155
+ else:
156
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
157
+
158
+ device = torch.device(f"cuda:{gpu_id}")
159
+
160
+ if self.device.type != "cpu":
161
+ self.to("cpu", silence_dtype_warnings=True)
162
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
163
+
164
+ hook = None
165
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
166
+ if cpu_offloaded_model is not None:
167
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
168
+
169
+ # We'll offload the last model manually.
170
+ self.final_offload_hook = hook
171
+
136
172
  @property
137
173
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
138
174
  def _execution_device(self):
@@ -152,6 +188,23 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
152
188
  return torch.device(module._hf_hook.execution_device)
153
189
  return self.device
154
190
 
191
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.run_safety_checker
192
+ def run_safety_checker(self, image, device, dtype):
193
+ if self.safety_checker is not None:
194
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
195
+ image, nsfw_detected, watermark_detected = self.safety_checker(
196
+ images=image,
197
+ clip_input=safety_checker_input.pixel_values.to(dtype=dtype),
198
+ )
199
+ else:
200
+ nsfw_detected = None
201
+ watermark_detected = None
202
+
203
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
204
+ self.unet_offload_hook.offload()
205
+
206
+ return image, nsfw_detected, watermark_detected
207
+
155
208
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
156
209
  def _encode_prompt(
157
210
  self,
@@ -645,13 +698,43 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMi
645
698
  # 10. Post-processing
646
699
  # make sure the VAE is in float32 mode, as it overflows in float16
647
700
  self.vae.to(dtype=torch.float32)
648
- image = self.decode_latents(latents.float())
701
+
702
+ # TODO(Patrick, William) - clean up when attention is refactored
703
+ use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
704
+ use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
705
+ # if xformers or torch_2_0 is used attention block does not need
706
+ # to be in float32 which can save lots of memory
707
+ if not use_torch_2_0_attn and not use_xformers:
708
+ self.vae.post_quant_conv.to(latents.dtype)
709
+ self.vae.decoder.conv_in.to(latents.dtype)
710
+ self.vae.decoder.mid_block.to(latents.dtype)
711
+ else:
712
+ latents = latents.float()
649
713
 
650
714
  # 11. Convert to PIL
651
715
  if output_type == "pil":
716
+ image = self.decode_latents(latents)
717
+
718
+ image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype)
719
+
652
720
  image = self.numpy_to_pil(image)
653
721
 
722
+ # 11. Apply watermark
723
+ if self.watermarker is not None:
724
+ image = self.watermarker.apply_watermark(image)
725
+ elif output_type == "pt":
726
+ latents = 1 / self.vae.config.scaling_factor * latents
727
+ image = self.vae.decode(latents).sample
728
+ has_nsfw_concept = None
729
+ else:
730
+ image = self.decode_latents(latents)
731
+ has_nsfw_concept = None
732
+
733
+ # Offload last model to CPU
734
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
735
+ self.final_offload_hook.offload()
736
+
654
737
  if not return_dict:
655
- return (image,)
738
+ return (image, has_nsfw_concept)
656
739
 
657
- return ImagePipelineOutput(images=image)
740
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
@@ -15,7 +15,7 @@ from ...models.attention_processor import (
15
15
  AttnProcessor,
16
16
  )
17
17
  from ...models.dual_transformer_2d import DualTransformer2DModel
18
- from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
18
+ from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps
19
19
  from ...models.transformer_2d import Transformer2DModel
20
20
  from ...models.unet_2d_condition import UNet2DConditionOutput
21
21
  from ...utils import logging
@@ -183,11 +183,16 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
183
183
  class_embed_type (`str`, *optional*, defaults to None):
184
184
  The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
185
185
  `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
186
+ addition_embed_type (`str`, *optional*, defaults to None):
187
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
188
+ "text". "text" will use the `TextTimeEmbedding` layer.
186
189
  num_class_embeds (`int`, *optional*, defaults to None):
187
190
  Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
188
191
  class conditioning with `class_embed_type` equal to `None`.
189
192
  time_embedding_type (`str`, *optional*, default to `positional`):
190
193
  The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
194
+ time_embedding_dim (`int`, *optional*, default to `None`):
195
+ An optional override for the dimension of the projected time embedding.
191
196
  time_embedding_act_fn (`str`, *optional*, default to `None`):
192
197
  Optional activation function to use on the time embeddings only one time before they as passed to the rest
193
198
  of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
@@ -246,12 +251,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
246
251
  dual_cross_attention: bool = False,
247
252
  use_linear_projection: bool = False,
248
253
  class_embed_type: Optional[str] = None,
254
+ addition_embed_type: Optional[str] = None,
249
255
  num_class_embeds: Optional[int] = None,
250
256
  upcast_attention: bool = False,
251
257
  resnet_time_scale_shift: str = "default",
252
258
  resnet_skip_time_act: bool = False,
253
259
  resnet_out_scale_factor: int = 1.0,
254
260
  time_embedding_type: str = "positional",
261
+ time_embedding_dim: Optional[int] = None,
255
262
  time_embedding_act_fn: Optional[str] = None,
256
263
  timestep_post_act: Optional[str] = None,
257
264
  time_cond_proj_dim: Optional[int] = None,
@@ -261,6 +268,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
261
268
  class_embeddings_concat: bool = False,
262
269
  mid_block_only_cross_attention: Optional[bool] = None,
263
270
  cross_attention_norm: Optional[str] = None,
271
+ addition_embed_type_num_heads=64,
264
272
  ):
265
273
  super().__init__()
266
274
 
@@ -311,7 +319,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
311
319
 
312
320
  # time
313
321
  if time_embedding_type == "fourier":
314
- time_embed_dim = block_out_channels[0] * 2
322
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
315
323
  if time_embed_dim % 2 != 0:
316
324
  raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
317
325
  self.time_proj = GaussianFourierProjection(
@@ -319,7 +327,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
319
327
  )
320
328
  timestep_input_dim = time_embed_dim
321
329
  elif time_embedding_type == "positional":
322
- time_embed_dim = block_out_channels[0] * 4
330
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
323
331
 
324
332
  self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
325
333
  timestep_input_dim = block_out_channels[0]
@@ -345,7 +353,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
345
353
  if class_embed_type is None and num_class_embeds is not None:
346
354
  self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
347
355
  elif class_embed_type == "timestep":
348
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
356
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
349
357
  elif class_embed_type == "identity":
350
358
  self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
351
359
  elif class_embed_type == "projection":
@@ -370,6 +378,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
370
378
  else:
371
379
  self.class_embedding = None
372
380
 
381
+ if addition_embed_type == "text":
382
+ if encoder_hid_dim is not None:
383
+ text_time_embedding_from_dim = encoder_hid_dim
384
+ else:
385
+ text_time_embedding_from_dim = cross_attention_dim
386
+
387
+ self.add_embedding = TextTimeEmbedding(
388
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
389
+ )
390
+ elif addition_embed_type is not None:
391
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.")
392
+
373
393
  if time_embedding_act_fn is None:
374
394
  self.time_embed_act = None
375
395
  elif time_embedding_act_fn == "swish":
@@ -534,7 +554,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
534
554
  self.conv_norm_out = nn.GroupNorm(
535
555
  num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
536
556
  )
537
- self.conv_act = nn.SiLU()
557
+
558
+ if act_fn == "swish":
559
+ self.conv_act = lambda x: F.silu(x)
560
+ elif act_fn == "mish":
561
+ self.conv_act = nn.Mish()
562
+ elif act_fn == "silu":
563
+ self.conv_act = nn.SiLU()
564
+ elif act_fn == "gelu":
565
+ self.conv_act = nn.GELU()
566
+ else:
567
+ raise ValueError(f"Unsupported activation function: {act_fn}")
568
+
538
569
  else:
539
570
  self.conv_norm_out = None
540
571
  self.conv_act = None
@@ -745,7 +776,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
745
776
 
746
777
  t_emb = self.time_proj(timesteps)
747
778
 
748
- # timesteps does not contain any weights and will always return f32 tensors
779
+ # `Timesteps` does not contain any weights and will always return f32 tensors
749
780
  # but time_embedding might actually be running in fp16. so we need to cast here.
750
781
  # there might be better ways to encapsulate this.
751
782
  t_emb = t_emb.to(dtype=self.dtype)
@@ -759,6 +790,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
759
790
  if self.config.class_embed_type == "timestep":
760
791
  class_labels = self.time_proj(class_labels)
761
792
 
793
+ # `Timesteps` does not contain any weights and will always return f32 tensors
794
+ # there might be better ways to encapsulate this.
795
+ class_labels = class_labels.to(dtype=sample.dtype)
796
+
762
797
  class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
763
798
 
764
799
  if self.config.class_embeddings_concat:
@@ -766,6 +801,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
766
801
  else:
767
802
  emb = emb + class_emb
768
803
 
804
+ if self.config.addition_embed_type == "text":
805
+ aug_emb = self.add_embedding(encoder_hidden_states)
806
+ emb = emb + aug_emb
807
+
769
808
  if self.time_embed_act is not None:
770
809
  emb = self.time_embed_act(emb)
771
810