diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,327 @@
1
+ # Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Dict, List
17
+
18
+ import numpy as np
19
+ import torch
20
+ from PIL import Image
21
+ from torchvision import transforms
22
+
23
+
24
+ def crop_image(pil_image, max_image_size):
25
+ """
26
+ Crop the image so that its height and width does not exceed `max_image_size`, while ensuring both the height and
27
+ width are multiples of 16.
28
+ """
29
+ while min(*pil_image.size) >= 2 * max_image_size:
30
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
31
+
32
+ if max(*pil_image.size) > max_image_size:
33
+ scale = max_image_size / max(*pil_image.size)
34
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
35
+
36
+ if min(*pil_image.size) < 16:
37
+ scale = 16 / min(*pil_image.size)
38
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
39
+
40
+ arr = np.array(pil_image)
41
+ crop_y1 = (arr.shape[0] % 16) // 2
42
+ crop_y2 = arr.shape[0] % 16 - crop_y1
43
+
44
+ crop_x1 = (arr.shape[1] % 16) // 2
45
+ crop_x2 = arr.shape[1] % 16 - crop_x1
46
+
47
+ arr = arr[crop_y1 : arr.shape[0] - crop_y2, crop_x1 : arr.shape[1] - crop_x2]
48
+ return Image.fromarray(arr)
49
+
50
+
51
+ class OmniGenMultiModalProcessor:
52
+ def __init__(self, text_tokenizer, max_image_size: int = 1024):
53
+ self.text_tokenizer = text_tokenizer
54
+ self.max_image_size = max_image_size
55
+
56
+ self.image_transform = transforms.Compose(
57
+ [
58
+ transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
61
+ ]
62
+ )
63
+
64
+ self.collator = OmniGenCollator()
65
+
66
+ def reset_max_image_size(self, max_image_size):
67
+ self.max_image_size = max_image_size
68
+ self.image_transform = transforms.Compose(
69
+ [
70
+ transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
71
+ transforms.ToTensor(),
72
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
73
+ ]
74
+ )
75
+
76
+ def process_image(self, image):
77
+ if isinstance(image, str):
78
+ image = Image.open(image).convert("RGB")
79
+ return self.image_transform(image)
80
+
81
+ def process_multi_modal_prompt(self, text, input_images):
82
+ text = self.add_prefix_instruction(text)
83
+ if input_images is None or len(input_images) == 0:
84
+ model_inputs = self.text_tokenizer(text)
85
+ return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
86
+
87
+ pattern = r"<\|image_\d+\|>"
88
+ prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
89
+
90
+ for i in range(1, len(prompt_chunks)):
91
+ if prompt_chunks[i][0] == 1:
92
+ prompt_chunks[i] = prompt_chunks[i][1:]
93
+
94
+ image_tags = re.findall(pattern, text)
95
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
96
+
97
+ unique_image_ids = sorted(set(image_ids))
98
+ assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), (
99
+ f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
100
+ )
101
+ # total images must be the same as the number of image tags
102
+ assert len(unique_image_ids) == len(input_images), (
103
+ f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
104
+ )
105
+
106
+ input_images = [input_images[x - 1] for x in image_ids]
107
+
108
+ all_input_ids = []
109
+ img_inx = []
110
+ for i in range(len(prompt_chunks)):
111
+ all_input_ids.extend(prompt_chunks[i])
112
+ if i != len(prompt_chunks) - 1:
113
+ start_inx = len(all_input_ids)
114
+ size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
115
+ img_inx.append([start_inx, start_inx + size])
116
+ all_input_ids.extend([0] * size)
117
+
118
+ return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
119
+
120
+ def add_prefix_instruction(self, prompt):
121
+ user_prompt = "<|user|>\n"
122
+ generation_prompt = "Generate an image according to the following instructions\n"
123
+ assistant_prompt = "<|assistant|>\n<|diffusion|>"
124
+ prompt_suffix = "<|end|>\n"
125
+ prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
126
+ return prompt
127
+
128
+ def __call__(
129
+ self,
130
+ instructions: List[str],
131
+ input_images: List[List[str]] = None,
132
+ height: int = 1024,
133
+ width: int = 1024,
134
+ negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
135
+ use_img_cfg: bool = True,
136
+ separate_cfg_input: bool = False,
137
+ use_input_image_size_as_output: bool = False,
138
+ num_images_per_prompt: int = 1,
139
+ ) -> Dict:
140
+ if isinstance(instructions, str):
141
+ instructions = [instructions]
142
+ input_images = [input_images]
143
+
144
+ input_data = []
145
+ for i in range(len(instructions)):
146
+ cur_instruction = instructions[i]
147
+ cur_input_images = None if input_images is None else input_images[i]
148
+ if cur_input_images is not None and len(cur_input_images) > 0:
149
+ cur_input_images = [self.process_image(x) for x in cur_input_images]
150
+ else:
151
+ cur_input_images = None
152
+ assert "<img><|image_1|></img>" not in cur_instruction
153
+
154
+ mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
155
+
156
+ neg_mllm_input, img_cfg_mllm_input = None, None
157
+ neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
158
+ if use_img_cfg:
159
+ if cur_input_images is not None and len(cur_input_images) >= 1:
160
+ img_cfg_prompt = [f"<img><|image_{i + 1}|></img>" for i in range(len(cur_input_images))]
161
+ img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
162
+ else:
163
+ img_cfg_mllm_input = neg_mllm_input
164
+
165
+ for _ in range(num_images_per_prompt):
166
+ if use_input_image_size_as_output:
167
+ input_data.append(
168
+ (
169
+ mllm_input,
170
+ neg_mllm_input,
171
+ img_cfg_mllm_input,
172
+ [mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)],
173
+ )
174
+ )
175
+ else:
176
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
177
+
178
+ return self.collator(input_data)
179
+
180
+
181
+ class OmniGenCollator:
182
+ def __init__(self, pad_token_id=2, hidden_size=3072):
183
+ self.pad_token_id = pad_token_id
184
+ self.hidden_size = hidden_size
185
+
186
+ def create_position(self, attention_mask, num_tokens_for_output_images):
187
+ position_ids = []
188
+ text_length = attention_mask.size(-1)
189
+ img_length = max(num_tokens_for_output_images)
190
+ for mask in attention_mask:
191
+ temp_l = torch.sum(mask)
192
+ temp_position = [0] * (text_length - temp_l) + list(
193
+ range(temp_l + img_length + 1)
194
+ ) # we add a time embedding into the sequence, so add one more token
195
+ position_ids.append(temp_position)
196
+ return torch.LongTensor(position_ids)
197
+
198
+ def create_mask(self, attention_mask, num_tokens_for_output_images):
199
+ """
200
+ OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within
201
+ each image sequence References: [OmniGen](https://arxiv.org/pdf/2409.11340)
202
+ """
203
+ extended_mask = []
204
+ padding_images = []
205
+ text_length = attention_mask.size(-1)
206
+ img_length = max(num_tokens_for_output_images)
207
+ seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
208
+ inx = 0
209
+ for mask in attention_mask:
210
+ temp_l = torch.sum(mask)
211
+ pad_l = text_length - temp_l
212
+
213
+ temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1)))
214
+
215
+ image_mask = torch.zeros(size=(temp_l + 1, img_length))
216
+ temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
217
+
218
+ image_mask = torch.ones(size=(img_length, temp_l + img_length + 1))
219
+ temp_mask = torch.cat([temp_mask, image_mask], dim=0)
220
+
221
+ if pad_l > 0:
222
+ pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l))
223
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
224
+
225
+ pad_mask = torch.ones(size=(pad_l, seq_len))
226
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
227
+
228
+ true_img_length = num_tokens_for_output_images[inx]
229
+ pad_img_length = img_length - true_img_length
230
+ if pad_img_length > 0:
231
+ temp_mask[:, -pad_img_length:] = 0
232
+ temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
233
+ else:
234
+ temp_padding_imgs = None
235
+
236
+ extended_mask.append(temp_mask.unsqueeze(0))
237
+ padding_images.append(temp_padding_imgs)
238
+ inx += 1
239
+ return torch.cat(extended_mask, dim=0), padding_images
240
+
241
+ def adjust_attention_for_input_images(self, attention_mask, image_sizes):
242
+ for b_inx in image_sizes.keys():
243
+ for start_inx, end_inx in image_sizes[b_inx]:
244
+ attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
245
+
246
+ return attention_mask
247
+
248
+ def pad_input_ids(self, input_ids, image_sizes):
249
+ max_l = max([len(x) for x in input_ids])
250
+ padded_ids = []
251
+ attention_mask = []
252
+
253
+ for i in range(len(input_ids)):
254
+ temp_ids = input_ids[i]
255
+ temp_l = len(temp_ids)
256
+ pad_l = max_l - temp_l
257
+ if pad_l == 0:
258
+ attention_mask.append([1] * max_l)
259
+ padded_ids.append(temp_ids)
260
+ else:
261
+ attention_mask.append([0] * pad_l + [1] * temp_l)
262
+ padded_ids.append([self.pad_token_id] * pad_l + temp_ids)
263
+
264
+ if i in image_sizes:
265
+ new_inx = []
266
+ for old_inx in image_sizes[i]:
267
+ new_inx.append([x + pad_l for x in old_inx])
268
+ image_sizes[i] = new_inx
269
+
270
+ return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
271
+
272
+ def process_mllm_input(self, mllm_inputs, target_img_size):
273
+ num_tokens_for_output_images = []
274
+ for img_size in target_img_size:
275
+ num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16)
276
+
277
+ pixel_values, image_sizes = [], {}
278
+ b_inx = 0
279
+ for x in mllm_inputs:
280
+ if x["pixel_values"] is not None:
281
+ pixel_values.extend(x["pixel_values"])
282
+ for size in x["image_sizes"]:
283
+ if b_inx not in image_sizes:
284
+ image_sizes[b_inx] = [size]
285
+ else:
286
+ image_sizes[b_inx].append(size)
287
+ b_inx += 1
288
+ pixel_values = [x.unsqueeze(0) for x in pixel_values]
289
+
290
+ input_ids = [x["input_ids"] for x in mllm_inputs]
291
+ padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
292
+ position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
293
+ attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
294
+ attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
295
+
296
+ return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
297
+
298
+ def __call__(self, features):
299
+ mllm_inputs = [f[0] for f in features]
300
+ cfg_mllm_inputs = [f[1] for f in features]
301
+ img_cfg_mllm_input = [f[2] for f in features]
302
+ target_img_size = [f[3] for f in features]
303
+
304
+ if img_cfg_mllm_input[0] is not None:
305
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
306
+ target_img_size = target_img_size + target_img_size + target_img_size
307
+ else:
308
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs
309
+ target_img_size = target_img_size + target_img_size
310
+
311
+ (
312
+ all_padded_input_ids,
313
+ all_position_ids,
314
+ all_attention_mask,
315
+ all_padding_images,
316
+ all_pixel_values,
317
+ all_image_sizes,
318
+ ) = self.process_mllm_input(mllm_inputs, target_img_size)
319
+
320
+ data = {
321
+ "input_ids": all_padded_input_ids,
322
+ "attention_mask": all_attention_mask,
323
+ "position_ids": all_position_ids,
324
+ "input_pixel_values": all_pixel_values,
325
+ "input_image_sizes": all_image_sizes,
326
+ }
327
+ return data
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team.
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
3
  # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -61,7 +61,7 @@ class OnnxRuntimeModel:
