diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,14 +17,16 @@ from typing import Any, Callable, Dict, List, Optional, Union
17
17
 
18
18
  import torch
19
19
  from transformers import (
20
+ BaseImageProcessor,
20
21
  CLIPTextModelWithProjection,
21
22
  CLIPTokenizer,
23
+ PreTrainedModel,
22
24
  T5EncoderModel,
23
25
  T5TokenizerFast,
24
26
  )
25
27
 
26
- from ...image_processor import VaeImageProcessor
27
- from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
28
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
29
+ from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
28
30
  from ...models.autoencoders import AutoencoderKL
29
31
  from ...models.transformers import SD3Transformer2DModel
30
32
  from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -68,6 +70,20 @@ EXAMPLE_DOC_STRING = """
68
70
  """
69
71
 
70
72
 
73
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
74
+ def calculate_shift(
75
+ image_seq_len,
76
+ base_seq_len: int = 256,
77
+ max_seq_len: int = 4096,
78
+ base_shift: float = 0.5,
79
+ max_shift: float = 1.16,
80
+ ):
81
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
82
+ b = base_shift - m * base_seq_len
83
+ mu = image_seq_len * m + b
84
+ return mu
85
+
86
+
71
87
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
72
88
  def retrieve_timesteps(
73
89
  scheduler,
@@ -77,7 +93,7 @@ def retrieve_timesteps(
77
93
  sigmas: Optional[List[float]] = None,
78
94
  **kwargs,
79
95
  ):
80
- """
96
+ r"""
81
97
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
82
98
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
83
99
 
@@ -128,7 +144,7 @@ def retrieve_timesteps(
128
144
  return timesteps, num_inference_steps
129
145
 
130
146
 
131
- class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
147
+ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
132
148
  r"""
133
149
  Args:
134
150
  transformer ([`SD3Transformer2DModel`]):
@@ -160,10 +176,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
160
176
  tokenizer_3 (`T5TokenizerFast`):
161
177
  Tokenizer of class
162
178
  [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
179
+ image_encoder (`PreTrainedModel`, *optional*):
180
+ Pre-trained Vision Model for IP Adapter.
181
+ feature_extractor (`BaseImageProcessor`, *optional*):
182
+ Image processor for IP Adapter.
163
183
  """
164
184
 
165
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
166
- _optional_components = []
185
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
186
+ _optional_components = ["image_encoder", "feature_extractor"]
167
187
  _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
168
188
 
169
189
  def __init__(
@@ -177,6 +197,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
177
197
  tokenizer_2: CLIPTokenizer,
178
198
  text_encoder_3: T5EncoderModel,
179
199
  tokenizer_3: T5TokenizerFast,
200
+ image_encoder: PreTrainedModel = None,
201
+ feature_extractor: BaseImageProcessor = None,
180
202
  ):
181
203
  super().__init__()
182
204
 
@@ -190,6 +212,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
190
212
  tokenizer_3=tokenizer_3,
191
213
  transformer=transformer,
192
214
  scheduler=scheduler,
215
+ image_encoder=image_encoder,
216
+ feature_extractor=feature_extractor,
193
217
  )
194
218
  self.vae_scale_factor = (
195
219
  2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -203,6 +227,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
203
227
  if hasattr(self, "transformer") and self.transformer is not None
204
228
  else 128
205
229
  )
230
+ self.patch_size = (
231
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
232
+ )
206
233
 
207
234
  def _get_t5_prompt_embeds(
208
235
  self,
@@ -525,8 +552,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
525
552
  callback_on_step_end_tensor_inputs=None,
526
553
  max_sequence_length=None,
527
554
  ):
528
- if height % 8 != 0 or width % 8 != 0:
529
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
555
+ if (
556
+ height % (self.vae_scale_factor * self.patch_size) != 0
557
+ or width % (self.vae_scale_factor * self.patch_size) != 0
558
+ ):
559
+ raise ValueError(
560
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
561
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
562
+ )
530
563
 
531
564
  if callback_on_step_end_tensor_inputs is not None and not all(
532
565
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -633,6 +666,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
633
666
  def guidance_scale(self):
634
667
  return self._guidance_scale
635
668
 
669
+ @property
670
+ def skip_guidance_layers(self):
671
+ return self._skip_guidance_layers
672
+
636
673
  @property
637
674
  def clip_skip(self):
638
675
  return self._clip_skip
@@ -656,6 +693,83 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
656
693
  def interrupt(self):
657
694
  return self._interrupt
658
695
 
696
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
697
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
698
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
699
+
700
+ Args:
701
+ image (`PipelineImageInput`):
702
+ Input image to be encoded.
703
+ device: (`torch.device`):
704
+ Torch device.
705
+
706
+ Returns:
707
+ `torch.Tensor`: The encoded image feature representation.
708
+ """
709
+ if not isinstance(image, torch.Tensor):
710
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
711
+
712
+ image = image.to(device=device, dtype=self.dtype)
713
+
714
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
715
+
716
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
717
+ def prepare_ip_adapter_image_embeds(
718
+ self,
719
+ ip_adapter_image: Optional[PipelineImageInput] = None,
720
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
721
+ device: Optional[torch.device] = None,
722
+ num_images_per_prompt: int = 1,
723
+ do_classifier_free_guidance: bool = True,
724
+ ) -> torch.Tensor:
725
+ """Prepares image embeddings for use in the IP-Adapter.
726
+
727
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
728
+
729
+ Args:
730
+ ip_adapter_image (`PipelineImageInput`, *optional*):
731
+ The input image to extract features from for IP-Adapter.
732
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
733
+ Precomputed image embeddings.
734
+ device: (`torch.device`, *optional*):
735
+ Torch device.
736
+ num_images_per_prompt (`int`, defaults to 1):
737
+ Number of images that should be generated per prompt.
738
+ do_classifier_free_guidance (`bool`, defaults to True):
739
+ Whether to use classifier free guidance or not.
740
+ """
741
+ device = device or self._execution_device
742
+
743
+ if ip_adapter_image_embeds is not None:
744
+ if do_classifier_free_guidance:
745
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
746
+ else:
747
+ single_image_embeds = ip_adapter_image_embeds
748
+ elif ip_adapter_image is not None:
749
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
750
+ if do_classifier_free_guidance:
751
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
752
+ else:
753
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
754
+
755
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
756
+
757
+ if do_classifier_free_guidance:
758
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
759
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
760
+
761
+ return image_embeds.to(device=device)
762
+
763
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
764
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
765
+ logger.warning(
766
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
767
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
768
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
769
+ )
770
+
771
+ super().enable_sequential_cpu_offload(*args, **kwargs)
772
+
659
773
  @torch.no_grad()
660
774
  @replace_example_docstring(EXAMPLE_DOC_STRING)
661
775
  def __call__(
@@ -666,7 +780,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
666
780
  height: Optional[int] = None,
667
781
  width: Optional[int] = None,
668
782
  num_inference_steps: int = 28,
669
- timesteps: List[int] = None,
783
+ sigmas: Optional[List[float]] = None,
670
784
  guidance_scale: float = 7.0,
671
785
  negative_prompt: Optional[Union[str, List[str]]] = None,
672
786
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -678,6 +792,8 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
678
792
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
679
793
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
680
794
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
795
+ ip_adapter_image: Optional[PipelineImageInput] = None,
796
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
681
797
  output_type: Optional[str] = "pil",
682
798
  return_dict: bool = True,
683
799
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -685,6 +801,11 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
685
801
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
686
802
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
687
803
  max_sequence_length: int = 256,
804
+ skip_guidance_layers: List[int] = None,
805
+ skip_layer_guidance_scale: float = 2.8,
806
+ skip_layer_guidance_stop: float = 0.2,
807
+ skip_layer_guidance_start: float = 0.01,
808
+ mu: Optional[float] = None,
688
809
  ):
689
810
  r"""
690
811
  Function invoked when calling the pipeline for generation.
@@ -706,10 +827,10 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
706
827
  num_inference_steps (`int`, *optional*, defaults to 50):
707
828
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
708
829
  expense of slower inference.
709
- timesteps (`List[int]`, *optional*):
710
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
711
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
712
- passed will be used. Must be in descending order.
830
+ sigmas (`List[float]`, *optional*):
831
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
832
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
833
+ will be used.
713
834
  guidance_scale (`float`, *optional*, defaults to 7.0):
714
835
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
715
836
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -749,12 +870,17 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
749
870
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
750
871
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
751
872
  input argument.
873
+ ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
874
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
875
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
876
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
877
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
752
878
  output_type (`str`, *optional*, defaults to `"pil"`):
753
879
  The output format of the generate image. Choose between
754
880
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
755
881
  return_dict (`bool`, *optional*, defaults to `True`):
756
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
757
- of a plain tuple.
882
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
883
+ a plain tuple.
758
884
  joint_attention_kwargs (`dict`, *optional*):
759
885
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
760
886
  `self.processor` in
@@ -769,6 +895,23 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
769
895
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
770
896
  `._callback_tensor_inputs` attribute of your pipeline class.
771
897
  max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
898
+ skip_guidance_layers (`List[int]`, *optional*):
899
+ A list of integers that specify layers to skip during guidance. If not provided, all layers will be
900
+ used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
901
+ Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
902
+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
903
+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
904
+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
905
+ with a scale of `1`.
906
+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
907
+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
908
+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
909
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
910
+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
911
+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
912
+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
913
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
914
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
772
915
 
773
916
  Examples:
774
917
 
@@ -800,6 +943,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
800
943
  )
801
944
 
802
945
  self._guidance_scale = guidance_scale
946
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
803
947
  self._clip_skip = clip_skip
804
948
  self._joint_attention_kwargs = joint_attention_kwargs
805
949
  self._interrupt = False
@@ -842,15 +986,13 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
842
986
  )
843
987
 
844
988
  if self.do_classifier_free_guidance:
989
+ if skip_guidance_layers is not None:
990
+ original_prompt_embeds = prompt_embeds
991
+ original_pooled_prompt_embeds = pooled_prompt_embeds
845
992
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
846
993
  pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
847
994
 
848
- # 4. Prepare timesteps
849
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
850
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
851
- self._num_timesteps = len(timesteps)
852
-
853
- # 5. Prepare latent variables
995
+ # 4. Prepare latent variables
854
996
  num_channels_latents = self.transformer.config.in_channels
855
997
  latents = self.prepare_latents(
856
998
  batch_size * num_images_per_prompt,
@@ -863,7 +1005,49 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
863
1005
  latents,
864
1006
  )
865
1007
 
866
- # 6. Denoising loop
1008
+ # 5. Prepare timesteps
1009
+ scheduler_kwargs = {}
1010
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1011
+ _, _, height, width = latents.shape
1012
+ image_seq_len = (height // self.transformer.config.patch_size) * (
1013
+ width // self.transformer.config.patch_size
1014
+ )
1015
+ mu = calculate_shift(
1016
+ image_seq_len,
1017
+ self.scheduler.config.base_image_seq_len,
1018
+ self.scheduler.config.max_image_seq_len,
1019
+ self.scheduler.config.base_shift,
1020
+ self.scheduler.config.max_shift,
1021
+ )
1022
+ scheduler_kwargs["mu"] = mu
1023
+ elif mu is not None:
1024
+ scheduler_kwargs["mu"] = mu
1025
+ timesteps, num_inference_steps = retrieve_timesteps(
1026
+ self.scheduler,
1027
+ num_inference_steps,
1028
+ device,
1029
+ sigmas=sigmas,
1030
+ **scheduler_kwargs,
1031
+ )
1032
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1033
+ self._num_timesteps = len(timesteps)
1034
+
1035
+ # 6. Prepare image embeddings
1036
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1037
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1038
+ ip_adapter_image,
1039
+ ip_adapter_image_embeds,
1040
+ device,
1041
+ batch_size * num_images_per_prompt,
1042
+ self.do_classifier_free_guidance,
1043
+ )
1044
+
1045
+ if self.joint_attention_kwargs is None:
1046
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1047
+ else:
1048
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1049
+
1050
+ # 7. Denoising loop
867
1051
  with self.progress_bar(total=num_inference_steps) as progress_bar:
868
1052
  for i, t in enumerate(timesteps):
869
1053
  if self.interrupt:
@@ -887,6 +1071,27 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
887
1071
  if self.do_classifier_free_guidance:
888
1072
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
889
1073
  noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1074
+ should_skip_layers = (
1075
+ True
1076
+ if i > num_inference_steps * skip_layer_guidance_start
1077
+ and i < num_inference_steps * skip_layer_guidance_stop
1078
+ else False
1079
+ )
1080
+ if skip_guidance_layers is not None and should_skip_layers:
1081
+ timestep = t.expand(latents.shape[0])
1082
+ latent_model_input = latents
1083
+ noise_pred_skip_layers = self.transformer(
1084
+ hidden_states=latent_model_input,
1085
+ timestep=timestep,
1086
+ encoder_hidden_states=original_prompt_embeds,
1087
+ pooled_projections=original_pooled_prompt_embeds,
1088
+ joint_attention_kwargs=self.joint_attention_kwargs,
1089
+ return_dict=False,
1090
+ skip_layers=skip_guidance_layers,
1091
+ )[0]
1092
+ noise_pred = (
1093
+ noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
1094
+ )
890
1095
 
891
1096
  # compute the previous noisy sample x_t -> x_t-1
892
1097
  latents_dtype = latents.dtype
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import Callable, Dict, List, Optional, Union
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
17
 
18
18
  import PIL.Image
19
19
  import torch
@@ -25,7 +25,7 @@ from transformers import (
25
25
  )
26
26
 
27
27
  from ...image_processor import PipelineImageInput, VaeImageProcessor
28
- from ...loaders import SD3LoraLoaderMixin
28
+ from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
29
29
  from ...models.autoencoders import AutoencoderKL
30
30
  from ...models.transformers import SD3Transformer2DModel
31
31
  from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -75,6 +75,20 @@ EXAMPLE_DOC_STRING = """
75
75
  """
76
76
 
77
77
 
78
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
79
+ def calculate_shift(
80
+ image_seq_len,
81
+ base_seq_len: int = 256,
82
+ max_seq_len: int = 4096,
83
+ base_shift: float = 0.5,
84
+ max_shift: float = 1.16,
85
+ ):
86
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
87
+ b = base_shift - m * base_seq_len
88
+ mu = image_seq_len * m + b
89
+ return mu
90
+
91
+
78
92
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
79
93
  def retrieve_latents(
80
94
  encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -98,7 +112,7 @@ def retrieve_timesteps(
98
112
  sigmas: Optional[List[float]] = None,
99
113
  **kwargs,
100
114
  ):
101
- """
115
+ r"""
102
116
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
103
117
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
104
118
 
@@ -149,7 +163,7 @@ def retrieve_timesteps(
149
163
  return timesteps, num_inference_steps
150
164
 
151
165
 
152
- class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
166
+ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
153
167
  r"""
154
168
  Args:
155
169
  transformer ([`SD3Transformer2DModel`]):
@@ -218,6 +232,9 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
218
232
  )
219
233
  self.tokenizer_max_length = self.tokenizer.model_max_length
220
234
  self.default_sample_size = self.transformer.config.sample_size
235
+ self.patch_size = (
236
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
237
+ )
221
238
 
222
239
  # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
223
240
  def _get_t5_prompt_embeds(
@@ -531,6 +548,8 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
531
548
  prompt,
532
549
  prompt_2,
533
550
  prompt_3,
551
+ height,
552
+ width,
534
553
  strength,
535
554
  negative_prompt=None,
536
555
  negative_prompt_2=None,
@@ -542,6 +561,15 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
542
561
  callback_on_step_end_tensor_inputs=None,
543
562
  max_sequence_length=None,
544
563
  ):
564
+ if (
565
+ height % (self.vae_scale_factor * self.patch_size) != 0
566
+ or width % (self.vae_scale_factor * self.patch_size) != 0
567
+ ):
568
+ raise ValueError(
569
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
570
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
571
+ )
572
+
545
573
  if strength < 0 or strength > 1:
546
574
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
547
575
 
@@ -680,6 +708,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
680
708
  def guidance_scale(self):
681
709
  return self._guidance_scale
682
710
 
711
+ @property
712
+ def joint_attention_kwargs(self):
713
+ return self._joint_attention_kwargs
714
+
683
715
  @property
684
716
  def clip_skip(self):
685
717
  return self._clip_skip
@@ -706,10 +738,12 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
706
738
  prompt: Union[str, List[str]] = None,
707
739
  prompt_2: Optional[Union[str, List[str]]] = None,
708
740
  prompt_3: Optional[Union[str, List[str]]] = None,
741
+ height: Optional[int] = None,
742
+ width: Optional[int] = None,
709
743
  image: PipelineImageInput = None,
710
744
  strength: float = 0.6,
711
745
  num_inference_steps: int = 50,
712
- timesteps: List[int] = None,
746
+ sigmas: Optional[List[float]] = None,
713
747
  guidance_scale: float = 7.0,
714
748
  negative_prompt: Optional[Union[str, List[str]]] = None,
715
749
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -723,10 +757,12 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
723
757
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
724
758
  output_type: Optional[str] = "pil",
725
759
  return_dict: bool = True,
760
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
726
761
  clip_skip: Optional[int] = None,
727
762
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
728
763
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
729
764
  max_sequence_length: int = 256,
765
+ mu: Optional[float] = None,
730
766
  ):
731
767
  r"""
732
768
  Function invoked when calling the pipeline for generation.
@@ -748,10 +784,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
748
784
  num_inference_steps (`int`, *optional*, defaults to 50):
749
785
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
750
786
  expense of slower inference.
751
- timesteps (`List[int]`, *optional*):
752
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
753
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
754
- passed will be used. Must be in descending order.
787
+ sigmas (`List[float]`, *optional*):
788
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
789
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
790
+ will be used.
755
791
  guidance_scale (`float`, *optional*, defaults to 7.0):
756
792
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
757
793
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -795,8 +831,12 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
795
831
  The output format of the generate image. Choose between
796
832
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
797
833
  return_dict (`bool`, *optional*, defaults to `True`):
798
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
799
- of a plain tuple.
834
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
835
+ a plain tuple.
836
+ joint_attention_kwargs (`dict`, *optional*):
837
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
838
+ `self.processor` in
839
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
800
840
  callback_on_step_end (`Callable`, *optional*):
801
841
  A function that calls at the end of each denoising steps during the inference. The function is called
802
842
  with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -807,6 +847,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
807
847
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
808
848
  `._callback_tensor_inputs` attribute of your pipeline class.
809
849
  max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
850
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
810
851
 
811
852
  Examples:
812
853
 
@@ -815,12 +856,16 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
815
856
  [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
816
857
  `tuple`. When returning a tuple, the first element is a list with the generated images.
817
858
  """
859
+ height = height or self.default_sample_size * self.vae_scale_factor
860
+ width = width or self.default_sample_size * self.vae_scale_factor
818
861
 
819
862
  # 1. Check inputs. Raise error if not correct
820
863
  self.check_inputs(
821
864
  prompt,
822
865
  prompt_2,
823
866
  prompt_3,
867
+ height,
868
+ width,
824
869
  strength,
825
870
  negative_prompt=negative_prompt,
826
871
  negative_prompt_2=negative_prompt_2,
@@ -835,6 +880,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
835
880
 
836
881
  self._guidance_scale = guidance_scale
837
882
  self._clip_skip = clip_skip
883
+ self._joint_attention_kwargs = joint_attention_kwargs
838
884
  self._interrupt = False
839
885
 
840
886
  # 2. Define call parameters
@@ -847,6 +893,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
847
893
 
848
894
  device = self._execution_device
849
895
 
896
+ lora_scale = (
897
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
898
+ )
899
+
850
900
  (
851
901
  prompt_embeds,
852
902
  negative_prompt_embeds,
@@ -868,6 +918,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
868
918
  clip_skip=self.clip_skip,
869
919
  num_images_per_prompt=num_images_per_prompt,
870
920
  max_sequence_length=max_sequence_length,
921
+ lora_scale=lora_scale,
871
922
  )
872
923
 
873
924
  if self.do_classifier_free_guidance:
@@ -875,10 +926,27 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
875
926
  pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
876
927
 
877
928
  # 3. Preprocess image
878
- image = self.image_processor.preprocess(image)
929
+ image = self.image_processor.preprocess(image, height=height, width=width)
879
930
 
880
931
  # 4. Prepare timesteps
881
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
932
+ scheduler_kwargs = {}
933
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
934
+ image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
935
+ int(width) // self.vae_scale_factor // self.transformer.config.patch_size
936
+ )
937
+ mu = calculate_shift(
938
+ image_seq_len,
939
+ self.scheduler.config.base_image_seq_len,
940
+ self.scheduler.config.max_image_seq_len,
941
+ self.scheduler.config.base_shift,
942
+ self.scheduler.config.max_shift,
943
+ )
944
+ scheduler_kwargs["mu"] = mu
945
+ elif mu is not None:
946
+ scheduler_kwargs["mu"] = mu
947
+ timesteps, num_inference_steps = retrieve_timesteps(
948
+ self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
949
+ )
882
950
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
883
951
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
884
952
 
@@ -912,6 +980,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
912
980
  timestep=timestep,
913
981
  encoder_hidden_states=prompt_embeds,
914
982
  pooled_projections=pooled_prompt_embeds,
983
+ joint_attention_kwargs=self.joint_attention_kwargs,
915
984
  return_dict=False,
916
985
  )[0]
917
986