diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
28
28
  from ...utils import (
29
29
  USE_PEFT_BACKEND,
30
30
  deprecate,
31
+ is_torch_xla_available,
31
32
  logging,
32
33
  replace_example_docstring,
33
34
  scale_lora_layers,
@@ -39,6 +40,13 @@ from .pipeline_output import StableDiffusionPipelineOutput
39
40
  from .safety_checker import StableDiffusionSafetyChecker
40
41
 
41
42
 
43
+ if is_torch_xla_available():
44
+ import torch_xla.core.xla_model as xm
45
+
46
+ XLA_AVAILABLE = True
47
+ else:
48
+ XLA_AVAILABLE = False
49
+
42
50
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
51
 
44
52
  EXAMPLE_DOC_STRING = """
@@ -57,9 +65,21 @@ EXAMPLE_DOC_STRING = """
57
65
 
58
66
 
59
67
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
60
- """
61
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
62
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
68
+ r"""
69
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
70
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
71
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
72
+
73
+ Args:
74
+ noise_cfg (`torch.Tensor`):
75
+ The predicted noise tensor for the guided diffusion process.
76
+ noise_pred_text (`torch.Tensor`):
77
+ The predicted noise tensor for the text-guided diffusion process.
78
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
79
+ A rescale factor applied to the noise predictions.
80
+
81
+ Returns:
82
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
63
83
  """
64
84
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
65
85
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -78,7 +98,7 @@ def retrieve_timesteps(
78
98
  sigmas: Optional[List[float]] = None,
79
99
  **kwargs,
80
100
  ):
81
- """
101
+ r"""
82
102
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
83
103
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
84
104
 
@@ -137,7 +157,7 @@ class StableDiffusionPipeline(
137
157
  IPAdapterMixin,
138
158
  FromSingleFileMixin,
139
159
  ):
140
- r"""
160
+ """
141
161
  Pipeline for text-to-image generation using Stable Diffusion.
142
162
 
143
163
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
@@ -1036,6 +1056,9 @@ class StableDiffusionPipeline(
1036
1056
  step_idx = i // getattr(self.scheduler, "order", 1)
1037
1057
  callback(step_idx, t, latents)
1038
1058
 
1059
+ if XLA_AVAILABLE:
1060
+ xm.mark_step()
1061
+
1039
1062
  if not output_type == "latent":
1040
1063
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1041
1064
  0
@@ -1049,7 +1072,6 @@ class StableDiffusionPipeline(
1049
1072
  do_denormalize = [True] * image.shape[0]
1050
1073
  else:
1051
1074
  do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1052
-
1053
1075
  image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1054
1076
 
1055
1077
  # Offload all models
@@ -119,7 +119,7 @@ def retrieve_timesteps(
119
119
  sigmas: Optional[List[float]] = None,
120
120
  **kwargs,
121
121
  ):
122
- """
122
+ r"""
123
123
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
124
124
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
125
125
 
@@ -60,7 +60,7 @@ def retrieve_timesteps(
60
60
  sigmas: Optional[List[float]] = None,
61
61
  **kwargs,
62
62
  ):
63
- """
63
+ r"""
64
64
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
65
65
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
66
66
 
@@ -33,6 +33,20 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
33
33
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
34
 
35
35
 
36
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
37
+ def retrieve_latents(
38
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
39
+ ):
40
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
41
+ return encoder_output.latent_dist.sample(generator)
42
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
43
+ return encoder_output.latent_dist.mode()
44
+ elif hasattr(encoder_output, "latents"):
45
+ return encoder_output.latents
46
+ else:
47
+ raise AttributeError("Could not access latents of provided encoder_output")
48
+
49
+
36
50
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
37
51
  def preprocess(image):
38
52
  warnings.warn(
@@ -105,7 +119,54 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
105
119
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
106
120
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
107
121
 
108
- def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt):
122
+ def _encode_prompt(
123
+ self,
124
+ prompt,
125
+ device,
126
+ do_classifier_free_guidance,
127
+ negative_prompt=None,
128
+ prompt_embeds: Optional[torch.Tensor] = None,
129
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
130
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
131
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
132
+ **kwargs,
133
+ ):
134
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
135
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
136
+
137
+ (
138
+ prompt_embeds,
139
+ negative_prompt_embeds,
140
+ pooled_prompt_embeds,
141
+ negative_pooled_prompt_embeds,
142
+ ) = self.encode_prompt(
143
+ prompt=prompt,
144
+ device=device,
145
+ do_classifier_free_guidance=do_classifier_free_guidance,
146
+ negative_prompt=negative_prompt,
147
+ prompt_embeds=prompt_embeds,
148
+ negative_prompt_embeds=negative_prompt_embeds,
149
+ pooled_prompt_embeds=pooled_prompt_embeds,
150
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
151
+ **kwargs,
152
+ )
153
+
154
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
155
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
156
+
157
+ return prompt_embeds, pooled_prompt_embeds
158
+
159
+ def encode_prompt(
160
+ self,
161
+ prompt,
162
+ device,
163
+ do_classifier_free_guidance,
164
+ negative_prompt=None,
165
+ prompt_embeds: Optional[torch.Tensor] = None,
166
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
167
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
168
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
169
+ ):
109
170
  r"""
110
171
  Encodes the prompt into text encoder hidden states.
111
172
 
@@ -119,81 +180,100 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
119
180
  negative_prompt (`str` or `List[str]`):
120
181
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
121
182
  if `guidance_scale` is less than `1`).