61
61
  return self.model.run(None, inputs)
62
62
 
63
63
  @staticmethod
64
- def load_model(path: Union[str, Path], provider=None, sess_options=None):
64
+ def load_model(path: Union[str, Path], provider=None, sess_options=None, provider_options=None):
65
65
  """
66
66
  Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
67
67
 
@@ -75,7 +75,9 @@ class OnnxRuntimeModel:
75
75
  logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
76
76
  provider = "CPUExecutionProvider"
77
77
 
78
- return ort.InferenceSession(path, providers=[provider], sess_options=sess_options)
78
+ return ort.InferenceSession(
79
+ path, providers=[provider], sess_options=sess_options, provider_options=provider_options
80
+ )
79
81
 
80
82
  def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
81
83
  """
@@ -158,7 +158,7 @@ class PAGMixin:
158
158
  ),
159
159
  ):
160
160
  r"""
161
- Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
161
+ Set the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
162
162
 
163
163
  Args:
164
164
  pag_applied_layers (`str` or `List[str]`):
@@ -30,6 +30,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
30
30
  from ...schedulers import KarrasDiffusionSchedulers
31
31
  from ...utils import (
32
32
  USE_PEFT_BACKEND,
33
+ is_torch_xla_available,
33
34
  logging,
34
35
  replace_example_docstring,
35
36
  scale_lora_layers,
@@ -42,6 +43,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
42
43
  from .pag_utils import PAGMixin
43
44
 
44
45
 
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
45
53
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
54
 
47
55
 
@@ -251,7 +259,7 @@ class StableDiffusionControlNetPAGPipeline(
251
259
  feature_extractor=feature_extractor,
252
260
  image_encoder=image_encoder,
253
261
  )
254
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
262
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
255
263
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
256
264
  self.control_image_processor = VaeImageProcessor(
257
265
  vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -1293,6 +1301,9 @@ class StableDiffusionControlNetPAGPipeline(
1293
1301
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1294
1302
  progress_bar.update()
1295
1303
 
1304
+ if XLA_AVAILABLE:
1305
+ xm.mark_step()
1306
+
1296
1307
  # If we do sequential model offloading, let's offload unet and controlnet
1297
1308
  # manually for max memory savings
1298
1309
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
@@ -31,6 +31,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
31
31
  from ...schedulers import KarrasDiffusionSchedulers
32
32
  from ...utils import (
33
33
  USE_PEFT_BACKEND,
34
+ is_torch_xla_available,
34
35
  logging,
35
36
  replace_example_docstring,
36
37
  scale_lora_layers,
@@ -43,6 +44,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
43
44
  from .pag_utils import PAGMixin
44
45
 
45
46
 
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
46
54
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
55
 
48
56
 
@@ -228,7 +236,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
228
236
  feature_extractor=feature_extractor,
229
237
  image_encoder=image_encoder,
230
238
  )
231
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
239
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
232
240
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
233
241
  self.mask_processor = VaeImageProcessor(
234
242
  vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -596,7 +604,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
596
604
  if padding_mask_crop is not None:
597
605
  if not isinstance(image, PIL.Image.Image):
598
606
  raise ValueError(
599
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
607
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
600
608
  )
601
609
  if not isinstance(mask_image, PIL.Image.Image):
602
610
  raise ValueError(
@@ -604,7 +612,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
604
612
  f" {type(mask_image)}."
605
613
  )
606
614
  if output_type != "pil":
607
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
615
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
608
616
 
609
617
  # `prompt` needs more sophisticated handling when there are multiple
610
618
  # conditionings.
@@ -1332,7 +1340,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
1332
1340
  f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1333
1341
  f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1334
1342
  f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1335
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1343
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
1336
1344
  " `pipeline.unet` or your `mask_image` or `image` input."
1337
1345
  )
1338
1346
  elif num_channels_unet != 4:
@@ -1505,6 +1513,9 @@ class StableDiffusionControlNetPAGInpaintPipeline(
1505
1513
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1506
1514
  progress_bar.update()
1507
1515
 
1516
+ if XLA_AVAILABLE:
1517
+ xm.mark_step()
1518
+
1508
1519
  # If we do sequential model offloading, let's offload unet and controlnet
1509
1520
  # manually for max memory savings
1510
1521
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
62
62
  from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
63
63
 
64
64
 
65
+ from ...utils import is_torch_xla_available
66
+
67
+
68
+ if is_torch_xla_available():
69
+ import torch_xla.core.xla_model as xm
70
+
71
+ XLA_AVAILABLE = True
72
+ else:
73
+ XLA_AVAILABLE = False
74
+
65
75
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
66
76
 
67
77
 
@@ -280,7 +290,7 @@ class StableDiffusionXLControlNetPAGPipeline(
280
290
  feature_extractor=feature_extractor,
281
291
  image_encoder=image_encoder,
282
292
  )
283
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
293
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
284
294
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
285
295
  self.control_image_processor = VaeImageProcessor(
286
296
  vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -421,7 +431,9 @@ class StableDiffusionXLControlNetPAGPipeline(
421
431
  prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
422
432
 
423
433
  # We are only ALWAYS interested in the pooled output of the final text encoder
424
- pooled_prompt_embeds = prompt_embeds[0]
434
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
435
+ pooled_prompt_embeds = prompt_embeds[0]
436
+
425
437
  if clip_skip is None:
426
438
  prompt_embeds = prompt_embeds.hidden_states[-2]
427
439
  else:
@@ -480,8 +492,10 @@ class StableDiffusionXLControlNetPAGPipeline(
480
492
  uncond_input.input_ids.to(device),
481
493
  output_hidden_states=True,
482
494
  )
495
+
483
496
  # We are only ALWAYS interested in the pooled output of the final text encoder
484
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
497
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
498
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
485
499
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
486
500
 
487
501
  negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1560,6 +1574,9 @@ class StableDiffusionXLControlNetPAGPipeline(
1560
1574
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1561
1575
  progress_bar.update()
1562
1576
 
1577
+ if XLA_AVAILABLE:
1578
+ xm.mark_step()
1579
+
1563
1580
  if not output_type == "latent":
1564
1581
  # make sure the VAE is in float32 mode, as it overflows in float16
1565
1582
  needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
62
62
  from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
63
63
 
64
64
 
65
+ from ...utils import is_torch_xla_available
66
+
67
+
68
+ if is_torch_xla_available():
69
+ import torch_xla.core.xla_model as xm
70
+
71
+ XLA_AVAILABLE = True
72
+ else:
73
+ XLA_AVAILABLE = False
74
+
65
75
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
66
76
 
67
77
 
@@ -270,7 +280,7 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
270
280
  feature_extractor=feature_extractor,
271
281
  image_encoder=image_encoder,
272
282
  )
273
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
283
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
274
284
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
275
285
  self.control_image_processor = VaeImageProcessor(
276
286
  vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -413,7 +423,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
413
423
  prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
414
424
 
415
425
  # We are only ALWAYS interested in the pooled output of the final text encoder
416
- pooled_prompt_embeds = prompt_embeds[0]
426
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
427
+ pooled_prompt_embeds = prompt_embeds[0]
428
+
417
429
  if clip_skip is None:
418
430
  prompt_embeds = prompt_embeds.hidden_states[-2]
419
431
  else:
@@ -472,8 +484,10 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
472
484
  uncond_input.input_ids.to(device),
473
485
  output_hidden_states=True,
474
486
  )
487
+
475
488
  # We are only ALWAYS interested in the pooled output of the final text encoder
476
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
489
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
490
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
477
491
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
478
492
 
479
493
  negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1626,6 +1640,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
1626
1640
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1627
1641
  progress_bar.update()
1628
1642
 
1643
+ if XLA_AVAILABLE:
1644
+ xm.mark_step()
1645
+
1629
1646
  # If we do sequential model offloading, let's offload unet and controlnet
1630
1647
  # manually for max memory savings
1631
1648
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
@@ -245,9 +245,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
245
245
  " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
246
246
  )
247
247
 
248
- self.vae_scale_factor = (
249
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
250
- )
248
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
251
249
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
252
250
  self.register_to_config(requires_safety_checker=requires_safety_checker)
253
251
  self.default_sample_size = (
@@ -202,12 +202,14 @@ class KolorsPAGPipeline(
202
202
  feature_extractor=feature_extractor,
203
203
  )
204
204
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
205
- self.vae_scale_factor = (
206
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
207
- )
205
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
208
206
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
209
207
 
210
- self.default_sample_size = self.unet.config.sample_size
208
+ self.default_sample_size = (
209
+ self.unet.config.sample_size
210
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
211
+ else 128
212
+ )
211
213
 
212
214
  self.set_pag_applied_layers(pag_applied_layers)
213
215
 
@@ -29,6 +29,7 @@ from ...utils import (
29
29
  deprecate,
30
30
  is_bs4_available,
31
31
  is_ftfy_available,
32
+ is_torch_xla_available,
32
33
  logging,
33
34
  replace_example_docstring,
34
35
  )
@@ -43,8 +44,16 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
43
44
  from .pag_utils import PAGMixin
44
45
 
45
46
 
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
46
54
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
55
 
56
+
48
57
  if is_bs4_available():
49
58
  from bs4 import BeautifulSoup
50
59
 
@@ -172,7 +181,7 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
172
181
  tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
173
182
  )
174
183
 
175
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
184
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
176
185
  self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
177
186
 
178
187
  self.set_pag_applied_layers(pag_applied_layers)
@@ -798,10 +807,11 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
798
807
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
799
808
  # This would be a good case for the `match` statement (Python 3.10+)
800
809
  is_mps = latent_model_input.device.type == "mps"
810
+ is_npu = latent_model_input.device.type == "npu"
801
811
  if isinstance(current_timestep, float):
802
- dtype = torch.float32 if is_mps else torch.float64
812
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
803
813
  else:
804
- dtype = torch.int32 if is_mps else torch.int64
814
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
805
815
  current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
806
816
  elif len(current_timestep.shape) == 0:
807
817
  current_timestep = current_timestep[None].to(latent_model_input.device)
@@ -843,6 +853,9 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
843
853
  step_idx = i // getattr(self.scheduler, "order", 1)
844
854
  callback(step_idx, t, latents)
845
855
 
856
+ if XLA_AVAILABLE:
857
+ xm.mark_step()
858
+
846
859
  if not output_type == "latent":
847
860
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
848
861
  if use_resolution_binning: