diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
17
17
 
18
18
  import numpy as np
19
19
  import torch
20
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
21
28
 
22
- from ...image_processor import VaeImageProcessor
23
- from ...loaders import FluxLoraLoaderMixin
29
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
30
+ from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
24
31
  from ...models.autoencoders import AutoencoderKL
25
32
  from ...models.transformers import FluxTransformer2DModel
26
33
  from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -86,7 +93,7 @@ def retrieve_timesteps(
86
93
  sigmas: Optional[List[float]] = None,
87
94
  **kwargs,
88
95
  ):
89
- """
96
+ r"""
90
97
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
98
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
99
 
@@ -137,7 +144,13 @@ def retrieve_timesteps(
137
144
  return timesteps, num_inference_steps
138
145
 
139
146
 
140
- class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
147
+ class FluxPipeline(
148
+ DiffusionPipeline,
149
+ FluxLoraLoaderMixin,
150
+ FromSingleFileMixin,
151
+ TextualInversionLoaderMixin,
152
+ FluxIPAdapterMixin,
153
+ ):
141
154
  r"""
142
155
  The Flux pipeline for text-to-image generation.
143
156
 
@@ -164,8 +177,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
164
177
  [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
165
178
  """
166
179
 
167
- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
168
- _optional_components = []
180
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
181
+ _optional_components = ["image_encoder", "feature_extractor"]
169
182
  _callback_tensor_inputs = ["latents", "prompt_embeds"]
170
183
 
171
184
  def __init__(
@@ -177,6 +190,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
177
190
  text_encoder_2: T5EncoderModel,
178
191
  tokenizer_2: T5TokenizerFast,
179
192
  transformer: FluxTransformer2DModel,
193
+ image_encoder: CLIPVisionModelWithProjection = None,
194
+ feature_extractor: CLIPImageProcessor = None,
180
195
  ):
181
196
  super().__init__()
182
197
 
@@ -188,15 +203,19 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
188
203
  tokenizer_2=tokenizer_2,
189
204
  transformer=transformer,
190
205
  scheduler=scheduler,
206
+ image_encoder=image_encoder,
207
+ feature_extractor=feature_extractor,
191
208
  )
192
209
  self.vae_scale_factor = (
193
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
210
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
194
211
  )
195
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
212
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
213
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
214
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
196
215
  self.tokenizer_max_length = (
197
216
  self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
198
217
  )
199
- self.default_sample_size = 64
218
+ self.default_sample_size = 128
200
219
 