183
+ prompt_embeds (`torch.FloatTensor`, *optional*):
184
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
185
+ provided, text embeddings will be generated from `prompt` input argument.
186
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
187
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
188
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
189
+ argument.
190
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
191
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
192
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
193
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
194
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
195
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
196
+ input argument.
122
197
  """
123
- batch_size = len(prompt) if isinstance(prompt, list) else 1
124
-
125
- text_inputs = self.tokenizer(
126
- prompt,
127
- padding="max_length",
128
- max_length=self.tokenizer.model_max_length,
129
- truncation=True,
130
- return_length=True,
131
- return_tensors="pt",
132
- )
133
- text_input_ids = text_inputs.input_ids
134
-
135
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
136
-
137
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
138
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
139
- logger.warning(
140
- "The following part of your input was truncated because CLIP can only handle sequences up to"
141
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
142
- )
143
-
144
- text_encoder_out = self.text_encoder(
145
- text_input_ids.to(device),
146
- output_hidden_states=True,
147
- )
148
- text_embeddings = text_encoder_out.hidden_states[-1]
149
- text_pooler_out = text_encoder_out.pooler_output
150
-
151
- # get unconditional embeddings for classifier free guidance
152
- if do_classifier_free_guidance:
153
- uncond_tokens: List[str]
154
- if negative_prompt is None:
155
- uncond_tokens = [""] * batch_size
156
- elif type(prompt) is not type(negative_prompt):
157
- raise TypeError(
158
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
159
- f" {type(prompt)}."
160
- )
161
- elif isinstance(negative_prompt, str):
162
- uncond_tokens = [negative_prompt]
163
- elif batch_size != len(negative_prompt):
164
- raise ValueError(
165
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
166
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
167
- " the batch size of `prompt`."
168
- )
169
- else:
170
- uncond_tokens = negative_prompt
198
+ if prompt is not None and isinstance(prompt, str):
199
+ batch_size = 1
200
+ elif prompt is not None and isinstance(prompt, list):
201
+ batch_size = len(prompt)
202
+ else:
203
+ batch_size = prompt_embeds.shape[0]
171
204
 
172
- max_length = text_input_ids.shape[-1]
173
- uncond_input = self.tokenizer(
174
- uncond_tokens,
205
+ if prompt_embeds is None or pooled_prompt_embeds is None:
206
+ text_inputs = self.tokenizer(
207
+ prompt,
175
208
  padding="max_length",
176
- max_length=max_length,
209
+ max_length=self.tokenizer.model_max_length,
177
210
  truncation=True,
178
211
  return_length=True,
179
212
  return_tensors="pt",
180
213
  )
214
+ text_input_ids = text_inputs.input_ids
181
215
 
182
- uncond_encoder_out = self.text_encoder(
183
- uncond_input.input_ids.to(device),
216
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
217
+
218
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
219
+ text_input_ids, untruncated_ids
220
+ ):
221
+ removed_text = self.tokenizer.batch_decode(
222
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
223
+ )
224
+ logger.warning(
225
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
226
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
227
+ )
228
+
229
+ text_encoder_out = self.text_encoder(
230
+ text_input_ids.to(device),
184
231
  output_hidden_states=True,
185
232
  )
233
+ prompt_embeds = text_encoder_out.hidden_states[-1]
234
+ pooled_prompt_embeds = text_encoder_out.pooler_output
186
235
 
187
- uncond_embeddings = uncond_encoder_out.hidden_states[-1]
188
- uncond_pooler_out = uncond_encoder_out.pooler_output
236
+ # get unconditional embeddings for classifier free guidance
237
+ if do_classifier_free_guidance:
238
+ if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
239
+ uncond_tokens: List[str]
240
+ if negative_prompt is None:
241
+ uncond_tokens = [""] * batch_size
242
+ elif type(prompt) is not type(negative_prompt):
243
+ raise TypeError(
244
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
245
+ f" {type(prompt)}."
246
+ )
247
+ elif isinstance(negative_prompt, str):
248
+ uncond_tokens = [negative_prompt]
249
+ elif batch_size != len(negative_prompt):
250
+ raise ValueError(
251
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
252
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
253
+ " the batch size of `prompt`."
254
+ )
255
+ else:
256
+ uncond_tokens = negative_prompt
257
+
258
+ max_length = text_input_ids.shape[-1]
259
+ uncond_input = self.tokenizer(
260
+ uncond_tokens,
261
+ padding="max_length",
262
+ max_length=max_length,
263
+ truncation=True,
264
+ return_length=True,
265
+ return_tensors="pt",
266
+ )
267
+
268
+ uncond_encoder_out = self.text_encoder(
269
+ uncond_input.input_ids.to(device),
270
+ output_hidden_states=True,
271
+ )
189
272
 
190
- # For classifier free guidance, we need to do two forward passes.
191
- # Here we concatenate the unconditional and text embeddings into a single batch
192
- # to avoid doing two forward passes
193
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
194
- text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out])
273
+ negative_prompt_embeds = uncond_encoder_out.hidden_states[-1]
274
+ negative_pooled_prompt_embeds = uncond_encoder_out.pooler_output
195
275
 
196
- return text_embeddings, text_pooler_out
276
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
197
277
 
198
278
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
199
279
  def decode_latents(self, latents):
@@ -207,12 +287,56 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
207
287
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
208
288
  return image
209
289
 
210
- def check_inputs(self, prompt, image, callback_steps):
211
- if not isinstance(prompt, str) and not isinstance(prompt, list):
290
+ def check_inputs(
291
+ self,
292
+ prompt,
293
+ image,
294
+ callback_steps,
295
+ negative_prompt=None,
296
+ prompt_embeds=None,
297
+ negative_prompt_embeds=None,
298
+ pooled_prompt_embeds=None,
299
+ negative_pooled_prompt_embeds=None,
300
+ ):
301
+ if prompt is not None and prompt_embeds is not None:
302
+ raise ValueError(
303
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
304
+ " only forward one of the two."
305
+ )
306
+ elif prompt is None and prompt_embeds is None:
307
+ raise ValueError(
308
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
309
+ )
310
+ elif prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list):
212
311
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
213
312
 
313
+ if negative_prompt is not None and negative_prompt_embeds is not None:
314
+ raise ValueError(
315
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
316
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
317
+ )
318
+
319
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
320
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
321
+ raise ValueError(
322
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
323
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
324
+ f" {negative_prompt_embeds.shape}."
325
+ )
326
+
327
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
328
+ raise ValueError(
329
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
330
+ )
331
+
332
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
333
+ raise ValueError(
334
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
335
+ )
336
+
214
337
  if (
215
338
  not isinstance(image, torch.Tensor)
339
+ and not isinstance(image, np.ndarray)
216
340
  and not isinstance(image, PIL.Image.Image)
217
341
  and not isinstance(image, list)
218
342
  ):
@@ -222,10 +346,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
222
346
 
223
347
  # verify batch size of prompt and image are same if image is a list or tensor
224
348
  if isinstance(image, (list, torch.Tensor)):
225
- if isinstance(prompt, str):
226
- batch_size = 1
349
+ if prompt is not None:
350
+ if isinstance(prompt, str):
351
+ batch_size = 1
352
+ else:
353
+ batch_size = len(prompt)
227
354
  else:
228
- batch_size = len(prompt)
355
+ batch_size = prompt_embeds.shape[0]
356
+
229
357
  if isinstance(image, list):
230
358
  image_batch_size = len(image)
231
359
  else:
@@ -261,13 +389,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
261
389
  @torch.no_grad()
262
390
  def __call__(
263
391
  self,
264
- prompt: Union[str, List[str]],
392
+ prompt: Union[str, List[str]] = None,
265
393
  image: PipelineImageInput = None,
266
394
  num_inference_steps: int = 75,
267
395
  guidance_scale: float = 9.0,
268
396
  negative_prompt: Optional[Union[str, List[str]]] = None,
269
397
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
270
398
  latents: Optional[torch.Tensor] = None,
399
+ prompt_embeds: Optional[torch.Tensor] = None,
400
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
401
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
402
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
271
403
  output_type: Optional[str] = "pil",
272
404
  return_dict: bool = True,
273
405
  callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
@@ -359,10 +491,22 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
359
491
  """
