diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1134 @@
1
+ # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
30
+ from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
31
+ from ...models import AutoencoderKL, FluxTransformer2DModel
32
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
33
+ from ...utils import (
34
+ USE_PEFT_BACKEND,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from ...utils.torch_utils import randn_tensor
42
+ from ..pipeline_utils import DiffusionPipeline
43
+ from .pipeline_output import FluxPipelineOutput
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+ EXAMPLE_DOC_STRING = """
57
+ Examples:
58
+ ```py
59
+ >>> import torch
60
+ >>> from diffusers import FluxKontextPipeline
61
+ >>> from diffusers.utils import load_image
62
+
63
+ >>> pipe = FluxKontextPipeline.from_pretrained(
64
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
65
+ ... )
66
+ >>> pipe.to("cuda")
67
+
68
+ >>> image = load_image(
69
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
70
+ ... ).convert("RGB")
71
+ >>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
72
+ >>> image = pipe(
73
+ ... image=image,
74
+ ... prompt=prompt,
75
+ ... guidance_scale=2.5,
76
+ ... generator=torch.Generator().manual_seed(42),
77
+ ... ).images[0]
78
+ >>> image.save("output.png")
79
+ ```
80
+ """
81
+
82
+ PREFERRED_KONTEXT_RESOLUTIONS = [
83
+ (672, 1568),
84
+ (688, 1504),
85
+ (720, 1456),
86
+ (752, 1392),
87
+ (800, 1328),
88
+ (832, 1248),
89
+ (880, 1184),
90
+ (944, 1104),
91
+ (1024, 1024),
92
+ (1104, 944),
93
+ (1184, 880),
94
+ (1248, 832),
95
+ (1328, 800),
96
+ (1392, 752),
97
+ (1456, 720),
98
+ (1504, 688),
99
+ (1568, 672),
100
+ ]
101
+
102
+
103
+ def calculate_shift(
104
+ image_seq_len,
105
+ base_seq_len: int = 256,
106
+ max_seq_len: int = 4096,
107
+ base_shift: float = 0.5,
108
+ max_shift: float = 1.15,
109
+ ):
110
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
111
+ b = base_shift - m * base_seq_len
112
+ mu = image_seq_len * m + b
113
+ return mu
114
+
115
+
116
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
117
+ def retrieve_timesteps(
118
+ scheduler,
119
+ num_inference_steps: Optional[int] = None,
120
+ device: Optional[Union[str, torch.device]] = None,
121
+ timesteps: Optional[List[int]] = None,
122
+ sigmas: Optional[List[float]] = None,
123
+ **kwargs,
124
+ ):
125
+ r"""
126
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
127
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
128
+
129
+ Args:
130
+ scheduler (`SchedulerMixin`):
131
+ The scheduler to get timesteps from.
132
+ num_inference_steps (`int`):
133
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
134
+ must be `None`.
135
+ device (`str` or `torch.device`, *optional*):
136
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
137
+ timesteps (`List[int]`, *optional*):
138
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
139
+ `num_inference_steps` and `sigmas` must be `None`.
140
+ sigmas (`List[float]`, *optional*):
141
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
142
+ `num_inference_steps` and `timesteps` must be `None`.
143
+
144
+ Returns:
145
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
146
+ second element is the number of inference steps.
147
+ """
148
+ if timesteps is not None and sigmas is not None:
149
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
150
+ if timesteps is not None:
151
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
152
+ if not accepts_timesteps:
153
+ raise ValueError(
154
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
155
+ f" timestep schedules. Please check whether you are using the correct scheduler."
156
+ )
157
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
158
+ timesteps = scheduler.timesteps
159
+ num_inference_steps = len(timesteps)
160
+ elif sigmas is not None:
161
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
162
+ if not accept_sigmas:
163
+ raise ValueError(
164
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
165
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
166
+ )
167
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
168
+ timesteps = scheduler.timesteps
169
+ num_inference_steps = len(timesteps)
170
+ else:
171
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
172
+ timesteps = scheduler.timesteps
173
+ return timesteps, num_inference_steps
174
+
175
+
176
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
177
+ def retrieve_latents(
178
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
179
+ ):
180
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
181
+ return encoder_output.latent_dist.sample(generator)
182
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
183
+ return encoder_output.latent_dist.mode()
184
+ elif hasattr(encoder_output, "latents"):
185
+ return encoder_output.latents
186
+ else:
187
+ raise AttributeError("Could not access latents of provided encoder_output")
188
+
189
+
190
+ class FluxKontextPipeline(
191
+ DiffusionPipeline,
192
+ FluxLoraLoaderMixin,
193
+ FromSingleFileMixin,
194
+ TextualInversionLoaderMixin,
195
+ FluxIPAdapterMixin,
196
+ ):
197
+ r"""
198
+ The Flux Kontext pipeline for image-to-image and text-to-image generation.
199
+
200
+ Reference: https://bfl.ai/announcements/flux-1-kontext-dev
201
+
202
+ Args:
203
+ transformer ([`FluxTransformer2DModel`]):
204
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
205
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
206
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
207
+ vae ([`AutoencoderKL`]):
208
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
209
+ text_encoder ([`CLIPTextModel`]):
210
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
211
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
212
+ text_encoder_2 ([`T5EncoderModel`]):
213
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
214
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
215
+ tokenizer (`CLIPTokenizer`):
216
+ Tokenizer of class
217
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
218
+ tokenizer_2 (`T5TokenizerFast`):
219
+ Second Tokenizer of class
220
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
221
+ """
222
+
223
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
224
+ _optional_components = ["image_encoder", "feature_extractor"]
225
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
226
+
227
+ def __init__(
228
+ self,
229
+ scheduler: FlowMatchEulerDiscreteScheduler,
230
+ vae: AutoencoderKL,
231
+ text_encoder: CLIPTextModel,
232
+ tokenizer: CLIPTokenizer,
233
+ text_encoder_2: T5EncoderModel,
234
+ tokenizer_2: T5TokenizerFast,
235
+ transformer: FluxTransformer2DModel,
236
+ image_encoder: CLIPVisionModelWithProjection = None,
237
+ feature_extractor: CLIPImageProcessor = None,
238
+ ):
239
+ super().__init__()
240
+
241
+ self.register_modules(
242
+ vae=vae,
243
+ text_encoder=text_encoder,
244
+ text_encoder_2=text_encoder_2,
245
+ tokenizer=tokenizer,
246
+ tokenizer_2=tokenizer_2,
247
+ transformer=transformer,
248
+ scheduler=scheduler,
249
+ image_encoder=image_encoder,
250
+ feature_extractor=feature_extractor,
251
+ )
252
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
253
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
254
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
255
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
256
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
257
+ self.tokenizer_max_length = (
258
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
259
+ )
260
+ self.default_sample_size = 128
261
+
262
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
263
+ def _get_t5_prompt_embeds(
264
+ self,
265
+ prompt: Union[str, List[str]] = None,
266
+ num_images_per_prompt: int = 1,
267
+ max_sequence_length: int = 512,
268
+ device: Optional[torch.device] = None,
269
+ dtype: Optional[torch.dtype] = None,
270
+ ):
271
+ device = device or self._execution_device
272
+ dtype = dtype or self.text_encoder.dtype
273
+
274
+ prompt = [prompt] if isinstance(prompt, str) else prompt
275
+ batch_size = len(prompt)
276
+
277
+ if isinstance(self, TextualInversionLoaderMixin):
278
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
279
+
280
+ text_inputs = self.tokenizer_2(
281
+ prompt,
282
+ padding="max_length",
283
+ max_length=max_sequence_length,
284
+ truncation=True,
285
+ return_length=False,
286
+ return_overflowing_tokens=False,
287
+ return_tensors="pt",
288
+ )
289
+ text_input_ids = text_inputs.input_ids
290
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
291
+
292
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
293
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
294
+ logger.warning(
295
+ "The following part of your input was truncated because `max_sequence_length` is set to "
296
+ f" {max_sequence_length} tokens: {removed_text}"
297
+ )
298
+
299
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
300
+
301
+ dtype = self.text_encoder_2.dtype
302
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
303
+
304
+ _, seq_len, _ = prompt_embeds.shape
305
+
306
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
307
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
308
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
309
+
310
+ return prompt_embeds
311
+
312
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
313
+ def _get_clip_prompt_embeds(
314
+ self,
315
+ prompt: Union[str, List[str]],
316
+ num_images_per_prompt: int = 1,
317
+ device: Optional[torch.device] = None,
318
+ ):
319
+ device = device or self._execution_device
320
+
321
+ prompt = [prompt] if isinstance(prompt, str) else prompt
322
+ batch_size = len(prompt)
323
+
324
+ if isinstance(self, TextualInversionLoaderMixin):
325
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
326
+
327
+ text_inputs = self.tokenizer(
328
+ prompt,
329
+ padding="max_length",
330
+ max_length=self.tokenizer_max_length,
331
+ truncation=True,
332
+ return_overflowing_tokens=False,
333
+ return_length=False,
334
+ return_tensors="pt",
335
+ )
336
+
337
+ text_input_ids = text_inputs.input_ids
338
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
339
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
340
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
341
+ logger.warning(
342
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
343
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
344
+ )
345
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
346
+
347
+ # Use pooled output of CLIPTextModel
348
+ prompt_embeds = prompt_embeds.pooler_output
349
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
350
+
351
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
352
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
353
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
354
+
355
+ return prompt_embeds
356
+
357
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
358
+ def encode_prompt(
359
+ self,
360
+ prompt: Union[str, List[str]],
361
+ prompt_2: Optional[Union[str, List[str]]] = None,
362
+ device: Optional[torch.device] = None,
363
+ num_images_per_prompt: int = 1,
364
+ prompt_embeds: Optional[torch.FloatTensor] = None,
365
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
366
+ max_sequence_length: int = 512,
367
+ lora_scale: Optional[float] = None,
368
+ ):
369
+ r"""
370
+
371
+ Args:
372
+ prompt (`str` or `List[str]`, *optional*):
373
+ prompt to be encoded
374
+ prompt_2 (`str` or `List[str]`, *optional*):
375
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
376
+ used in all text-encoders
377
+ device: (`torch.device`):
378
+ torch device
379
+ num_images_per_prompt (`int`):
380
+ number of images that should be generated per prompt
381
+ prompt_embeds (`torch.FloatTensor`, *optional*):
382
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
383
+ provided, text embeddings will be generated from `prompt` input argument.
384
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
385
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
386
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
387
+ lora_scale (`float`, *optional*):
388
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
389
+ """
390
+ device = device or self._execution_device
391
+
392
+ # set lora scale so that monkey patched LoRA
393
+ # function of text encoder can correctly access it
394
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
395
+ self._lora_scale = lora_scale
396
+
397
+ # dynamically adjust the LoRA scale
398
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
399
+ scale_lora_layers(self.text_encoder, lora_scale)
400
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
401
+ scale_lora_layers(self.text_encoder_2, lora_scale)
402
+
403
+ prompt = [prompt] if isinstance(prompt, str) else prompt
404
+
405
+ if prompt_embeds is None:
406
+ prompt_2 = prompt_2 or prompt
407
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
408
+
409
+ # We only use the pooled prompt output from the CLIPTextModel
410
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
411
+ prompt=prompt,
412
+ device=device,
413
+ num_images_per_prompt=num_images_per_prompt,
414
+ )
415
+ prompt_embeds = self._get_t5_prompt_embeds(
416
+ prompt=prompt_2,
417
+ num_images_per_prompt=num_images_per_prompt,
418
+ max_sequence_length=max_sequence_length,
419
+ device=device,
420
+ )
421
+
422
+ if self.text_encoder is not None:
423
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
424
+ # Retrieve the original scale by scaling back the LoRA layers
425
+ unscale_lora_layers(self.text_encoder, lora_scale)
426
+
427
+ if self.text_encoder_2 is not None:
428
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
429
+ # Retrieve the original scale by scaling back the LoRA layers
430
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
431
+
432
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
433
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
434
+
435
+ return prompt_embeds, pooled_prompt_embeds, text_ids
436
+
437
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
438
+ def encode_image(self, image, device, num_images_per_prompt):
439
+ dtype = next(self.image_encoder.parameters()).dtype
440
+
441
+ if not isinstance(image, torch.Tensor):
442
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
443
+
444
+ image = image.to(device=device, dtype=dtype)
445
+ image_embeds = self.image_encoder(image).image_embeds
446
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
447
+ return image_embeds
448
+
449
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
450
+ def prepare_ip_adapter_image_embeds(
451
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
452
+ ):
453
+ image_embeds = []
454
+ if ip_adapter_image_embeds is None:
455
+ if not isinstance(ip_adapter_image, list):
456
+ ip_adapter_image = [ip_adapter_image]
457
+
458
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
459
+ raise ValueError(
460
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
461
+ )
462
+
463
+ for single_ip_adapter_image in ip_adapter_image:
464
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
465
+ image_embeds.append(single_image_embeds[None, :])
466
+ else:
467
+ if not isinstance(ip_adapter_image_embeds, list):
468
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
469
+
470
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
471
+ raise ValueError(
472
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
473
+ )
474
+
475
+ for single_image_embeds in ip_adapter_image_embeds:
476
+ image_embeds.append(single_image_embeds)
477
+
478
+ ip_adapter_image_embeds = []
479
+ for single_image_embeds in image_embeds:
480
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
481
+ single_image_embeds = single_image_embeds.to(device=device)
482
+ ip_adapter_image_embeds.append(single_image_embeds)
483
+
484
+ return ip_adapter_image_embeds
485
+
486
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
487
+ def check_inputs(
488
+ self,
489
+ prompt,
490
+ prompt_2,
491
+ height,
492
+ width,
493
+ negative_prompt=None,
494
+ negative_prompt_2=None,
495
+ prompt_embeds=None,
496
+ negative_prompt_embeds=None,
497
+ pooled_prompt_embeds=None,
498
+ negative_pooled_prompt_embeds=None,
499
+ callback_on_step_end_tensor_inputs=None,
500
+ max_sequence_length=None,
501
+ ):
502
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
503
+ logger.warning(
504
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
505
+ )
506
+
507
+ if callback_on_step_end_tensor_inputs is not None and not all(
508
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
509
+ ):
510
+ raise ValueError(
511
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
512
+ )
513
+
514
+ if prompt is not None and prompt_embeds is not None:
515
+ raise ValueError(
516
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
517
+ " only forward one of the two."
518
+ )
519
+ elif prompt_2 is not None and prompt_embeds is not None:
520
+ raise ValueError(
521
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
522
+ " only forward one of the two."
523
+ )
524
+ elif prompt is None and prompt_embeds is None:
525
+ raise ValueError(
526
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
527
+ )
528
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
529
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
530
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
531
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
532
+
533
+ if negative_prompt is not None and negative_prompt_embeds is not None:
534
+ raise ValueError(
535
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
536
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
537
+ )
538
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
539
+ raise ValueError(
540
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
541
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
542
+ )
543
+
544
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
545
+ raise ValueError(
546
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
547
+ )
548
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
549
+ raise ValueError(
550
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
551
+ )
552
+
553
+ if max_sequence_length is not None and max_sequence_length > 512:
554
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
555
+
556
+ @staticmethod
557
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
558
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
559
+ latent_image_ids = torch.zeros(height, width, 3)
560
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
561
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
562
+
563
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
564
+
565
+ latent_image_ids = latent_image_ids.reshape(
566
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
567
+ )
568
+
569
+ return latent_image_ids.to(device=device, dtype=dtype)
570
+
571
+ @staticmethod
572
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
573
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
574
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
575
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
576
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
577
+
578
+ return latents
579
+
580
+ @staticmethod
581
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
582
+ def _unpack_latents(latents, height, width, vae_scale_factor):
583
+ batch_size, num_patches, channels = latents.shape
584
+
585
+ # VAE applies 8x compression on images but we must also account for packing which requires
586
+ # latent height and width to be divisible by 2.
587
+ height = 2 * (int(height) // (vae_scale_factor * 2))
588
+ width = 2 * (int(width) // (vae_scale_factor * 2))
589
+
590
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
591
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
592
+
593
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
594
+
595
+ return latents
596
+
597
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
598
+ if isinstance(generator, list):
599
+ image_latents = [
600
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
601
+ for i in range(image.shape[0])
602
+ ]
603
+ image_latents = torch.cat(image_latents, dim=0)
604
+ else:
605
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
606
+
607
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
608
+
609
+ return image_latents
610
+
611
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
612
+ def enable_vae_slicing(self):
613
+ r"""
614
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
615
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
616
+ """
617
+ self.vae.enable_slicing()
618
+
619
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
620
+ def disable_vae_slicing(self):
621
+ r"""
622
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
623
+ computing decoding in one step.
624
+ """
625
+ self.vae.disable_slicing()
626
+
627
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
628
+ def enable_vae_tiling(self):
629
+ r"""
630
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
631
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
632
+ processing larger images.
633
+ """
634
+ self.vae.enable_tiling()
635
+
636
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
637
+ def disable_vae_tiling(self):
638
+ r"""
639
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
640
+ computing decoding in one step.
641
+ """
642
+ self.vae.disable_tiling()
643
+
644
+ def prepare_latents(
645
+ self,
646
+ image: Optional[torch.Tensor],
647
+ batch_size: int,
648
+ num_channels_latents: int,
649
+ height: int,
650
+ width: int,
651
+ dtype: torch.dtype,
652
+ device: torch.device,
653
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
654
+ latents: Optional[torch.Tensor] = None,
655
+ ):
656
+ if isinstance(generator, list) and len(generator) != batch_size:
657
+ raise ValueError(
658
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
659
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
660
+ )
661
+
662
+ # VAE applies 8x compression on images but we must also account for packing which requires
663
+ # latent height and width to be divisible by 2.
664
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
665
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
666
+ shape = (batch_size, num_channels_latents, height, width)
667
+
668
+ image_latents = image_ids = None
669
+ if image is not None:
670
+ image = image.to(device=device, dtype=dtype)
671
+ if image.shape[1] != self.latent_channels:
672
+ image_latents = self._encode_vae_image(image=image, generator=generator)
673
+ else:
674
+ image_latents = image
675
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
676
+ # expand init_latents for batch_size
677
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
678
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
679
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
680
+ raise ValueError(
681
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
682
+ )
683
+ else:
684
+ image_latents = torch.cat([image_latents], dim=0)
685
+
686
+ image_latent_height, image_latent_width = image_latents.shape[2:]
687
+ image_latents = self._pack_latents(
688
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
689
+ )
690
+ image_ids = self._prepare_latent_image_ids(
691
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
692
+ )
693
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
694
+ image_ids[..., 0] = 1
695
+
696
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
697
+
698
+ if latents is None:
699
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
700
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
701
+ else:
702
+ latents = latents.to(device=device, dtype=dtype)
703
+
704
+ return latents, image_latents, latent_ids, image_ids
705
+
706
+ @property
707
+ def guidance_scale(self):
708
+ return self._guidance_scale
709
+
710
+ @property
711
+ def joint_attention_kwargs(self):
712
+ return self._joint_attention_kwargs
713
+
714
+ @property
715
+ def num_timesteps(self):
716
+ return self._num_timesteps
717
+
718
+ @property
719
+ def current_timestep(self):
720
+ return self._current_timestep
721
+
722
+ @property
723
+ def interrupt(self):
724
+ return self._interrupt
725
+
726
+ @torch.no_grad()
727
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
728
+ def __call__(
729
+ self,
730
+ image: Optional[PipelineImageInput] = None,
731
+ prompt: Union[str, List[str]] = None,
732
+ prompt_2: Optional[Union[str, List[str]]] = None,
733
+ negative_prompt: Union[str, List[str]] = None,
734
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
735
+ true_cfg_scale: float = 1.0,
736
+ height: Optional[int] = None,
737
+ width: Optional[int] = None,
738
+ num_inference_steps: int = 28,
739
+ sigmas: Optional[List[float]] = None,
740
+ guidance_scale: float = 3.5,
741
+ num_images_per_prompt: Optional[int] = 1,
742
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
743
+ latents: Optional[torch.FloatTensor] = None,
744
+ prompt_embeds: Optional[torch.FloatTensor] = None,
745
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
746
+ ip_adapter_image: Optional[PipelineImageInput] = None,
747
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
748
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
749
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
750
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
751
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
752
+ output_type: Optional[str] = "pil",
753
+ return_dict: bool = True,
754
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
755
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
756
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
757
+ max_sequence_length: int = 512,
758
+ max_area: int = 1024**2,
759
+ _auto_resize: bool = True,
760
+ ):
761
+ r"""
762
+ Function invoked when calling the pipeline for generation.
763
+
764
+ Args:
765
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
766
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
767
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
768
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
769
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
770
+ latents as `image`, but if passing latents directly it is not encoded again.
771
+ prompt (`str` or `List[str]`, *optional*):
772
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
773
+ instead.
774
+ prompt_2 (`str` or `List[str]`, *optional*):
775
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
776
+ will be used instead.
777
+ negative_prompt (`str` or `List[str]`, *optional*):
778
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
779
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
780
+ not greater than `1`).
781
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
782
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
783
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
784
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
785
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
786
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
787
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
788
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
789
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
790
+ num_inference_steps (`int`, *optional*, defaults to 50):
791
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
792
+ expense of slower inference.
793
+ sigmas (`List[float]`, *optional*):
794
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
795
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
796
+ will be used.
797
+ guidance_scale (`float`, *optional*, defaults to 3.5):
798
+ Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
799
+ a model to generate images more aligned with prompt at the expense of lower image quality.
800
+
801
+ Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
802
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
803
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
804
+ The number of images to generate per prompt.
805
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
806
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
807
+ to make generation deterministic.
808
+ latents (`torch.FloatTensor`, *optional*):
809
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
810
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
811
+ tensor will ge generated by sampling using the supplied random `generator`.
812
+ prompt_embeds (`torch.FloatTensor`, *optional*):
813
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
814
+ provided, text embeddings will be generated from `prompt` input argument.
815
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
816
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
817
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
818
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
819
+ Optional image input to work with IP Adapters.
820
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
821
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
822
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
823
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
824
+ negative_ip_adapter_image:
825
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
826
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
827
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
828
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
829
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
830
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
831
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
832
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
833
+ argument.
834
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
835
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
836
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
837
+ input argument.
838
+ output_type (`str`, *optional*, defaults to `"pil"`):
839
+ The output format of the generate image. Choose between
840
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
841
+ return_dict (`bool`, *optional*, defaults to `True`):
842
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
843
+ joint_attention_kwargs (`dict`, *optional*):
844
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
845
+ `self.processor` in
846
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
847
+ callback_on_step_end (`Callable`, *optional*):
848
+ A function that calls at the end of each denoising steps during the inference. The function is called
849
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
850
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
851
+ `callback_on_step_end_tensor_inputs`.
852
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
853
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
854
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
855
+ `._callback_tensor_inputs` attribute of your pipeline class.
856
+ max_sequence_length (`int` defaults to 512):
857
+ Maximum sequence length to use with the `prompt`.
858
+ max_area (`int`, defaults to `1024 ** 2`):
859
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
860
+ area while maintaining the aspect ratio.
861
+
862
+ Examples:
863
+
864
+ Returns:
865
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
866
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
867
+ images.
868
+ """
869
+
870
+ height = height or self.default_sample_size * self.vae_scale_factor
871
+ width = width or self.default_sample_size * self.vae_scale_factor
872
+
873
+ original_height, original_width = height, width
874
+ aspect_ratio = width / height
875
+ width = round((max_area * aspect_ratio) ** 0.5)
876
+ height = round((max_area / aspect_ratio) ** 0.5)
877
+
878
+ multiple_of = self.vae_scale_factor * 2
879
+ width = width // multiple_of * multiple_of
880
+ height = height // multiple_of * multiple_of
881
+
882
+ if height != original_height or width != original_width:
883
+ logger.warning(
884
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
885
+ )
886
+
887
+ # 1. Check inputs. Raise error if not correct
888
+ self.check_inputs(
889
+ prompt,
890
+ prompt_2,
891
+ height,
892
+ width,
893
+ negative_prompt=negative_prompt,
894
+ negative_prompt_2=negative_prompt_2,
895
+ prompt_embeds=prompt_embeds,
896
+ negative_prompt_embeds=negative_prompt_embeds,
897
+ pooled_prompt_embeds=pooled_prompt_embeds,
898
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
899
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
900
+ max_sequence_length=max_sequence_length,
901
+ )
902
+
903
+ self._guidance_scale = guidance_scale
904
+ self._joint_attention_kwargs = joint_attention_kwargs
905
+ self._current_timestep = None
906
+ self._interrupt = False
907
+
908
+ # 2. Define call parameters
909
+ if prompt is not None and isinstance(prompt, str):
910
+ batch_size = 1
911
+ elif prompt is not None and isinstance(prompt, list):
912
+ batch_size = len(prompt)
913
+ else:
914
+ batch_size = prompt_embeds.shape[0]
915
+
916
+ device = self._execution_device
917
+
918
+ lora_scale = (
919
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
920
+ )
921
+ has_neg_prompt = negative_prompt is not None or (
922
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
923
+ )
924
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
925
+ (
926
+ prompt_embeds,
927
+ pooled_prompt_embeds,
928
+ text_ids,
929
+ ) = self.encode_prompt(
930
+ prompt=prompt,
931
+ prompt_2=prompt_2,
932
+ prompt_embeds=prompt_embeds,
933
+ pooled_prompt_embeds=pooled_prompt_embeds,
934
+ device=device,
935
+ num_images_per_prompt=num_images_per_prompt,
936
+ max_sequence_length=max_sequence_length,
937
+ lora_scale=lora_scale,
938
+ )
939
+ if do_true_cfg:
940
+ (
941
+ negative_prompt_embeds,
942
+ negative_pooled_prompt_embeds,
943
+ negative_text_ids,
944
+ ) = self.encode_prompt(
945
+ prompt=negative_prompt,
946
+ prompt_2=negative_prompt_2,
947
+ prompt_embeds=negative_prompt_embeds,
948
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
949
+ device=device,
950
+ num_images_per_prompt=num_images_per_prompt,
951
+ max_sequence_length=max_sequence_length,
952
+ lora_scale=lora_scale,
953
+ )
954
+
955
+ # 3. Preprocess image
956
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
957
+ img = image[0] if isinstance(image, list) else image
958
+ image_height, image_width = self.image_processor.get_default_height_width(img)
959
+ aspect_ratio = image_width / image_height
960
+ if _auto_resize:
961
+ # Kontext is trained on specific resolutions, using one of them is recommended
962
+ _, image_width, image_height = min(
963
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
964
+ )
965
+ image_width = image_width // multiple_of * multiple_of
966
+ image_height = image_height // multiple_of * multiple_of
967
+ image = self.image_processor.resize(image, image_height, image_width)
968
+ image = self.image_processor.preprocess(image, image_height, image_width)
969
+
970
+ # 4. Prepare latent variables
971
+ num_channels_latents = self.transformer.config.in_channels // 4
972
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
973
+ image,
974
+ batch_size * num_images_per_prompt,
975
+ num_channels_latents,
976
+ height,
977
+ width,
978
+ prompt_embeds.dtype,
979
+ device,
980
+ generator,
981
+ latents,
982
+ )
983
+ if image_ids is not None:
984
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
985
+
986
+ # 5. Prepare timesteps
987
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
988
+ image_seq_len = latents.shape[1]
989
+ mu = calculate_shift(
990
+ image_seq_len,
991
+ self.scheduler.config.get("base_image_seq_len", 256),
992
+ self.scheduler.config.get("max_image_seq_len", 4096),
993
+ self.scheduler.config.get("base_shift", 0.5),
994
+ self.scheduler.config.get("max_shift", 1.15),
995
+ )
996
+ timesteps, num_inference_steps = retrieve_timesteps(
997
+ self.scheduler,
998
+ num_inference_steps,
999
+ device,
1000
+ sigmas=sigmas,
1001
+ mu=mu,
1002
+ )
1003
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1004
+ self._num_timesteps = len(timesteps)
1005
+
1006
+ # handle guidance
1007
+ if self.transformer.config.guidance_embeds:
1008
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1009
+ guidance = guidance.expand(latents.shape[0])
1010
+ else:
1011
+ guidance = None
1012
+
1013
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
1014
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1015
+ ):
1016
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1017
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1018
+
1019
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
1020
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1021
+ ):
1022
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1023
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1024
+
1025
+ if self.joint_attention_kwargs is None:
1026
+ self._joint_attention_kwargs = {}
1027
+
1028
+ image_embeds = None
1029
+ negative_image_embeds = None
1030
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1031
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1032
+ ip_adapter_image,
1033
+ ip_adapter_image_embeds,
1034
+ device,
1035
+ batch_size * num_images_per_prompt,
1036
+ )
1037
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
1038
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
1039
+ negative_ip_adapter_image,
1040
+ negative_ip_adapter_image_embeds,
1041
+ device,
1042
+ batch_size * num_images_per_prompt,
1043
+ )
1044
+
1045
+ # 6. Denoising loop
1046
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
1047
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
1048
+ self.scheduler.set_begin_index(0)
1049
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1050
+ for i, t in enumerate(timesteps):
1051
+ if self.interrupt:
1052
+ continue
1053
+
1054
+ self._current_timestep = t
1055
+ if image_embeds is not None:
1056
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
1057
+
1058
+ latent_model_input = latents
1059
+ if image_latents is not None:
1060
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
1061
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1062
+
1063
+ noise_pred = self.transformer(
1064
+ hidden_states=latent_model_input,
1065
+ timestep=timestep / 1000,
1066
+ guidance=guidance,
1067
+ pooled_projections=pooled_prompt_embeds,
1068
+ encoder_hidden_states=prompt_embeds,
1069
+ txt_ids=text_ids,
1070
+ img_ids=latent_ids,
1071
+ joint_attention_kwargs=self.joint_attention_kwargs,
1072
+ return_dict=False,
1073
+ )[0]
1074
+ noise_pred = noise_pred[:, : latents.size(1)]
1075
+
1076
+ if do_true_cfg:
1077
+ if negative_image_embeds is not None:
1078
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1079
+ neg_noise_pred = self.transformer(
1080
+ hidden_states=latent_model_input,
1081
+ timestep=timestep / 1000,
1082
+ guidance=guidance,
1083
+ pooled_projections=negative_pooled_prompt_embeds,
1084
+ encoder_hidden_states=negative_prompt_embeds,
1085
+ txt_ids=negative_text_ids,
1086
+ img_ids=latent_ids,
1087
+ joint_attention_kwargs=self.joint_attention_kwargs,
1088
+ return_dict=False,
1089
+ )[0]
1090
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
1091
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1092
+
1093
+ # compute the previous noisy sample x_t -> x_t-1
1094
+ latents_dtype = latents.dtype
1095
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1096
+
1097
+ if latents.dtype != latents_dtype:
1098
+ if torch.backends.mps.is_available():
1099
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1100
+ latents = latents.to(latents_dtype)
1101
+
1102
+ if callback_on_step_end is not None:
1103
+ callback_kwargs = {}
1104
+ for k in callback_on_step_end_tensor_inputs:
1105
+ callback_kwargs[k] = locals()[k]
1106
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1107
+
1108
+ latents = callback_outputs.pop("latents", latents)
1109
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1110
+
1111
+ # call the callback, if provided
1112
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1113
+ progress_bar.update()
1114
+
1115
+ if XLA_AVAILABLE:
1116
+ xm.mark_step()
1117
+
1118
+ self._current_timestep = None
1119
+
1120
+ if output_type == "latent":
1121
+ image = latents
1122
+ else:
1123
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1124
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1125
+ image = self.vae.decode(latents, return_dict=False)[0]
1126
+ image = self.image_processor.postprocess(image, output_type=output_type)
1127
+
1128
+ # Offload all models
1129
+ self.maybe_free_model_hooks()
1130
+
1131
+ if not return_dict:
1132
+ return (image,)
1133
+
1134
+ return FluxPipelineOutput(images=image)