201
220
  def _get_t5_prompt_embeds(
202
221
  self,
@@ -212,6 +231,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
212
231
  prompt = [prompt] if isinstance(prompt, str) else prompt
213
232
  batch_size = len(prompt)
214
233
 
234
+ if isinstance(self, TextualInversionLoaderMixin):
235
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
236
+
215
237
  text_inputs = self.tokenizer_2(
216
238
  prompt,
217
239
  padding="max_length",
@@ -255,6 +277,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
255
277
  prompt = [prompt] if isinstance(prompt, str) else prompt
256
278
  batch_size = len(prompt)
257
279
 
280
+ if isinstance(self, TextualInversionLoaderMixin):
281
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
282
+
258
283
  text_inputs = self.tokenizer(
259
284
  prompt,
260
285
  padding="max_length",
@@ -331,10 +356,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
331
356
  scale_lora_layers(self.text_encoder_2, lora_scale)
332
357
 
333
358
  prompt = [prompt] if isinstance(prompt, str) else prompt
334
- if prompt is not None:
335
- batch_size = len(prompt)
336
- else:
337
- batch_size = prompt_embeds.shape[0]
338
359
 
339
360
  if prompt_embeds is None:
340
361
  prompt_2 = prompt_2 or prompt
@@ -364,24 +385,71 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
364
385
  unscale_lora_layers(self.text_encoder_2, lora_scale)
365
386
 
366
387
  dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
367
- text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
368
- text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
388
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
369
389
 
370
390
  return prompt_embeds, pooled_prompt_embeds, text_ids
371
391
 
392
+ def encode_image(self, image, device, num_images_per_prompt):
393
+ dtype = next(self.image_encoder.parameters()).dtype
394
+
395
+ if not isinstance(image, torch.Tensor):
396
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
397
+
398
+ image = image.to(device=device, dtype=dtype)
399
+ image_embeds = self.image_encoder(image).image_embeds
400
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
401
+ return image_embeds
402
+
403
+ def prepare_ip_adapter_image_embeds(
404
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
405
+ ):
406
+ image_embeds = []
407
+ if ip_adapter_image_embeds is None:
408
+ if not isinstance(ip_adapter_image, list):
409
+ ip_adapter_image = [ip_adapter_image]
410
+
411
+ if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
412
+ raise ValueError(
413
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
414
+ )
415
+
416
+ for single_ip_adapter_image, image_proj_layer in zip(
417
+ ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
418
+ ):
419
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
420
+
421
+ image_embeds.append(single_image_embeds[None, :])
422
+ else:
423
+ for single_image_embeds in ip_adapter_image_embeds:
424
+ image_embeds.append(single_image_embeds)
425
+
426
+ ip_adapter_image_embeds = []
427
+ for i, single_image_embeds in enumerate(image_embeds):
428
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
429
+ single_image_embeds = single_image_embeds.to(device=device)
430
+ ip_adapter_image_embeds.append(single_image_embeds)
431
+
432
+ return ip_adapter_image_embeds
433
+
372
434
  def check_inputs(
373
435
  self,
374
436
  prompt,
375
437
  prompt_2,
376
438
  height,
377
439
  width,
440
+ negative_prompt=None,
441
+ negative_prompt_2=None,
378
442
  prompt_embeds=None,
443
+ negative_prompt_embeds=None,
379
444
  pooled_prompt_embeds=None,
445
+ negative_pooled_prompt_embeds=None,
380
446
  callback_on_step_end_tensor_inputs=None,
381
447
  max_sequence_length=None,
382
448
  ):
383
- if height % 8 != 0 or width % 8 != 0:
384
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
449
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
450
+ logger.warning(
451
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
452
+ )
385
453
 
386
454
  if callback_on_step_end_tensor_inputs is not None and not all(
387
455
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -409,25 +477,47 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
409
477
  elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
410
478
  raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
411
479
 
480
+ if negative_prompt is not None and negative_prompt_embeds is not None:
481
+ raise ValueError(
482
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
483
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
484
+ )
485
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
486
+ raise ValueError(
487
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
488
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
489
+ )
490
+
491
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
492
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
493
+ raise ValueError(
494
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
495
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
496
+ f" {negative_prompt_embeds.shape}."
497
+ )
498
+
412
499
  if prompt_embeds is not None and pooled_prompt_embeds is None:
413
500
  raise ValueError(
414
501
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
415
502
  )
503
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
504
+ raise ValueError(
505
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
506
+ )
416
507
 
417
508
  if max_sequence_length is not None and max_sequence_length > 512:
418
509
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
419
510
 
420
511
  @staticmethod
421
512
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
422
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
423
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
424
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
513
+ latent_image_ids = torch.zeros(height, width, 3)
514
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
515
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
425
516
 
426
517
  latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
427
518
 
428
- latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
429
519
  latent_image_ids = latent_image_ids.reshape(
430
- batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
520
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
431
521
  )
432
522
 
433
523
  return latent_image_ids.to(device=device, dtype=dtype)
@@ -444,16 +534,47 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
444
534
  def _unpack_latents(latents, height, width, vae_scale_factor):
445
535
  batch_size, num_patches, channels = latents.shape
446
536
 
447
- height = height // vae_scale_factor
448
- width = width // vae_scale_factor
537
+ # VAE applies 8x compression on images but we must also account for packing which requires
538
+ # latent height and width to be divisible by 2.
539
+ height = 2 * (int(height) // (vae_scale_factor * 2))
540
+ width = 2 * (int(width) // (vae_scale_factor * 2))
449
541
 
450
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
542
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
451
543
  latents = latents.permute(0, 3, 1, 4, 2, 5)
452
544
 
453
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
545
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
454
546
 
455
547
  return latents
456
548
 
549
+ def enable_vae_slicing(self):
550
+ r"""
551
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
552
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
553
+ """
554
+ self.vae.enable_slicing()
555
+
556
+ def disable_vae_slicing(self):
557
+ r"""
558
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
559
+ computing decoding in one step.
560
+ """
561
+ self.vae.disable_slicing()
562
+
563
+ def enable_vae_tiling(self):
564
+ r"""
565
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
566
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
567
+ processing larger images.
568
+ """
569
+ self.vae.enable_tiling()
570
+
571
+ def disable_vae_tiling(self):
572
+ r"""
573
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
574
+ computing decoding in one step.
575
+ """
576
+ self.vae.disable_tiling()
577
+
457
578
  def prepare_latents(
458
579
  self,
459
580
  batch_size,
@@ -465,13 +586,15 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
465
586
  generator,
466
587
  latents=None,
467
588
  ):
468
- height = 2 * (int(height) // self.vae_scale_factor)
469
- width = 2 * (int(width) // self.vae_scale_factor)
589
+ # VAE applies 8x compression on images but we must also account for packing which requires
590
+ # latent height and width to be divisible by 2.
591
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
592
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
470
593
 
471
594
  shape = (batch_size, num_channels_latents, height, width)
472
595
 
473
596
  if latents is not None:
474
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
597
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
475
598
  return latents.to(device=device, dtype=dtype), latent_image_ids
476
599
 
477
600
  if isinstance(generator, list) and len(generator) != batch_size:
@@ -483,7 +606,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
483
606
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
484
607
  latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
485
608
 
486
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
609
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
487
610
 
488
611
  return latents, latent_image_ids
489
612
 
@@ -509,16 +632,25 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
509
632
  self,
510
633
  prompt: Union[str, List[str]] = None,
511
634
  prompt_2: Optional[Union[str, List[str]]] = None,
635
+ negative_prompt: Union[str, List[str]] = None,
636
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
637
+ true_cfg_scale: float = 1.0,
512
638
  height: Optional[int] = None,
513
639
  width: Optional[int] = None,
514
640
  num_inference_steps: int = 28,
515
- timesteps: List[int] = None,
516
- guidance_scale: float = 7.0,
641
+ sigmas: Optional[List[float]] = None,
642
+ guidance_scale: float = 3.5,
517
643
  num_images_per_prompt: Optional[int] = 1,
518
644
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
519
645
  latents: Optional[torch.FloatTensor] = None,
520
646
  prompt_embeds: Optional[torch.FloatTensor] = None,
521
647
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
648
+ ip_adapter_image: Optional[PipelineImageInput] = None,
649
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
650
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
651
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
652
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
653
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
522
654
  output_type: Optional[str] = "pil",
523
655
  return_dict: bool = True,
524
656
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -543,10 +675,10 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
543
675
  num_inference_steps (`int`, *optional*, defaults to 50):
544
676
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
545
677
  expense of slower inference.
546
- timesteps (`List[int]`, *optional*):
547
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
548
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
549
- passed will be used. Must be in descending order.
678
+ sigmas (`List[float]`, *optional*):
679
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
680
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
681
+ will be used.
550
682
  guidance_scale (`float`, *optional*, defaults to 7.0):
551
683
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
552
684
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -568,6 +700,17 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
568
700
  pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
569
701
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
570
702
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
703
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
704
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
705
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
706
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
707
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
708
+ negative_ip_adapter_image:
709
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
710
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
711
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
712
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
713
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
571
714
  output_type (`str`, *optional*, defaults to `"pil"`):
572
715
  The output format of the generate image. Choose between
573
716
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -605,8 +748,12 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
605
748
  prompt_2,
606
749
  height,
607
750
  width,
751
+ negative_prompt=negative_prompt,
752
+ negative_prompt_2=negative_prompt_2,
608
753
  prompt_embeds=prompt_embeds,
754
+ negative_prompt_embeds=negative_prompt_embeds,
609
755
  pooled_prompt_embeds=pooled_prompt_embeds,
756
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
610
757
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
611
758
  max_sequence_length=max_sequence_length,
612
759
  )
@@ -628,6 +775,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
628
775
  lora_scale = (
629
776
  self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
630
777
  )
778
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
631
779
  (
632
780
  prompt_embeds,
633
781
  pooled_prompt_embeds,
@@ -642,6 +790,21 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
642
790
  max_sequence_length=max_sequence_length,
643
791
  lora_scale=lora_scale,
644
792
  )
793
+ if do_true_cfg:
794
+ (
795
+ negative_prompt_embeds,
796
+ negative_pooled_prompt_embeds,
797
+ _,
798
+ ) = self.encode_prompt(
799
+ prompt=negative_prompt,
800
+ prompt_2=negative_prompt_2,
801
+ prompt_embeds=negative_prompt_embeds,
802
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
803
+ device=device,
804
+ num_images_per_prompt=num_images_per_prompt,
805
+ max_sequence_length=max_sequence_length,
806
+ lora_scale=lora_scale,
807
+ )
645
808
 
646
809
  # 4. Prepare latent variables
647
810
  num_channels_latents = self.transformer.config.in_channels // 4
@@ -657,7 +820,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
657
820
  )
658
821
 
659
822
  # 5. Prepare timesteps
660
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
823
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
661
824
  image_seq_len = latents.shape[1]
662
825
  mu = calculate_shift(
663
826
  image_seq_len,
@@ -670,32 +833,61 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
670
833
  self.scheduler,
671
834
  num_inference_steps,
672
835
  device,
673
- timesteps,
674
- sigmas,
836
+ sigmas=sigmas,
675
837
  mu=mu,
676
838
  )
677
839
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
678
840
  self._num_timesteps = len(timesteps)
679
841
 
842
+ # handle guidance
843
+ if self.transformer.config.guidance_embeds:
844
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
845
+ guidance = guidance.expand(latents.shape[0])
846
+ else:
847
+ guidance = None
848
+
849
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
850
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
851
+ ):
852
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
853
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
854
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
855
+ ):
856
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
857
+
858
+ if self.joint_attention_kwargs is None:
859
+ self._joint_attention_kwargs = {}
860
+
861
+ image_embeds = None
862
+ negative_image_embeds = None
863
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
864
+ image_embeds = self.prepare_ip_adapter_image_embeds(
865
+ ip_adapter_image,
866
+ ip_adapter_image_embeds,
867
+ device,
868
+ batch_size * num_images_per_prompt,
869
+ )
870
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
871
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
872
+ negative_ip_adapter_image,
873
+ negative_ip_adapter_image_embeds,
874
+ device,
875
+ batch_size * num_images_per_prompt,
876
+ )
877
+
680
878
  # 6. Denoising loop
681
879
  with self.progress_bar(total=num_inference_steps) as progress_bar:
682
880
  for i, t in enumerate(timesteps):
683
881
  if self.interrupt:
684
882
  continue
685
883
 
884
+ if image_embeds is not None:
885
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
686
886
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
687
887
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
688
888
 
689
- # handle guidance
690
- if self.transformer.config.guidance_embeds:
691
- guidance = torch.tensor([guidance_scale], device=device)
692
- guidance = guidance.expand(latents.shape[0])
693
- else:
694
- guidance = None
695
-
696
889
  noise_pred = self.transformer(
697
890
  hidden_states=latents,
698
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
699
891
  timestep=timestep / 1000,
700
892
  guidance=guidance,
701
893
  pooled_projections=pooled_prompt_embeds,
@@ -706,6 +898,22 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
706
898
  return_dict=False,
707
899
  )[0]
708
900
 
901
+ if do_true_cfg:
902
+ if negative_image_embeds is not None:
903
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
904
+ neg_noise_pred = self.transformer(
905
+ hidden_states=latents,
906
+ timestep=timestep / 1000,
907
+ guidance=guidance,
908
+ pooled_projections=negative_pooled_prompt_embeds,
909
+ encoder_hidden_states=negative_prompt_embeds,
910
+ txt_ids=text_ids,
911
+ img_ids=latent_image_ids,
912
+ joint_attention_kwargs=self.joint_attention_kwargs,
913
+ return_dict=False,
914
+ )[0]
915
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
916
+
709
917
  # compute the previous noisy sample x_t -> x_t-1
710
918
  latents_dtype = latents.dtype
711
919
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]