diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -207,8 +207,8 @@ class HunyuanDiTPipeline(DiffusionPipeline):
207
207
  safety_checker: StableDiffusionSafetyChecker,
208
208
  feature_extractor: CLIPImageProcessor,
209
209
  requires_safety_checker: bool = True,
210
- text_encoder_2=T5EncoderModel,
211
- tokenizer_2=MT5Tokenizer,
210
+ text_encoder_2: Optional[T5EncoderModel] = None,
211
+ tokenizer_2: Optional[MT5Tokenizer] = None,
212
212
  ):
213
213
  super().__init__()
214
214
 
@@ -240,9 +240,7 @@ class HunyuanDiTPipeline(DiffusionPipeline):
240
240
  " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
241
241
  )
242
242
 
243
- self.vae_scale_factor = (
244
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
245
- )
243
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
246
244
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
247
245
  self.register_to_config(requires_safety_checker=requires_safety_checker)
248
246
  self.default_sample_size = (
@@ -27,6 +27,7 @@ from ...models.unets.unet_i2vgen_xl import I2VGenXLUNet
27
27
  from ...schedulers import DDIMScheduler
28
28
  from ...utils import (
29
29
  BaseOutput,
30
+ is_torch_xla_available,
30
31
  logging,
31
32
  replace_example_docstring,
32
33
  )
@@ -35,8 +36,16 @@ from ...video_processor import VideoProcessor
35
36
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
36
37
 
37
38
 
39
+ if is_torch_xla_available():
40
+ import torch_xla.core.xla_model as xm
41
+
42
+ XLA_AVAILABLE = True
43
+ else:
44
+ XLA_AVAILABLE = False
45
+
38
46
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
47
 
48
+
40
49
  EXAMPLE_DOC_STRING = """
41
50
  Examples:
42
51
  ```py
@@ -133,7 +142,7 @@ class I2VGenXLPipeline(
133
142
  unet=unet,
134
143
  scheduler=scheduler,
135
144
  )
136
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
145
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
137
146
  # `do_resize=False` as we do custom resizing.
138
147
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False)
139
148
 
@@ -711,6 +720,9 @@ class I2VGenXLPipeline(
711
720
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
712
721
  progress_bar.update()
713
722
 
723
+ if XLA_AVAILABLE:
724
+ xm.mark_step()
725
+
714
726
  # 8. Post processing
715
727
  if output_type == "latent":
716
728
  video = latents
@@ -22,6 +22,7 @@ from transformers import (
22
22
  from ...models import UNet2DConditionModel, VQModel
23
23
  from ...schedulers import DDIMScheduler, DDPMScheduler
24
24
  from ...utils import (
25
+ is_torch_xla_available,
25
26
  logging,
26
27
  replace_example_docstring,
27
28
  )
@@ -30,8 +31,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
30
31
  from .text_encoder import MultilingualCLIP
31
32
 
32
33
 
34
+ if is_torch_xla_available():
35
+ import torch_xla.core.xla_model as xm
36
+
37
+ XLA_AVAILABLE = True
38
+ else:
39
+ XLA_AVAILABLE = False
40
+
33
41
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
42
 
43
+
35
44
  EXAMPLE_DOC_STRING = """
36
45
  Examples:
37
46
  ```py
@@ -385,6 +394,9 @@ class KandinskyPipeline(DiffusionPipeline):
385
394
  step_idx = i // getattr(self.scheduler, "order", 1)
386
395
  callback(step_idx, t, latents)
387
396
 
397
+ if XLA_AVAILABLE:
398
+ xm.mark_step()
399
+
388
400
  # post-processing
389
401
  image = self.movq.decode(latents, force_not_quantize=True)["sample"]
390
402
 
@@ -360,7 +360,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
360
360
  """
361
361
 
362
362
  _load_connected_pipes = True
363
- model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
363
+ model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
364
364
  _exclude_from_cpu_offload = ["prior_prior"]
365
365
 
366
366
  def __init__(
@@ -25,6 +25,7 @@ from transformers import (
25
25
  from ...models import UNet2DConditionModel, VQModel
26
26
  from ...schedulers import DDIMScheduler
27
27
  from ...utils import (
28
+ is_torch_xla_available,
28
29
  logging,
29
30
  replace_example_docstring,
30
31
  )
@@ -33,8 +34,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
33
34
  from .text_encoder import MultilingualCLIP
34
35
 
35
36
 
37
+ if is_torch_xla_available():
38
+ import torch_xla.core.xla_model as xm
39
+
40
+ XLA_AVAILABLE = True
41
+ else:
42
+ XLA_AVAILABLE = False
43
+
36
44
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
45
 
46
+
38
47
  EXAMPLE_DOC_STRING = """
39
48
  Examples:
40
49
  ```py
@@ -478,6 +487,9 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
478
487
  step_idx = i // getattr(self.scheduler, "order", 1)
479
488
  callback(step_idx, t, latents)
480
489
 
490
+ if XLA_AVAILABLE:
491
+ xm.mark_step()
492
+
481
493
  # 7. post-processing
482
494
  image = self.movq.decode(latents, force_not_quantize=True)["sample"]
483
495
 
@@ -29,6 +29,7 @@ from ... import __version__
29
29
  from ...models import UNet2DConditionModel, VQModel
30
30
  from ...schedulers import DDIMScheduler
31
31
  from ...utils import (
32
+ is_torch_xla_available,
32
33
  logging,
33
34
  replace_example_docstring,
34
35
  )
@@ -37,8 +38,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
37
38
  from .text_encoder import MultilingualCLIP
38
39
 
39
40
 
41
+ if is_torch_xla_available():
42
+ import torch_xla.core.xla_model as xm
43
+
44
+ XLA_AVAILABLE = True
45
+ else:
46
+ XLA_AVAILABLE = False
47
+
40
48
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
49
 
50
+
42
51
  EXAMPLE_DOC_STRING = """
43
52
  Examples:
44
53
  ```py
@@ -570,7 +579,7 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
570
579
  f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
571
580
  f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
572
581
  f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
573
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
582
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
574
583
  " `pipeline.unet` or your `mask_image` or `image` input."
575
584
  )
576
585
 
@@ -613,6 +622,9 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
613
622
  step_idx = i // getattr(self.scheduler, "order", 1)
614
623
  callback(step_idx, t, latents)
615
624
 
625
+ if XLA_AVAILABLE:
626
+ xm.mark_step()
627
+
616
628
  # post-processing
617
629
  image = self.movq.decode(latents, force_not_quantize=True)["sample"]
618
630
 
@@ -24,6 +24,7 @@ from ...models import PriorTransformer
24
24
  from ...schedulers import UnCLIPScheduler
25
25
  from ...utils import (
26
26
  BaseOutput,
27
+ is_torch_xla_available,
27
28
  logging,
28
29
  replace_example_docstring,
29
30
  )
@@ -31,8 +32,16 @@ from ...utils.torch_utils import randn_tensor
31
32
  from ..pipeline_utils import DiffusionPipeline
32
33
 
33
34
 
35
+ if is_torch_xla_available():
36
+ import torch_xla.core.xla_model as xm
37
+
38
+ XLA_AVAILABLE = True
39
+ else:
40
+ XLA_AVAILABLE = False
41
+
34
42
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
43
 
44
+
36
45
  EXAMPLE_DOC_STRING = """
37
46
  Examples:
38
47
  ```py
@@ -519,6 +528,9 @@ class KandinskyPriorPipeline(DiffusionPipeline):
519
528
  prev_timestep=prev_timestep,
520
529
  ).prev_sample
521
530
 
531
+ if XLA_AVAILABLE:
532
+ xm.mark_step()
533
+
522
534
  latents = self.prior.post_process_latents(latents)
523
535
 
524
536
  image_embeddings = latents
@@ -18,13 +18,21 @@ import torch
18
18
 
19
19
  from ...models import UNet2DConditionModel, VQModel
20
20
  from ...schedulers import DDPMScheduler
21
- from ...utils import deprecate, logging, replace_example_docstring
21
+ from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
22
22
  from ...utils.torch_utils import randn_tensor
23
23
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
24
 
25
25
 
26
+ if is_torch_xla_available():
27
+ import torch_xla.core.xla_model as xm
28
+
29
+ XLA_AVAILABLE = True
30
+ else:
31
+ XLA_AVAILABLE = False
32
+
26
33
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
34
 
35
+
28
36
  EXAMPLE_DOC_STRING = """
29
37
  Examples:
30
38
  ```py
@@ -296,6 +304,9 @@ class KandinskyV22Pipeline(DiffusionPipeline):
296
304
  step_idx = i // getattr(self.scheduler, "order", 1)
297
305
  callback(step_idx, t, latents)
298
306
 
307
+ if XLA_AVAILABLE:
308
+ xm.mark_step()
309
+
299
310
  if output_type not in ["pt", "np", "pil", "latent"]:
300
311
  raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
301
312
 
@@ -19,14 +19,23 @@ import torch
19
19
  from ...models import UNet2DConditionModel, VQModel
20
20
  from ...schedulers import DDPMScheduler
21
21
  from ...utils import (
22
+ is_torch_xla_available,
22
23
  logging,
23
24
  )
24
25
  from ...utils.torch_utils import randn_tensor
25
26
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
26
27
 
27
28
 
29
+ if is_torch_xla_available():
30
+ import torch_xla.core.xla_model as xm
31
+
32
+ XLA_AVAILABLE = True
33
+ else:
34
+ XLA_AVAILABLE = False
35
+
28
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
37
 
38
+
30
39
  EXAMPLE_DOC_STRING = """
31
40
  Examples:
32
41
  ```py
@@ -297,6 +306,10 @@ class KandinskyV22ControlnetPipeline(DiffusionPipeline):
297
306
  if callback is not None and i % callback_steps == 0:
298
307
  step_idx = i // getattr(self.scheduler, "order", 1)
299
308
  callback(step_idx, t, latents)
309
+
310
+ if XLA_AVAILABLE:
311
+ xm.mark_step()
312
+
300
313
  # post-processing
301
314
  image = self.movq.decode(latents, force_not_quantize=True)["sample"]
302
315
 
@@ -22,14 +22,23 @@ from PIL import Image
22
22
  from ...models import UNet2DConditionModel, VQModel
23
23
  from ...schedulers import DDPMScheduler
24
24
  from ...utils import (
25
+ is_torch_xla_available,
25
26
  logging,
26
27
  )
27
28
  from ...utils.torch_utils import randn_tensor
28
29
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
29
30
 
30
31
 
32
+ if is_torch_xla_available():
33
+ import torch_xla.core.xla_model as xm
34
+
35
+ XLA_AVAILABLE = True
36
+ else:
37
+ XLA_AVAILABLE = False
38
+
31
39
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
40
 
41
+
33
42
  EXAMPLE_DOC_STRING = """
34
43
  Examples:
35
44
  ```py
@@ -358,6 +367,9 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
358
367
  step_idx = i // getattr(self.scheduler, "order", 1)
359
368
  callback(step_idx, t, latents)
360
369
 
370
+ if XLA_AVAILABLE:
371
+ xm.mark_step()
372
+
361
373
  # post-processing
362
374
  image = self.movq.decode(latents, force_not_quantize=True)["sample"]
363
375
 
@@ -21,13 +21,21 @@ from PIL import Image
21
21
 
22
22
  from ...models import UNet2DConditionModel, VQModel
23
23
  from ...schedulers import DDPMScheduler
24
- from ...utils import deprecate, logging
24
+ from ...utils import deprecate, is_torch_xla_available, logging
25
25
  from ...utils.torch_utils import randn_tensor
26
26
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
27
27
 
28
28
 
29
+ if is_torch_xla_available():
30
+ import torch_xla.core.xla_model as xm
31
+
32
+ XLA_AVAILABLE = True
33
+ else:
34
+ XLA_AVAILABLE = False
35
+
29
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
37
 
38
+
31
39
  EXAMPLE_DOC_STRING = """
32
40
  Examples:
33
41
  ```py
@@ -372,6 +380,9 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
372
380
  step_idx = i // getattr(self.scheduler, "order", 1)
373
381
  callback(step_idx, t, latents)
374
382
 
383
+ if XLA_AVAILABLE:
384
+ xm.mark_step()
385
+
375
386
  if output_type not in ["pt", "np", "pil", "latent"]:
376
387
  raise ValueError(
377
388
  f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
@@ -25,13 +25,21 @@ from PIL import Image
25
25
  from ... import __version__
26
26
  from ...models import UNet2DConditionModel, VQModel
27
27
  from ...schedulers import DDPMScheduler
28
- from ...utils import deprecate, logging
28
+ from ...utils import deprecate, is_torch_xla_available, logging
29
29
  from ...utils.torch_utils import randn_tensor
30
30
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
31
31
 
32
32
 
33
+ if is_torch_xla_available():
34
+ import torch_xla.core.xla_model as xm
35
+
36
+ XLA_AVAILABLE = True
37
+ else:
38
+ XLA_AVAILABLE = False
39
+
33
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
41
 
42
+
35
43
  EXAMPLE_DOC_STRING = """
36
44
  Examples:
37
45
  ```py
@@ -526,6 +534,9 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
526
534
  step_idx = i // getattr(self.scheduler, "order", 1)
527
535
  callback(step_idx, t, latents)
528
536
 
537
+ if XLA_AVAILABLE:
538
+ xm.mark_step()
539
+
529
540
  # post-processing
530
541
  latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents
531
542
 
@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
7
7
  from ...models import PriorTransformer
8
8
  from ...schedulers import UnCLIPScheduler
9
9
  from ...utils import (
10
+ is_torch_xla_available,
10
11
  logging,
11
12
  replace_example_docstring,
12
13
  )
@@ -15,8 +16,16 @@ from ..kandinsky import KandinskyPriorPipelineOutput
15
16
  from ..pipeline_utils import DiffusionPipeline
16
17
 
17
18
 
19
+ if is_torch_xla_available():
20
+ import torch_xla.core.xla_model as xm
21
+
22
+ XLA_AVAILABLE = True
23
+ else:
24
+ XLA_AVAILABLE = False
25
+
18
26
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
27
 
28
+
20
29
  EXAMPLE_DOC_STRING = """
21
30
  Examples:
22
31
  ```py
@@ -524,6 +533,9 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
524
533
  )
525
534
  text_mask = callback_outputs.pop("text_mask", text_mask)
526
535
 
536
+ if XLA_AVAILABLE:
537
+ xm.mark_step()
538
+
527
539
  latents = self.prior.post_process_latents(latents)
528
540
 
529
541
  image_embeddings = latents
@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
7
7
  from ...models import PriorTransformer
8
8
  from ...schedulers import UnCLIPScheduler
9
9
  from ...utils import (
10
+ is_torch_xla_available,
10
11
  logging,
11
12
  replace_example_docstring,
12
13
  )
@@ -15,8 +16,16 @@ from ..kandinsky import KandinskyPriorPipelineOutput
15
16
  from ..pipeline_utils import DiffusionPipeline
16
17
 
17
18
 
19
+ if is_torch_xla_available():
20
+ import torch_xla.core.xla_model as xm
21
+
22
+ XLA_AVAILABLE = True
23
+ else:
24
+ XLA_AVAILABLE = False
25
+
18
26
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
27
 
28
+
20
29
  EXAMPLE_DOC_STRING = """
21
30
  Examples:
22
31
  ```py
@@ -538,6 +547,9 @@ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline):
538
547
  prev_timestep=prev_timestep,
539
548
  ).prev_sample
540
549
 
550
+ if XLA_AVAILABLE:
551
+ xm.mark_step()
552
+
541
553
  latents = self.prior.post_process_latents(latents)
542
554
 
543
555
  image_embeddings = latents
@@ -8,6 +8,7 @@ from ...models import Kandinsky3UNet, VQModel
8
8
  from ...schedulers import DDPMScheduler
9
9
  from ...utils import (
10
10
  deprecate,
11
+ is_torch_xla_available,
11
12
  logging,
12
13
  replace_example_docstring,
13
14
  )
@@ -15,8 +16,16 @@ from ...utils.torch_utils import randn_tensor
15
16
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
16
17
 
17
18
 
19
+ if is_torch_xla_available():
20
+ import torch_xla.core.xla_model as xm
21
+
22
+ XLA_AVAILABLE = True
23
+ else:
24
+ XLA_AVAILABLE = False
25
+
18
26
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
27
 
28
+
20
29
  EXAMPLE_DOC_STRING = """
21
30
  Examples:
22
31
  ```py
@@ -549,6 +558,9 @@ class Kandinsky3Pipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
549
558
  step_idx = i // getattr(self.scheduler, "order", 1)
550
559
  callback(step_idx, t, latents)
551
560
 
561
+ if XLA_AVAILABLE:
562
+ xm.mark_step()
563
+
552
564
  # post-processing
553
565
  if output_type not in ["pt", "np", "pil", "latent"]:
554
566
  raise ValueError(
@@ -12,6 +12,7 @@ from ...models import Kandinsky3UNet, VQModel
12
12
  from ...schedulers import DDPMScheduler
13
13
  from ...utils import (
14
14
  deprecate,
15
+ is_torch_xla_available,
15
16
  logging,
16
17
  replace_example_docstring,
17
18
  )
@@ -19,8 +20,16 @@ from ...utils.torch_utils import randn_tensor
19
20
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
20
21
 
21
22
 
23
+ if is_torch_xla_available():
24
+ import torch_xla.core.xla_model as xm
25
+
26
+ XLA_AVAILABLE = True
27
+ else:
28
+ XLA_AVAILABLE = False
29
+
22
30
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
31
 
32
+
24
33
  EXAMPLE_DOC_STRING = """
25
34
  Examples:
26
35
  ```py
@@ -617,6 +626,9 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
617
626
  step_idx = i // getattr(self.scheduler, "order", 1)
618
627
  callback(step_idx, t, latents)
619
628
 
629
+ if XLA_AVAILABLE:
630
+ xm.mark_step()
631
+
620
632
  # post-processing
621
633
  if output_type not in ["pt", "np", "pil", "latent"]:
622
634
  raise ValueError(
@@ -19,7 +19,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
19
19
 
20
20
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
21
21
  from ...image_processor import PipelineImageInput, VaeImageProcessor
22
- from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
22
+ from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin
23
23
  from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
24
24
  from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
25
25
  from ...schedulers import KarrasDiffusionSchedulers
@@ -121,7 +121,7 @@ def retrieve_timesteps(
121
121
  return timesteps, num_inference_steps
122
122
 
123
123
 
124
- class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin):
124
+ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin):
125
125
  r"""
126
126
  Pipeline for text-to-image generation using Kolors.
127
127
 
@@ -129,8 +129,8 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
129
129
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
130
130
 
131
131
  The pipeline also inherits the following loading methods:
132
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
133
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
132
+ - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
133
+ - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
134
134
  - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
135
135
 
136
136
  Args:
@@ -188,12 +188,14 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
188
188
  feature_extractor=feature_extractor,
189
189
  )
190
190
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
191
- self.vae_scale_factor = (
192
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
193
- )
191
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
194
192
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
195
193
 
196
- self.default_sample_size = self.unet.config.sample_size
194
+ self.default_sample_size = (
195
+ self.unet.config.sample_size
196
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
197
+ else 128
198
+ )
197
199
 
198
200
  def encode_prompt(
199
201
  self,
@@ -207,12 +207,14 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
207
207
  feature_extractor=feature_extractor,
208
208
  )
209
209
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
210
- self.vae_scale_factor = (
211
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
212
- )
210
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
213
211
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
214
212
 
215
- self.default_sample_size = self.unet.config.sample_size
213
+ self.default_sample_size = (
214
+ self.unet.config.sample_size
215
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
216
+ else 128
217
+ )
216
218
 
217
219
  # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt
218
220
  def encode_prompt(