diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1460 @@
1
+ # Copyright 2025 ZenAI. All rights reserved.
2
+ # author: @vuongminh1907
3
+
4
+ import inspect
5
+ from typing import Any, Callable, Dict, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import PIL.Image
9
+ import torch
10
+ from transformers import (
11
+ CLIPImageProcessor,
12
+ CLIPTextModel,
13
+ CLIPTokenizer,
14
+ CLIPVisionModelWithProjection,
15
+ T5EncoderModel,
16
+ T5TokenizerFast,
17
+ )
18
+
19
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
20
+ from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
21
+ from ...models import AutoencoderKL, FluxTransformer2DModel
22
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
23
+ from ...utils import (
24
+ USE_PEFT_BACKEND,
25
+ is_torch_xla_available,
26
+ logging,
27
+ replace_example_docstring,
28
+ scale_lora_layers,
29
+ unscale_lora_layers,
30
+ )
31
+ from ...utils.torch_utils import randn_tensor
32
+ from ..pipeline_utils import DiffusionPipeline
33
+ from .pipeline_output import FluxPipelineOutput
34
+
35
+
36
+ if is_torch_xla_available():
37
+ import torch_xla.core.xla_model as xm
38
+
39
+ XLA_AVAILABLE = True
40
+ else:
41
+ XLA_AVAILABLE = False
42
+
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+ EXAMPLE_DOC_STRING = """
47
+ Examples:
48
+ # Inpainting with text only
49
+ ```py
50
+ >>> import torch
51
+ >>> from diffusers import FluxKontextInpaintPipeline
52
+ >>> from diffusers.utils import load_image
53
+
54
+ >>> prompt = "Change the yellow dinosaur to green one"
55
+ >>> img_url = (
56
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
57
+ ... )
58
+ >>> mask_url = (
59
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
60
+ ... )
61
+
62
+ >>> source = load_image(img_url)
63
+ >>> mask = load_image(mask_url)
64
+
65
+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
66
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
67
+ ... )
68
+ >>> pipe.to("cuda")
69
+
70
+ >>> image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]
71
+ >>> image.save("kontext_inpainting_normal.png")
72
+ ```
73
+
74
+ # Inpainting with image conditioning
75
+ ```py
76
+ >>> import torch
77
+ >>> from diffusers import FluxKontextInpaintPipeline
78
+ >>> from diffusers.utils import load_image
79
+
80
+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
81
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
82
+ ... )
83
+ >>> pipe.to("cuda")
84
+
85
+ >>> prompt = "Replace this ball"
86
+ >>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
87
+ >>> mask_url = (
88
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
89
+ ... )
90
+ >>> image_reference_url = (
91
+ ... "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
92
+ ... )
93
+
94
+ >>> source = load_image(img_url)
95
+ >>> mask = load_image(mask_url)
96
+ >>> image_reference = load_image(image_reference_url)
97
+
98
+ >>> mask = pipe.mask_processor.blur(mask, blur_factor=12)
99
+ >>> image = pipe(
100
+ ... prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0
101
+ ... ).images[0]
102
+ >>> image.save("kontext_inpainting_ref.png")
103
+ ```
104
+ """
105
+
106
+ PREFERRED_KONTEXT_RESOLUTIONS = [
107
+ (672, 1568),
108
+ (688, 1504),
109
+ (720, 1456),
110
+ (752, 1392),
111
+ (800, 1328),
112
+ (832, 1248),
113
+ (880, 1184),
114
+ (944, 1104),
115
+ (1024, 1024),
116
+ (1104, 944),
117
+ (1184, 880),
118
+ (1248, 832),
119
+ (1328, 800),
120
+ (1392, 752),
121
+ (1456, 720),
122
+ (1504, 688),
123
+ (1568, 672),
124
+ ]
125
+
126
+
127
+ def calculate_shift(
128
+ image_seq_len,
129
+ base_seq_len: int = 256,
130
+ max_seq_len: int = 4096,
131
+ base_shift: float = 0.5,
132
+ max_shift: float = 1.15,
133
+ ):
134
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
135
+ b = base_shift - m * base_seq_len
136
+ mu = image_seq_len * m + b
137
+ return mu
138
+
139
+
140
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
141
+ def retrieve_timesteps(
142
+ scheduler,
143
+ num_inference_steps: Optional[int] = None,
144
+ device: Optional[Union[str, torch.device]] = None,
145
+ timesteps: Optional[List[int]] = None,
146
+ sigmas: Optional[List[float]] = None,
147
+ **kwargs,
148
+ ):
149
+ r"""
150
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
151
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
152
+
153
+ Args:
154
+ scheduler (`SchedulerMixin`):
155
+ The scheduler to get timesteps from.
156
+ num_inference_steps (`int`):
157
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
158
+ must be `None`.
159
+ device (`str` or `torch.device`, *optional*):
160
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
161
+ timesteps (`List[int]`, *optional*):
162
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
163
+ `num_inference_steps` and `sigmas` must be `None`.
164
+ sigmas (`List[float]`, *optional*):
165
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
166
+ `num_inference_steps` and `timesteps` must be `None`.
167
+
168
+ Returns:
169
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
170
+ second element is the number of inference steps.
171
+ """
172
+ if timesteps is not None and sigmas is not None:
173
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
174
+ if timesteps is not None:
175
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
176
+ if not accepts_timesteps:
177
+ raise ValueError(
178
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
179
+ f" timestep schedules. Please check whether you are using the correct scheduler."
180
+ )
181
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
182
+ timesteps = scheduler.timesteps
183
+ num_inference_steps = len(timesteps)
184
+ elif sigmas is not None:
185
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
186
+ if not accept_sigmas:
187
+ raise ValueError(
188
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
189
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
190
+ )
191
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
192
+ timesteps = scheduler.timesteps
193
+ num_inference_steps = len(timesteps)
194
+ else:
195
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
196
+ timesteps = scheduler.timesteps
197
+ return timesteps, num_inference_steps
198
+
199
+
200
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
201
+ def retrieve_latents(
202
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
203
+ ):
204
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
205
+ return encoder_output.latent_dist.sample(generator)
206
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
207
+ return encoder_output.latent_dist.mode()
208
+ elif hasattr(encoder_output, "latents"):
209
+ return encoder_output.latents
210
+ else:
211
+ raise AttributeError("Could not access latents of provided encoder_output")
212
+
213
+
214
+ class FluxKontextInpaintPipeline(
215
+ DiffusionPipeline,
216
+ FluxLoraLoaderMixin,
217
+ FromSingleFileMixin,
218
+ TextualInversionLoaderMixin,
219
+ FluxIPAdapterMixin,
220
+ ):
221
+ r"""
222
+ The Flux Kontext pipeline for text-to-image generation.
223
+
224
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
225
+
226
+ Args:
227
+ transformer ([`FluxTransformer2DModel`]):
228
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
229
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
230
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
231
+ vae ([`AutoencoderKL`]):
232
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
233
+ text_encoder ([`CLIPTextModel`]):
234
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
235
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
236
+ text_encoder_2 ([`T5EncoderModel`]):
237
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
238
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
239
+ tokenizer (`CLIPTokenizer`):
240
+ Tokenizer of class
241
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
242
+ tokenizer_2 (`T5TokenizerFast`):
243
+ Second Tokenizer of class
244
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
245
+ """
246
+
247
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
248
+ _optional_components = ["image_encoder", "feature_extractor"]
249
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
250
+
251
+ def __init__(
252
+ self,
253
+ scheduler: FlowMatchEulerDiscreteScheduler,
254
+ vae: AutoencoderKL,
255
+ text_encoder: CLIPTextModel,
256
+ tokenizer: CLIPTokenizer,
257
+ text_encoder_2: T5EncoderModel,
258
+ tokenizer_2: T5TokenizerFast,
259
+ transformer: FluxTransformer2DModel,
260
+ image_encoder: CLIPVisionModelWithProjection = None,
261
+ feature_extractor: CLIPImageProcessor = None,
262
+ ):
263
+ super().__init__()
264
+
265
+ self.register_modules(
266
+ vae=vae,
267
+ text_encoder=text_encoder,
268
+ text_encoder_2=text_encoder_2,
269
+ tokenizer=tokenizer,
270
+ tokenizer_2=tokenizer_2,
271
+ transformer=transformer,
272
+ scheduler=scheduler,
273
+ image_encoder=image_encoder,
274
+ feature_extractor=feature_extractor,
275
+ )
276
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
277
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
278
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
279
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
280
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
281
+
282
+ self.mask_processor = VaeImageProcessor(
283
+ vae_scale_factor=self.vae_scale_factor * 2,
284
+ vae_latent_channels=self.latent_channels,
285
+ do_normalize=False,
286
+ do_binarize=True,
287
+ do_convert_grayscale=True,
288
+ )
289
+
290
+ self.tokenizer_max_length = (
291
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
292
+ )
293
+ self.default_sample_size = 128
294
+
295
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
296
+ def _get_t5_prompt_embeds(
297
+ self,
298
+ prompt: Union[str, List[str]] = None,
299
+ num_images_per_prompt: int = 1,
300
+ max_sequence_length: int = 512,
301
+ device: Optional[torch.device] = None,
302
+ dtype: Optional[torch.dtype] = None,
303
+ ):
304
+ device = device or self._execution_device
305
+ dtype = dtype or self.text_encoder.dtype
306
+
307
+ prompt = [prompt] if isinstance(prompt, str) else prompt
308
+ batch_size = len(prompt)
309
+
310
+ if isinstance(self, TextualInversionLoaderMixin):
311
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
312
+
313
+ text_inputs = self.tokenizer_2(
314
+ prompt,
315
+ padding="max_length",
316
+ max_length=max_sequence_length,
317
+ truncation=True,
318
+ return_length=False,
319
+ return_overflowing_tokens=False,
320
+ return_tensors="pt",
321
+ )
322
+ text_input_ids = text_inputs.input_ids
323
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
324
+
325
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
326
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
327
+ logger.warning(
328
+ "The following part of your input was truncated because `max_sequence_length` is set to "
329
+ f" {max_sequence_length} tokens: {removed_text}"
330
+ )
331
+
332
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
333
+
334
+ dtype = self.text_encoder_2.dtype
335
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
336
+
337
+ _, seq_len, _ = prompt_embeds.shape
338
+
339
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
340
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
341
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
342
+
343
+ return prompt_embeds
344
+
345
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
346
+ def _get_clip_prompt_embeds(
347
+ self,
348
+ prompt: Union[str, List[str]],
349
+ num_images_per_prompt: int = 1,
350
+ device: Optional[torch.device] = None,
351
+ ):
352
+ device = device or self._execution_device
353
+
354
+ prompt = [prompt] if isinstance(prompt, str) else prompt
355
+ batch_size = len(prompt)
356
+
357
+ if isinstance(self, TextualInversionLoaderMixin):
358
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
359
+
360
+ text_inputs = self.tokenizer(
361
+ prompt,
362
+ padding="max_length",
363
+ max_length=self.tokenizer_max_length,
364
+ truncation=True,
365
+ return_overflowing_tokens=False,
366
+ return_length=False,
367
+ return_tensors="pt",
368
+ )
369
+
370
+ text_input_ids = text_inputs.input_ids
371
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
372
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
373
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
374
+ logger.warning(
375
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
376
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
377
+ )
378
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
379
+
380
+ # Use pooled output of CLIPTextModel
381
+ prompt_embeds = prompt_embeds.pooler_output
382
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
383
+
384
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
385
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
386
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
387
+
388
+ return prompt_embeds
389
+
390
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
391
+ def encode_prompt(
392
+ self,
393
+ prompt: Union[str, List[str]],
394
+ prompt_2: Optional[Union[str, List[str]]] = None,
395
+ device: Optional[torch.device] = None,
396
+ num_images_per_prompt: int = 1,
397
+ prompt_embeds: Optional[torch.FloatTensor] = None,
398
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
399
+ max_sequence_length: int = 512,
400
+ lora_scale: Optional[float] = None,
401
+ ):
402
+ r"""
403
+
404
+ Args:
405
+ prompt (`str` or `List[str]`, *optional*):
406
+ prompt to be encoded
407
+ prompt_2 (`str` or `List[str]`, *optional*):
408
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
409
+ used in all text-encoders
410
+ device: (`torch.device`):
411
+ torch device
412
+ num_images_per_prompt (`int`):
413
+ number of images that should be generated per prompt
414
+ prompt_embeds (`torch.FloatTensor`, *optional*):
415
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
416
+ provided, text embeddings will be generated from `prompt` input argument.
417
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
418
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
419
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
420
+ lora_scale (`float`, *optional*):
421
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
422
+ """
423
+ device = device or self._execution_device
424
+
425
+ # set lora scale so that monkey patched LoRA
426
+ # function of text encoder can correctly access it
427
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
428
+ self._lora_scale = lora_scale
429
+
430
+ # dynamically adjust the LoRA scale
431
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
432
+ scale_lora_layers(self.text_encoder, lora_scale)
433
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
434
+ scale_lora_layers(self.text_encoder_2, lora_scale)
435
+
436
+ prompt = [prompt] if isinstance(prompt, str) else prompt
437
+
438
+ if prompt_embeds is None:
439
+ prompt_2 = prompt_2 or prompt
440
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
441
+
442
+ # We only use the pooled prompt output from the CLIPTextModel
443
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
444
+ prompt=prompt,
445
+ device=device,
446
+ num_images_per_prompt=num_images_per_prompt,
447
+ )
448
+ prompt_embeds = self._get_t5_prompt_embeds(
449
+ prompt=prompt_2,
450
+ num_images_per_prompt=num_images_per_prompt,
451
+ max_sequence_length=max_sequence_length,
452
+ device=device,
453
+ )
454
+
455
+ if self.text_encoder is not None:
456
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
457
+ # Retrieve the original scale by scaling back the LoRA layers
458
+ unscale_lora_layers(self.text_encoder, lora_scale)
459
+
460
+ if self.text_encoder_2 is not None:
461
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
462
+ # Retrieve the original scale by scaling back the LoRA layers
463
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
464
+
465
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
466
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
467
+
468
+ return prompt_embeds, pooled_prompt_embeds, text_ids
469
+
470
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
471
+ def encode_image(self, image, device, num_images_per_prompt):
472
+ dtype = next(self.image_encoder.parameters()).dtype
473
+
474
+ if not isinstance(image, torch.Tensor):
475
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
476
+
477
+ image = image.to(device=device, dtype=dtype)
478
+ image_embeds = self.image_encoder(image).image_embeds
479
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
480
+ return image_embeds
481
+
482
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
483
+ def prepare_ip_adapter_image_embeds(
484
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
485
+ ):
486
+ image_embeds = []
487
+ if ip_adapter_image_embeds is None:
488
+ if not isinstance(ip_adapter_image, list):
489
+ ip_adapter_image = [ip_adapter_image]
490
+
491
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
492
+ raise ValueError(
493
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
494
+ )
495
+
496
+ for single_ip_adapter_image in ip_adapter_image:
497
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
498
+ image_embeds.append(single_image_embeds[None, :])
499
+ else:
500
+ if not isinstance(ip_adapter_image_embeds, list):
501
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
502
+
503
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
504
+ raise ValueError(
505
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
506
+ )
507
+
508
+ for single_image_embeds in ip_adapter_image_embeds:
509
+ image_embeds.append(single_image_embeds)
510
+
511
+ ip_adapter_image_embeds = []
512
+ for single_image_embeds in image_embeds:
513
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
514
+ single_image_embeds = single_image_embeds.to(device=device)
515
+ ip_adapter_image_embeds.append(single_image_embeds)
516
+
517
+ return ip_adapter_image_embeds
518
+
519
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
520
+ def get_timesteps(self, num_inference_steps, strength, device):
521
+ # get the original timestep using init_timestep
522
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
523
+
524
+ t_start = int(max(num_inference_steps - init_timestep, 0))
525
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
526
+ if hasattr(self.scheduler, "set_begin_index"):
527
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
528
+
529
+ return timesteps, num_inference_steps - t_start
530
+
531
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.check_inputs
532
+ def check_inputs(
533
+ self,
534
+ prompt,
535
+ prompt_2,
536
+ image,
537
+ mask_image,
538
+ strength,
539
+ height,
540
+ width,
541
+ output_type,
542
+ negative_prompt=None,
543
+ negative_prompt_2=None,
544
+ prompt_embeds=None,
545
+ negative_prompt_embeds=None,
546
+ pooled_prompt_embeds=None,
547
+ negative_pooled_prompt_embeds=None,
548
+ callback_on_step_end_tensor_inputs=None,
549
+ padding_mask_crop=None,
550
+ max_sequence_length=None,
551
+ ):
552
+ if strength < 0 or strength > 1:
553
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
554
+
555
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
556
+ logger.warning(
557
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
558
+ )
559
+
560
+ if callback_on_step_end_tensor_inputs is not None and not all(
561
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
562
+ ):
563
+ raise ValueError(
564
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
565
+ )
566
+
567
+ if prompt is not None and prompt_embeds is not None:
568
+ raise ValueError(
569
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
570
+ " only forward one of the two."
571
+ )
572
+ elif prompt_2 is not None and prompt_embeds is not None:
573
+ raise ValueError(
574
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
575
+ " only forward one of the two."
576
+ )
577
+ elif prompt is None and prompt_embeds is None:
578
+ raise ValueError(
579
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
580
+ )
581
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
582
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
583
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
584
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
585
+
586
+ if negative_prompt is not None and negative_prompt_embeds is not None:
587
+ raise ValueError(
588
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
589
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
590
+ )
591
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
592
+ raise ValueError(
593
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
594
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
595
+ )
596
+
597
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
598
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
599
+ raise ValueError(
600
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
601
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
602
+ f" {negative_prompt_embeds.shape}."
603
+ )
604
+
605
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
606
+ raise ValueError(
607
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
608
+ )
609
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
610
+ raise ValueError(
611
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
612
+ )
613
+
614
+ if padding_mask_crop is not None:
615
+ if not isinstance(image, PIL.Image.Image):
616
+ raise ValueError(
617
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
618
+ )
619
+ if not isinstance(mask_image, PIL.Image.Image):
620
+ raise ValueError(
621
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
622
+ f" {type(mask_image)}."
623
+ )
624
+ if output_type != "pil":
625
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
626
+
627
+ if max_sequence_length is not None and max_sequence_length > 512:
628
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
629
+
630
+ @staticmethod
631
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
632
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
633
+ latent_image_ids = torch.zeros(height, width, 3)
634
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
635
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
636
+
637
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
638
+
639
+ latent_image_ids = latent_image_ids.reshape(
640
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
641
+ )
642
+
643
+ return latent_image_ids.to(device=device, dtype=dtype)
644
+
645
+ @staticmethod
646
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
647
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
648
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
649
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
650
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
651
+
652
+ return latents
653
+
654
+ @staticmethod
655
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
656
+ def _unpack_latents(latents, height, width, vae_scale_factor):
657
+ batch_size, num_patches, channels = latents.shape
658
+
659
+ # VAE applies 8x compression on images but we must also account for packing which requires
660
+ # latent height and width to be divisible by 2.
661
+ height = 2 * (int(height) // (vae_scale_factor * 2))
662
+ width = 2 * (int(width) // (vae_scale_factor * 2))
663
+
664
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
665
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
666
+
667
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
668
+
669
+ return latents
670
+
671
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
672
+ if isinstance(generator, list):
673
+ image_latents = [
674
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
675
+ for i in range(image.shape[0])
676
+ ]
677
+ image_latents = torch.cat(image_latents, dim=0)
678
+ else:
679
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
680
+
681
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
682
+
683
+ return image_latents
684
+
685
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
686
+ def enable_vae_slicing(self):
687
+ r"""
688
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
689
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
690
+ """
691
+ self.vae.enable_slicing()
692
+
693
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
694
+ def disable_vae_slicing(self):
695
+ r"""
696
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
697
+ computing decoding in one step.
698
+ """
699
+ self.vae.disable_slicing()
700
+
701
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
702
+ def enable_vae_tiling(self):
703
+ r"""
704
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
705
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
706
+ processing larger images.
707
+ """
708
+ self.vae.enable_tiling()
709
+
710
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
711
+ def disable_vae_tiling(self):
712
+ r"""
713
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
714
+ computing decoding in one step.
715
+ """
716
+ self.vae.disable_tiling()
717
+
718
+ def prepare_latents(
719
+ self,
720
+ image: Optional[torch.Tensor],
721
+ timestep: int,
722
+ batch_size: int,
723
+ num_channels_latents: int,
724
+ height: int,
725
+ width: int,
726
+ dtype: torch.dtype,
727
+ device: torch.device,
728
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
729
+ latents: Optional[torch.Tensor] = None,
730
+ image_reference: Optional[torch.Tensor] = None,
731
+ ):
732
+ if isinstance(generator, list) and len(generator) != batch_size:
733
+ raise ValueError(
734
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
735
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
736
+ )
737
+
738
+ # VAE applies 8x compression on images but we must also account for packing which requires
739
+ # latent height and width to be divisible by 2.
740
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
741
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
742
+ shape = (batch_size, num_channels_latents, height, width)
743
+
744
+ # Prepare image latents
745
+ image_latents = image_ids = None
746
+ if image is not None:
747
+ image = image.to(device=device, dtype=dtype)
748
+ if image.shape[1] != self.latent_channels:
749
+ image_latents = self._encode_vae_image(image=image, generator=generator)
750
+ else:
751
+ image_latents = image
752
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
753
+ # expand init_latents for batch_size
754
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
755
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
756
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
757
+ raise ValueError(
758
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
759
+ )
760
+ else:
761
+ image_latents = torch.cat([image_latents], dim=0)
762
+
763
+ # Prepare image reference latents
764
+ image_reference_latents = image_reference_ids = None
765
+ if image_reference is not None:
766
+ image_reference = image_reference.to(device=device, dtype=dtype)
767
+ if image_reference.shape[1] != self.latent_channels:
768
+ image_reference_latents = self._encode_vae_image(image=image_reference, generator=generator)
769
+ else:
770
+ image_reference_latents = image_reference
771
+ if batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] == 0:
772
+ # expand init_latents for batch_size
773
+ additional_image_per_prompt = batch_size // image_reference_latents.shape[0]
774
+ image_reference_latents = torch.cat([image_reference_latents] * additional_image_per_prompt, dim=0)
775
+ elif batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] != 0:
776
+ raise ValueError(
777
+ f"Cannot duplicate `image_reference` of batch size {image_reference_latents.shape[0]} to {batch_size} text prompts."
778
+ )
779
+ else:
780
+ image_reference_latents = torch.cat([image_reference_latents], dim=0)
781
+
782
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
783
+
784
+ if latents is None:
785
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
786
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
787
+ else:
788
+ noise = latents.to(device=device, dtype=dtype)
789
+ latents = noise
790
+
791
+ image_latent_height, image_latent_width = image_latents.shape[2:]
792
+ image_latents = self._pack_latents(
793
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
794
+ )
795
+ image_ids = self._prepare_latent_image_ids(
796
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
797
+ )
798
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
799
+ image_ids[..., 0] = 1
800
+
801
+ if image_reference_latents is not None:
802
+ image_reference_latent_height, image_reference_latent_width = image_reference_latents.shape[2:]
803
+ image_reference_latents = self._pack_latents(
804
+ image_reference_latents,
805
+ batch_size,
806
+ num_channels_latents,
807
+ image_reference_latent_height,
808
+ image_reference_latent_width,
809
+ )
810
+ image_reference_ids = self._prepare_latent_image_ids(
811
+ batch_size, image_reference_latent_height // 2, image_reference_latent_width // 2, device, dtype
812
+ )
813
+ # image_reference_ids are the same as latent ids with the first dimension set to 1 instead of 0
814
+ image_reference_ids[..., 0] = 1
815
+
816
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
817
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
818
+
819
+ return latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise
820
+
821
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
822
+ def prepare_mask_latents(
823
+ self,
824
+ mask,
825
+ masked_image,
826
+ batch_size,
827
+ num_channels_latents,
828
+ num_images_per_prompt,
829
+ height,
830
+ width,
831
+ dtype,
832
+ device,
833
+ generator,
834
+ ):
835
+ # VAE applies 8x compression on images but we must also account for packing which requires
836
+ # latent height and width to be divisible by 2.
837
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
838
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
839
+ # resize the mask to latents shape as we concatenate the mask to the latents
840
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
841
+ # and half precision
842
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
843
+ mask = mask.to(device=device, dtype=dtype)
844
+
845
+ batch_size = batch_size * num_images_per_prompt
846
+
847
+ masked_image = masked_image.to(device=device, dtype=dtype)
848
+
849
+ if masked_image.shape[1] == 16:
850
+ masked_image_latents = masked_image
851
+ else:
852
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
853
+
854
+ masked_image_latents = (
855
+ masked_image_latents - self.vae.config.shift_factor
856
+ ) * self.vae.config.scaling_factor
857
+
858
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
859
+ if mask.shape[0] < batch_size:
860
+ if not batch_size % mask.shape[0] == 0:
861
+ raise ValueError(
862
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
863
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
864
+ " of masks that you pass is divisible by the total requested batch size."
865
+ )
866
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
867
+ if masked_image_latents.shape[0] < batch_size:
868
+ if not batch_size % masked_image_latents.shape[0] == 0:
869
+ raise ValueError(
870
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
871
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
872
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
873
+ )
874
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
875
+
876
+ # aligning device to prevent device errors when concating it with the latent model input
877
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
878
+ masked_image_latents = self._pack_latents(
879
+ masked_image_latents,
880
+ batch_size,
881
+ num_channels_latents,
882
+ height,
883
+ width,
884
+ )
885
+ mask = self._pack_latents(
886
+ mask.repeat(1, num_channels_latents, 1, 1),
887
+ batch_size,
888
+ num_channels_latents,
889
+ height,
890
+ width,
891
+ )
892
+
893
+ return mask, masked_image_latents
894
+
895
+ @property
896
+ def guidance_scale(self):
897
+ return self._guidance_scale
898
+
899
+ @property
900
+ def joint_attention_kwargs(self):
901
+ return self._joint_attention_kwargs
902
+
903
+ @property
904
+ def num_timesteps(self):
905
+ return self._num_timesteps
906
+
907
+ @property
908
+ def current_timestep(self):
909
+ return self._current_timestep
910
+
911
+ @property
912
+ def interrupt(self):
913
+ return self._interrupt
914
+
915
+ @torch.no_grad()
916
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
917
+ def __call__(
918
+ self,
919
+ image: Optional[PipelineImageInput] = None,
920
+ image_reference: Optional[PipelineImageInput] = None,
921
+ mask_image: PipelineImageInput = None,
922
+ prompt: Union[str, List[str]] = None,
923
+ prompt_2: Optional[Union[str, List[str]]] = None,
924
+ negative_prompt: Union[str, List[str]] = None,
925
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
926
+ true_cfg_scale: float = 1.0,
927
+ height: Optional[int] = None,
928
+ width: Optional[int] = None,
929
+ strength: float = 1.0,
930
+ padding_mask_crop: Optional[int] = None,
931
+ num_inference_steps: int = 28,
932
+ sigmas: Optional[List[float]] = None,
933
+ guidance_scale: float = 3.5,
934
+ num_images_per_prompt: Optional[int] = 1,
935
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
936
+ latents: Optional[torch.FloatTensor] = None,
937
+ prompt_embeds: Optional[torch.FloatTensor] = None,
938
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
939
+ ip_adapter_image: Optional[PipelineImageInput] = None,
940
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
941
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
942
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
943
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
944
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
945
+ output_type: Optional[str] = "pil",
946
+ return_dict: bool = True,
947
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
948
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
949
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
950
+ max_sequence_length: int = 512,
951
+ max_area: int = 1024**2,
952
+ _auto_resize: bool = True,
953
+ ):
954
+ r"""
955
+ Function invoked when calling the pipeline for generation.
956
+
957
+ Args:
958
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
959
+ `Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of the image
960
+ to be masked out with `mask_image` and repainted according to `prompt` and `image_reference`). For both
961
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
962
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
963
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
964
+ latents as `image`, but if passing latents directly it is not encoded again.
965
+ image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
966
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point for the
967
+ masked area. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If
968
+ it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is
969
+ a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can
970
+ also accept image latents as `image`, but if passing latents directly it is not encoded again.
971
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
972
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
973
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
974
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
975
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
976
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
977
+ 1)`, or `(H, W)`.
978
+ prompt (`str` or `List[str]`, *optional*):
979
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
980
+ instead.
981
+ prompt_2 (`str` or `List[str]`, *optional*):
982
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
983
+ will be used instead.
984
+ negative_prompt (`str` or `List[str]`, *optional*):
985
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
986
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
987
+ not greater than `1`).
988
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
989
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
990
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
991
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
992
+ True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
993
+ `negative_prompt` is provided.
994
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
995
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
996
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
997
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
998
+ strength (`float`, *optional*, defaults to 1.0):
999
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
1000
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
1001
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
1002
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
1003
+ essentially ignores `image`.
1004
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
1005
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
1006
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
1007
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
1008
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
1009
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
1010
+ the image is large and contain information irrelevant for inpainting, such as background.
1011
+ num_inference_steps (`int`, *optional*, defaults to 50):
1012
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1013
+ expense of slower inference.
1014
+ sigmas (`List[float]`, *optional*):
1015
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1016
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1017
+ will be used.
1018
+ guidance_scale (`float`, *optional*, defaults to 3.5):
1019
+ Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
1020
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
1021
+
1022
+ Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
1023
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
1024
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1025
+ The number of images to generate per prompt.
1026
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1027
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1028
+ to make generation deterministic.
1029
+ latents (`torch.FloatTensor`, *optional*):
1030
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1031
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1032
+ tensor will ge generated by sampling using the supplied random `generator`.
1033
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1034
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1035
+ provided, text embeddings will be generated from `prompt` input argument.
1036
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1037
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1038
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1039
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
1040
+ Optional image input to work with IP Adapters.
1041
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1042
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1043
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
1044
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1045
+ negative_ip_adapter_image:
1046
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1047
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1048
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1049
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
1050
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1051
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1052
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1053
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1054
+ argument.
1055
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1056
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1057
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1058
+ input argument.
1059
+ output_type (`str`, *optional*, defaults to `"pil"`):
1060
+ The output format of the generate image. Choose between
1061
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1062
+ return_dict (`bool`, *optional*, defaults to `True`):
1063
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
1064
+ joint_attention_kwargs (`dict`, *optional*):
1065
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1066
+ `self.processor` in
1067
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1068
+ callback_on_step_end (`Callable`, *optional*):
1069
+ A function that calls at the end of each denoising steps during the inference. The function is called
1070
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1071
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1072
+ `callback_on_step_end_tensor_inputs`.
1073
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1074
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1075
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1076
+ `._callback_tensor_inputs` attribute of your pipeline class.
1077
+ max_sequence_length (`int` defaults to 512):
1078
+ Maximum sequence length to use with the `prompt`.
1079
+ max_area (`int`, defaults to `1024 ** 2`):
1080
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
1081
+ area while maintaining the aspect ratio.
1082
+
1083
+ Examples:
1084
+
1085
+ Returns:
1086
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
1087
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
1088
+ images.
1089
+ """
1090
+
1091
+ height = height or self.default_sample_size * self.vae_scale_factor
1092
+ width = width or self.default_sample_size * self.vae_scale_factor
1093
+
1094
+ original_height, original_width = height, width
1095
+ aspect_ratio = width / height
1096
+ width = round((max_area * aspect_ratio) ** 0.5)
1097
+ height = round((max_area / aspect_ratio) ** 0.5)
1098
+
1099
+ multiple_of = self.vae_scale_factor * 2
1100
+ width = width // multiple_of * multiple_of
1101
+ height = height // multiple_of * multiple_of
1102
+
1103
+ if height != original_height or width != original_width:
1104
+ logger.warning(
1105
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
1106
+ )
1107
+
1108
+ # 1. Check inputs. Raise error if not correct
1109
+ self.check_inputs(
1110
+ prompt,
1111
+ prompt_2,
1112
+ image,
1113
+ mask_image,
1114
+ strength,
1115
+ height,
1116
+ width,
1117
+ output_type=output_type,
1118
+ negative_prompt=negative_prompt,
1119
+ negative_prompt_2=negative_prompt_2,
1120
+ prompt_embeds=prompt_embeds,
1121
+ negative_prompt_embeds=negative_prompt_embeds,
1122
+ pooled_prompt_embeds=pooled_prompt_embeds,
1123
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1124
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1125
+ padding_mask_crop=padding_mask_crop,
1126
+ max_sequence_length=max_sequence_length,
1127
+ )
1128
+
1129
+ self._guidance_scale = guidance_scale
1130
+ self._joint_attention_kwargs = joint_attention_kwargs
1131
+ self._current_timestep = None
1132
+ self._interrupt = False
1133
+
1134
+ # 2. Preprocess image
1135
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
1136
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
1137
+ image = torch.cat(image, dim=0)
1138
+ img = image[0] if isinstance(image, list) else image
1139
+ image_height, image_width = self.image_processor.get_default_height_width(img)
1140
+ aspect_ratio = image_width / image_height
1141
+ if _auto_resize:
1142
+ # Kontext is trained on specific resolutions, using one of them is recommended
1143
+ _, image_width, image_height = min(
1144
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
1145
+ )
1146
+ image_width = image_width // multiple_of * multiple_of
1147
+ image_height = image_height // multiple_of * multiple_of
1148
+ image = self.image_processor.resize(image, image_height, image_width)
1149
+
1150
+ # Choose the resolution of the image to be the same as the image
1151
+ width = image_width
1152
+ height = image_height
1153
+
1154
+ # 2.1 Preprocess mask
1155
+ if padding_mask_crop is not None:
1156
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
1157
+ resize_mode = "fill"
1158
+ else:
1159
+ crops_coords = None
1160
+ resize_mode = "default"
1161
+
1162
+ image = self.image_processor.preprocess(
1163
+ image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode
1164
+ )
1165
+ else:
1166
+ raise ValueError("image must be provided correctly for inpainting")
1167
+
1168
+ init_image = image.to(dtype=torch.float32)
1169
+
1170
+ # 2.1 Preprocess image_reference
1171
+ if image_reference is not None and not (
1172
+ isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels
1173
+ ):
1174
+ if (
1175
+ isinstance(image_reference, list)
1176
+ and isinstance(image_reference[0], torch.Tensor)
1177
+ and image_reference[0].ndim == 4
1178
+ ):
1179
+ image_reference = torch.cat(image_reference, dim=0)
1180
+ img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference
1181
+ image_reference_height, image_reference_width = self.image_processor.get_default_height_width(
1182
+ img_reference
1183
+ )
1184
+ aspect_ratio = image_reference_width / image_reference_height
1185
+ if _auto_resize:
1186
+ # Kontext is trained on specific resolutions, using one of them is recommended
1187
+ _, image_reference_width, image_reference_height = min(
1188
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
1189
+ )
1190
+ image_reference_width = image_reference_width // multiple_of * multiple_of
1191
+ image_reference_height = image_reference_height // multiple_of * multiple_of
1192
+ image_reference = self.image_processor.resize(
1193
+ image_reference, image_reference_height, image_reference_width
1194
+ )
1195
+ image_reference = self.image_processor.preprocess(
1196
+ image_reference,
1197
+ image_reference_height,
1198
+ image_reference_width,
1199
+ crops_coords=crops_coords,
1200
+ resize_mode=resize_mode,
1201
+ )
1202
+ else:
1203
+ image_reference = None
1204
+
1205
+ # 3. Define call parameters
1206
+ if prompt is not None and isinstance(prompt, str):
1207
+ batch_size = 1
1208
+ elif prompt is not None and isinstance(prompt, list):
1209
+ batch_size = len(prompt)
1210
+ else:
1211
+ batch_size = prompt_embeds.shape[0]
1212
+
1213
+ device = self._execution_device
1214
+
1215
+ lora_scale = (
1216
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1217
+ )
1218
+ has_neg_prompt = negative_prompt is not None or (
1219
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
1220
+ )
1221
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
1222
+ (
1223
+ prompt_embeds,
1224
+ pooled_prompt_embeds,
1225
+ text_ids,
1226
+ ) = self.encode_prompt(
1227
+ prompt=prompt,
1228
+ prompt_2=prompt_2,
1229
+ prompt_embeds=prompt_embeds,
1230
+ pooled_prompt_embeds=pooled_prompt_embeds,
1231
+ device=device,
1232
+ num_images_per_prompt=num_images_per_prompt,
1233
+ max_sequence_length=max_sequence_length,
1234
+ lora_scale=lora_scale,
1235
+ )
1236
+ if do_true_cfg:
1237
+ (
1238
+ negative_prompt_embeds,
1239
+ negative_pooled_prompt_embeds,
1240
+ negative_text_ids,
1241
+ ) = self.encode_prompt(
1242
+ prompt=negative_prompt,
1243
+ prompt_2=negative_prompt_2,
1244
+ prompt_embeds=negative_prompt_embeds,
1245
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
1246
+ device=device,
1247
+ num_images_per_prompt=num_images_per_prompt,
1248
+ max_sequence_length=max_sequence_length,
1249
+ lora_scale=lora_scale,
1250
+ )
1251
+
1252
+ # 4. Prepare timesteps
1253
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1254
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
1255
+ mu = calculate_shift(
1256
+ image_seq_len,
1257
+ self.scheduler.config.get("base_image_seq_len", 256),
1258
+ self.scheduler.config.get("max_image_seq_len", 4096),
1259
+ self.scheduler.config.get("base_shift", 0.5),
1260
+ self.scheduler.config.get("max_shift", 1.15),
1261
+ )
1262
+ timesteps, num_inference_steps = retrieve_timesteps(
1263
+ self.scheduler,
1264
+ num_inference_steps,
1265
+ device,
1266
+ sigmas=sigmas,
1267
+ mu=mu,
1268
+ )
1269
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1270
+ if num_inference_steps < 1:
1271
+ raise ValueError(
1272
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1273
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1274
+ )
1275
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1276
+
1277
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1278
+ self._num_timesteps = len(timesteps)
1279
+
1280
+ # 5. Prepare latent variables
1281
+ num_channels_latents = self.transformer.config.in_channels // 4
1282
+ latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise = (
1283
+ self.prepare_latents(
1284
+ init_image,
1285
+ latent_timestep,
1286
+ batch_size * num_images_per_prompt,
1287
+ num_channels_latents,
1288
+ height,
1289
+ width,
1290
+ prompt_embeds.dtype,
1291
+ device,
1292
+ generator,
1293
+ latents,
1294
+ image_reference,
1295
+ )
1296
+ )
1297
+
1298
+ if image_reference_ids is not None:
1299
+ latent_ids = torch.cat([latent_ids, image_reference_ids], dim=0) # dim 0 is sequence dimension
1300
+ elif image_ids is not None:
1301
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
1302
+
1303
+ mask_condition = self.mask_processor.preprocess(
1304
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1305
+ )
1306
+
1307
+ masked_image = init_image * (mask_condition < 0.5)
1308
+
1309
+ mask, _ = self.prepare_mask_latents(
1310
+ mask_condition,
1311
+ masked_image,
1312
+ batch_size,
1313
+ num_channels_latents,
1314
+ num_images_per_prompt,
1315
+ height,
1316
+ width,
1317
+ prompt_embeds.dtype,
1318
+ device,
1319
+ generator,
1320
+ )
1321
+
1322
+ # handle guidance
1323
+ if self.transformer.config.guidance_embeds:
1324
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1325
+ guidance = guidance.expand(latents.shape[0])
1326
+ else:
1327
+ guidance = None
1328
+
1329
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
1330
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1331
+ ):
1332
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1333
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1334
+
1335
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
1336
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1337
+ ):
1338
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1339
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
1340
+
1341
+ if self.joint_attention_kwargs is None:
1342
+ self._joint_attention_kwargs = {}
1343
+
1344
+ image_embeds = None
1345
+ negative_image_embeds = None
1346
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1347
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1348
+ ip_adapter_image,
1349
+ ip_adapter_image_embeds,
1350
+ device,
1351
+ batch_size * num_images_per_prompt,
1352
+ )
1353
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
1354
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
1355
+ negative_ip_adapter_image,
1356
+ negative_ip_adapter_image_embeds,
1357
+ device,
1358
+ batch_size * num_images_per_prompt,
1359
+ )
1360
+
1361
+ # 6. Denoising loop
1362
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1363
+ for i, t in enumerate(timesteps):
1364
+ if self.interrupt:
1365
+ continue
1366
+
1367
+ self._current_timestep = t
1368
+ if image_embeds is not None:
1369
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
1370
+
1371
+ latent_model_input = latents
1372
+ if image_reference_latents is not None:
1373
+ latent_model_input = torch.cat([latents, image_reference_latents], dim=1)
1374
+ elif image_latents is not None:
1375
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
1376
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1377
+
1378
+ noise_pred = self.transformer(
1379
+ hidden_states=latent_model_input,
1380
+ timestep=timestep / 1000,
1381
+ guidance=guidance,
1382
+ pooled_projections=pooled_prompt_embeds,
1383
+ encoder_hidden_states=prompt_embeds,
1384
+ txt_ids=text_ids,
1385
+ img_ids=latent_ids,
1386
+ joint_attention_kwargs=self.joint_attention_kwargs,
1387
+ return_dict=False,
1388
+ )[0]
1389
+ noise_pred = noise_pred[:, : latents.size(1)]
1390
+
1391
+ if do_true_cfg:
1392
+ if negative_image_embeds is not None:
1393
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1394
+ neg_noise_pred = self.transformer(
1395
+ hidden_states=latent_model_input,
1396
+ timestep=timestep / 1000,
1397
+ guidance=guidance,
1398
+ pooled_projections=negative_pooled_prompt_embeds,
1399
+ encoder_hidden_states=negative_prompt_embeds,
1400
+ txt_ids=negative_text_ids,
1401
+ img_ids=latent_ids,
1402
+ joint_attention_kwargs=self.joint_attention_kwargs,
1403
+ return_dict=False,
1404
+ )[0]
1405
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
1406
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1407
+
1408
+ # compute the previous noisy sample x_t -> x_t-1
1409
+ latents_dtype = latents.dtype
1410
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1411
+
1412
+ init_latents_proper = image_latents
1413
+ init_mask = mask
1414
+
1415
+ if i < len(timesteps) - 1:
1416
+ noise_timestep = timesteps[i + 1]
1417
+ init_latents_proper = self.scheduler.scale_noise(
1418
+ init_latents_proper, torch.tensor([noise_timestep]), noise
1419
+ )
1420
+
1421
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1422
+
1423
+ if latents.dtype != latents_dtype:
1424
+ if torch.backends.mps.is_available():
1425
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1426
+ latents = latents.to(latents_dtype)
1427
+
1428
+ if callback_on_step_end is not None:
1429
+ callback_kwargs = {}
1430
+ for k in callback_on_step_end_tensor_inputs:
1431
+ callback_kwargs[k] = locals()[k]
1432
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1433
+
1434
+ latents = callback_outputs.pop("latents", latents)
1435
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1436
+
1437
+ # call the callback, if provided
1438
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1439
+ progress_bar.update()
1440
+
1441
+ if XLA_AVAILABLE:
1442
+ xm.mark_step()
1443
+
1444
+ self._current_timestep = None
1445
+
1446
+ if output_type == "latent":
1447
+ image = latents
1448
+ else:
1449
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1450
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1451
+ image = self.vae.decode(latents, return_dict=False)[0]
1452
+ image = self.image_processor.postprocess(image, output_type=output_type)
1453
+
1454
+ # Offload all models
1455
+ self.maybe_free_model_hooks()
1456
+
1457
+ if not return_dict:
1458
+ return (image,)
1459
+
1460
+ return FluxPipelineOutput(images=image)