360
492
 
361
493
  # 1. Check inputs
362
- self.check_inputs(prompt, image, callback_steps)
494
+ self.check_inputs(
495
+ prompt,
496
+ image,
497
+ callback_steps,
498
+ negative_prompt,
499
+ prompt_embeds,
500
+ negative_prompt_embeds,
501
+ pooled_prompt_embeds,
502
+ negative_pooled_prompt_embeds,
503
+ )
363
504
 
364
505
  # 2. Define call parameters
365
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
506
+ if prompt is not None:
507
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
508
+ else:
509
+ batch_size = prompt_embeds.shape[0]
366
510
  device = self._execution_device
367
511
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
368
512
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -373,16 +517,32 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
373
517
  prompt = [""] * batch_size
374
518
 
375
519
  # 3. Encode input prompt
376
- text_embeddings, text_pooler_out = self._encode_prompt(
377
- prompt, device, do_classifier_free_guidance, negative_prompt
520
+ (
521
+ prompt_embeds,
522
+ negative_prompt_embeds,
523
+ pooled_prompt_embeds,
524
+ negative_pooled_prompt_embeds,
525
+ ) = self.encode_prompt(
526
+ prompt,
527
+ device,
528
+ do_classifier_free_guidance,
529
+ negative_prompt,
530
+ prompt_embeds,
531
+ negative_prompt_embeds,
532
+ pooled_prompt_embeds,
533
+ negative_pooled_prompt_embeds,
378
534
  )
379
535
 
536
+ if do_classifier_free_guidance:
537
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
538
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
539
+
380
540
  # 4. Preprocess image
381
541
  image = self.image_processor.preprocess(image)
382
- image = image.to(dtype=text_embeddings.dtype, device=device)
542
+ image = image.to(dtype=prompt_embeds.dtype, device=device)
383
543
  if image.shape[1] == 3:
384
544
  # encode image if not in latent-space yet
385
- image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor
545
+ image = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
386
546
 
387
547
  # 5. set timesteps
388
548
  self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -400,17 +560,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
400
560
  inv_noise_level = (noise_level**2 + 1) ** (-0.5)
401
561
 
402
562
  image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None]
403
- image_cond = image_cond.to(text_embeddings.dtype)
563
+ image_cond = image_cond.to(prompt_embeds.dtype)
404
564
 
405
565
  noise_level_embed = torch.cat(
406
566
  [
407
- torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
408
- torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
567
+ torch.ones(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
568
+ torch.zeros(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
409
569
  ],
410
570
  dim=1,
411
571
  )
412
572
 
413
- timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1)
573
+ timestep_condition = torch.cat([noise_level_embed, pooled_prompt_embeds], dim=1)
414
574
 
415
575
  # 6. Prepare latent variables
416
576
  height, width = image.shape[2:]
@@ -420,7 +580,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
420
580
  num_channels_latents,
421
581
  height * 2, # 2x upscale
422
582
  width * 2,
423
- text_embeddings.dtype,
583
+ prompt_embeds.dtype,
424
584
  device,
425
585
  generator,
426
586
  latents,
@@ -454,7 +614,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
454
614
  noise_pred = self.unet(
455
615
  scaled_model_input,
456
616
  timestep,
457
- encoder_hidden_states=text_embeddings,
617
+ encoder_hidden_states=prompt_embeds,
458
618
  timestep_cond=timestep_condition,
459
619
  ).sample
460
620
 
@@ -77,7 +77,7 @@ def retrieve_timesteps(
77
77
  sigmas: Optional[List[float]] = None,
78
78
  **kwargs,
79
79
  ):
80
- """
80
+ r"""
81
81
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
82
82
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
83
83
 
@@ -203,6 +203,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
203
203
  if hasattr(self, "transformer") and self.transformer is not None
204
204
  else 128
205
205
  )
206
+ self.patch_size = (
207
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
208
+ )
206
209
 
207
210
  def _get_t5_prompt_embeds(
208
211
  self,
@@ -525,8 +528,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
525
528
  callback_on_step_end_tensor_inputs=None,
526
529
  max_sequence_length=None,
527
530
  ):
528
- if height % 8 != 0 or width % 8 != 0:
529
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
531
+ if (
532
+ height % (self.vae_scale_factor * self.patch_size) != 0
533
+ or width % (self.vae_scale_factor * self.patch_size) != 0
534
+ ):
535
+ raise ValueError(
536
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
537
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
538
+ )
530
539
 
531
540
  if callback_on_step_end_tensor_inputs is not None and not all(
532
541
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import Callable, Dict, List, Optional, Union
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
17
 
18
18
  import PIL.Image
19
19
  import torch
@@ -25,7 +25,7 @@ from transformers import (
25
25
  )
26
26
 
27
27
  from ...image_processor import PipelineImageInput, VaeImageProcessor
28
- from ...loaders import SD3LoraLoaderMixin
28
+ from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
29
29
  from ...models.autoencoders import AutoencoderKL
30
30
  from ...models.transformers import SD3Transformer2DModel
31
31
  from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -98,7 +98,7 @@ def retrieve_timesteps(
98
98
  sigmas: Optional[List[float]] = None,
99
99
  **kwargs,
100
100
  ):
101
- """
101
+ r"""
102
102
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
103
103
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
104
104
 
@@ -149,7 +149,7 @@ def retrieve_timesteps(
149
149
  return timesteps, num_inference_steps
150
150
 
151
151
 
152
- class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
152
+ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
153
153
  r"""
154
154
  Args:
155
155
  transformer ([`SD3Transformer2DModel`]):
@@ -680,6 +680,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
680
680
  def guidance_scale(self):
681
681
  return self._guidance_scale
682
682
 
683
+ @property
684
+ def joint_attention_kwargs(self):
685
+ return self._joint_attention_kwargs
686
+
683
687
  @property
684
688
  def clip_skip(self):
685
689
  return self._clip_skip
@@ -723,6 +727,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
723
727
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
724
728
  output_type: Optional[str] = "pil",
725
729
  return_dict: bool = True,
730
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
726
731
  clip_skip: Optional[int] = None,
727
732
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
728
733
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -797,6 +802,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
797
802
  return_dict (`bool`, *optional*, defaults to `True`):
798
803
  Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
799
804
  of a plain tuple.
805
+ joint_attention_kwargs (`dict`, *optional*):
806
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
807
+ `self.processor` in
808
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
800
809
  callback_on_step_end (`Callable`, *optional*):
801
810
  A function that calls at the end of each denoising steps during the inference. The function is called
802
811
  with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -835,6 +844,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
835
844
 
836
845
  self._guidance_scale = guidance_scale
837
846
  self._clip_skip = clip_skip
847
+ self._joint_attention_kwargs = joint_attention_kwargs
838
848
  self._interrupt = False
839
849
 
840
850
  # 2. Define call parameters
@@ -847,6 +857,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
847
857
 
848
858
  device = self._execution_device
849
859
 
860
+ lora_scale = (
861
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
862
+ )
863
+
850
864
  (
851
865
  prompt_embeds,
852
866
  negative_prompt_embeds,
@@ -868,6 +882,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
868
882
  clip_skip=self.clip_skip,
869
883
  num_images_per_prompt=num_images_per_prompt,
870
884
  max_sequence_length=max_sequence_length,
885
+ lora_scale=lora_scale,
871
886
  )
872
887
 
873
888
  if self.do_classifier_free_guidance:
@@ -912,6 +927,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
912
927
  timestep=timestep,
913
928
  encoder_hidden_states=prompt_embeds,
914
929
  pooled_projections=pooled_prompt_embeds,
930
+ joint_attention_kwargs=self.joint_attention_kwargs,
915
931
  return_dict=False,
916
932
  )[0]
917
933