diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,311 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from dataclasses import dataclass
17
+ from typing import Any, Callable, Optional, Tuple, Union
18
+
19
+ import torch
20
+
21
+ from ..models.attention_processor import Attention, MochiAttention
22
+ from ..utils import logging
23
+ from .hooks import HookRegistry, ModelHook
24
+
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+
29
+ _PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
30
+ _ATTENTION_CLASSES = (Attention, MochiAttention)
31
+ _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
32
+ _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
33
+ _CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
34
+
35
+
36
+ @dataclass
37
+ class PyramidAttentionBroadcastConfig:
38
+ r"""
39
+ Configuration for Pyramid Attention Broadcast.
40
+
41
+ Args:
42
+ spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
43
+ The number of times a specific spatial attention broadcast is skipped before computing the attention states
44
+ to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
45
+ old attention states will be re-used) before computing the new attention states again.
46
+ temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
47
+ The number of times a specific temporal attention broadcast is skipped before computing the attention
48
+ states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
49
+ (i.e., old attention states will be re-used) before computing the new attention states again.
50
+ cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
51
+ The number of times a specific cross-attention broadcast is skipped before computing the attention states
52
+ to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
53
+ old attention states will be re-used) before computing the new attention states again.
54
+ spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
55
+ The range of timesteps to skip in the spatial attention layer. The attention computations will be
56
+ conditionally skipped if the current timestep is within the specified range.
57
+ temporal_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
58
+ The range of timesteps to skip in the temporal attention layer. The attention computations will be
59
+ conditionally skipped if the current timestep is within the specified range.
60
+ cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
61
+ The range of timesteps to skip in the cross-attention layer. The attention computations will be
62
+ conditionally skipped if the current timestep is within the specified range.
63
+ spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
64
+ The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
65
+ temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
66
+ The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
67
+ cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
68
+ The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
69
+ """
70
+
71
+ spatial_attention_block_skip_range: Optional[int] = None
72
+ temporal_attention_block_skip_range: Optional[int] = None
73
+ cross_attention_block_skip_range: Optional[int] = None
74
+
75
+ spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
76
+ temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
77
+ cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
78
+
79
+ spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
80
+ temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
81
+ cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
82
+
83
+ current_timestep_callback: Callable[[], int] = None
84
+
85
+ # TODO(aryan): add PAB for MLP layers (very limited speedup from testing with original codebase
86
+ # so not added for now)
87
+
88
+ def __repr__(self) -> str:
89
+ return (
90
+ f"PyramidAttentionBroadcastConfig(\n"
91
+ f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
92
+ f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
93
+ f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
94
+ f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
95
+ f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
96
+ f" cross_attention_timestep_skip_range={self.cross_attention_timestep_skip_range},\n"
97
+ f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
98
+ f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
99
+ f" cross_attention_block_identifiers={self.cross_attention_block_identifiers},\n"
100
+ f" current_timestep_callback={self.current_timestep_callback}\n"
101
+ ")"
102
+ )
103
+
104
+
105
+ class PyramidAttentionBroadcastState:
106
+ r"""
107
+ State for Pyramid Attention Broadcast.
108
+
109
+ Attributes:
110
+ iteration (`int`):
111
+ The current iteration of the Pyramid Attention Broadcast. It is necessary to ensure that `reset_state` is
112
+ called before starting a new inference forward pass for PAB to work correctly.
113
+ cache (`Any`):
114
+ The cached output from the previous forward pass. This is used to re-use the attention states when the
115
+ attention computation is skipped. It is either a tensor or a tuple of tensors, depending on the module.
116
+ """
117
+
118
+ def __init__(self) -> None:
119
+ self.iteration = 0
120
+ self.cache = None
121
+
122
+ def reset(self):
123
+ self.iteration = 0
124
+ self.cache = None
125
+
126
+ def __repr__(self):
127
+ cache_repr = ""
128
+ if self.cache is None:
129
+ cache_repr = "None"
130
+ else:
131
+ cache_repr = f"Tensor(shape={self.cache.shape}, dtype={self.cache.dtype})"
132
+ return f"PyramidAttentionBroadcastState(iteration={self.iteration}, cache={cache_repr})"
133
+
134
+
135
+ class PyramidAttentionBroadcastHook(ModelHook):
136
+ r"""A hook that applies Pyramid Attention Broadcast to a given module."""
137
+
138
+ _is_stateful = True
139
+
140
+ def __init__(
141
+ self, timestep_skip_range: Tuple[int, int], block_skip_range: int, current_timestep_callback: Callable[[], int]
142
+ ) -> None:
143
+ super().__init__()
144
+
145
+ self.timestep_skip_range = timestep_skip_range
146
+ self.block_skip_range = block_skip_range
147
+ self.current_timestep_callback = current_timestep_callback
148
+
149
+ def initialize_hook(self, module):
150
+ self.state = PyramidAttentionBroadcastState()
151
+ return module
152
+
153
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
154
+ is_within_timestep_range = (
155
+ self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
156
+ )
157
+ should_compute_attention = (
158
+ self.state.cache is None
159
+ or self.state.iteration == 0
160
+ or not is_within_timestep_range
161
+ or self.state.iteration % self.block_skip_range == 0
162
+ )
163
+
164
+ if should_compute_attention:
165
+ output = self.fn_ref.original_forward(*args, **kwargs)
166
+ else:
167
+ output = self.state.cache
168
+
169
+ self.state.cache = output
170
+ self.state.iteration += 1
171
+ return output
172
+
173
+ def reset_state(self, module: torch.nn.Module) -> None:
174
+ self.state.reset()
175
+ return module
176
+
177
+
178
+ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
179
+ r"""
180
+ Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
181
+
182
+ PAB is an attention approximation method that leverages the similarity in attention states between timesteps to
183
+ reduce the computational cost of attention computation. The key takeaway from the paper is that the attention
184
+ similarity in the cross-attention layers between timesteps is high, followed by less similarity in the temporal and
185
+ spatial layers. This allows for the skipping of attention computation in the cross-attention layers more frequently
186
+ than in the temporal and spatial layers. Applying PAB will, therefore, speedup the inference process.
187
+
188
+ Args:
189
+ module (`torch.nn.Module`):
190
+ The module to apply Pyramid Attention Broadcast to.
191
+ config (`Optional[PyramidAttentionBroadcastConfig]`, `optional`, defaults to `None`):
192
+ The configuration to use for Pyramid Attention Broadcast.
193
+
194
+ Example:
195
+
196
+ ```python
197
+ >>> import torch
198
+ >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
199
+ >>> from diffusers.utils import export_to_video
200
+
201
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
202
+ >>> pipe.to("cuda")
203
+
204
+ >>> config = PyramidAttentionBroadcastConfig(
205
+ ... spatial_attention_block_skip_range=2,
206
+ ... spatial_attention_timestep_skip_range=(100, 800),
207
+ ... current_timestep_callback=lambda: pipe.current_timestep,
208
+ ... )
209
+ >>> apply_pyramid_attention_broadcast(pipe.transformer, config)
210
+ ```
211
+ """
212
+ if config.current_timestep_callback is None:
213
+ raise ValueError(
214
+ "The `current_timestep_callback` function must be provided in the configuration to apply Pyramid Attention Broadcast."
215
+ )
216
+
217
+ if (
218
+ config.spatial_attention_block_skip_range is None
219
+ and config.temporal_attention_block_skip_range is None
220
+ and config.cross_attention_block_skip_range is None
221
+ ):
222
+ logger.warning(
223
+ "Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
224
+ "or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
225
+ "To avoid this warning, please set one of the above parameters."
226
+ )
227
+ config.spatial_attention_block_skip_range = 2
228
+
229
+ for name, submodule in module.named_modules():
230
+ if not isinstance(submodule, _ATTENTION_CLASSES):
231
+ # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
232
+ # cannot be applied to this layer. For custom layers, users can extend this functionality and implement
233
+ # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
234
+ continue
235
+ _apply_pyramid_attention_broadcast_on_attention_class(name, submodule, config)
236
+
237
+
238
+ def _apply_pyramid_attention_broadcast_on_attention_class(
239
+ name: str, module: Attention, config: PyramidAttentionBroadcastConfig
240
+ ) -> bool:
241
+ is_spatial_self_attention = (
242
+ any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
243
+ and config.spatial_attention_block_skip_range is not None
244
+ and not getattr(module, "is_cross_attention", False)
245
+ )
246
+ is_temporal_self_attention = (
247
+ any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
248
+ and config.temporal_attention_block_skip_range is not None
249
+ and not getattr(module, "is_cross_attention", False)
250
+ )
251
+ is_cross_attention = (
252
+ any(re.search(identifier, name) is not None for identifier in config.cross_attention_block_identifiers)
253
+ and config.cross_attention_block_skip_range is not None
254
+ and getattr(module, "is_cross_attention", False)
255
+ )
256
+
257
+ block_skip_range, timestep_skip_range, block_type = None, None, None
258
+ if is_spatial_self_attention:
259
+ block_skip_range = config.spatial_attention_block_skip_range
260
+ timestep_skip_range = config.spatial_attention_timestep_skip_range
261
+ block_type = "spatial"
262
+ elif is_temporal_self_attention:
263
+ block_skip_range = config.temporal_attention_block_skip_range
264
+ timestep_skip_range = config.temporal_attention_timestep_skip_range
265
+ block_type = "temporal"
266
+ elif is_cross_attention:
267
+ block_skip_range = config.cross_attention_block_skip_range
268
+ timestep_skip_range = config.cross_attention_timestep_skip_range
269
+ block_type = "cross"
270
+
271
+ if block_skip_range is None or timestep_skip_range is None:
272
+ logger.info(
273
+ f'Unable to apply Pyramid Attention Broadcast to the selected layer: "{name}" because it does '
274
+ f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, "
275
+ f"however, that this layer may still be valid for applying PAB. Please specify the correct "
276
+ f"block identifiers in the configuration."
277
+ )
278
+ return False
279
+
280
+ logger.debug(f"Enabling Pyramid Attention Broadcast ({block_type}) in layer: {name}")
281
+ _apply_pyramid_attention_broadcast_hook(
282
+ module, timestep_skip_range, block_skip_range, config.current_timestep_callback
283
+ )
284
+ return True
285
+
286
+
287
+ def _apply_pyramid_attention_broadcast_hook(
288
+ module: Union[Attention, MochiAttention],
289
+ timestep_skip_range: Tuple[int, int],
290
+ block_skip_range: int,
291
+ current_timestep_callback: Callable[[], int],
292
+ ):
293
+ r"""
294
+ Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given torch.nn.Module.
295
+
296
+ Args:
297
+ module (`torch.nn.Module`):
298
+ The module to apply Pyramid Attention Broadcast to.
299
+ timestep_skip_range (`Tuple[int, int]`):
300
+ The range of timesteps to skip in the attention layer. The attention computations will be conditionally
301
+ skipped if the current timestep is within the specified range.
302
+ block_skip_range (`int`):
303
+ The number of times a specific attention broadcast is skipped before computing the attention states to
304
+ re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
305
+ attention states will be re-used) before computing the new attention states again.
306
+ current_timestep_callback (`Callable[[], int]`):
307
+ A callback function that returns the current inference timestep.
308
+ """
309
+ registry = HookRegistry.check_if_exists_or_initialize(module)
310
+ hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
311
+ registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
@@ -70,9 +70,12 @@ if is_torch_available():
70
70
  "LoraLoaderMixin",
71
71
  "FluxLoraLoaderMixin",
72
72
  "CogVideoXLoraLoaderMixin",
73
+ "CogView4LoraLoaderMixin",
73
74
  "Mochi1LoraLoaderMixin",
74
75
  "HunyuanVideoLoraLoaderMixin",
75
76
  "SanaLoraLoaderMixin",
77
+ "Lumina2LoraLoaderMixin",
78
+ "WanLoraLoaderMixin",
76
79
  ]
77
80
  _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
78
81
  _import_structure["ip_adapter"] = [
@@ -101,15 +104,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
101
104
  from .lora_pipeline import (
102
105
  AmusedLoraLoaderMixin,
103
106
  CogVideoXLoraLoaderMixin,
107
+ CogView4LoraLoaderMixin,
104
108
  FluxLoraLoaderMixin,
105
109
  HunyuanVideoLoraLoaderMixin,
106
110
  LoraLoaderMixin,
107
111
  LTXVideoLoraLoaderMixin,
112
+ Lumina2LoraLoaderMixin,
108
113
  Mochi1LoraLoaderMixin,
109
114
  SanaLoraLoaderMixin,
110
115
  SD3LoraLoaderMixin,
111
116
  StableDiffusionLoraLoaderMixin,
112
117
  StableDiffusionXLLoraLoaderMixin,
118
+ WanLoraLoaderMixin,
113
119
  )
114
120
  from .single_file import FromSingleFileMixin
115
121
  from .textual_inversion import TextualInversionLoaderMixin
@@ -23,7 +23,9 @@ from safetensors import safe_open
23
23
  from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
24
24
  from ..utils import (
25
25
  USE_PEFT_BACKEND,
26
+ _get_detailed_type,
26
27
  _get_model_file,
28
+ _is_valid_type,
27
29
  is_accelerate_available,
28
30
  is_torch_version,
29
31
  is_transformers_available,
@@ -213,7 +215,8 @@ class IPAdapterMixin:
213
215
  low_cpu_mem_usage=low_cpu_mem_usage,
214
216
  cache_dir=cache_dir,
215
217
  local_files_only=local_files_only,
216
- ).to(self.device, dtype=self.dtype)
218
+ torch_dtype=self.dtype,
219
+ ).to(self.device)
217
220
  self.register_modules(image_encoder=image_encoder)
218
221
  else:
219
222
  raise ValueError(
@@ -292,8 +295,7 @@ class IPAdapterMixin:
292
295
  ):
293
296
  if len(scale_configs) != len(attn_processor.scale):
294
297
  raise ValueError(
295
- f"Cannot assign {len(scale_configs)} scale_configs to "
296
- f"{len(attn_processor.scale)} IP-Adapter."
298
+ f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
297
299
  )
298
300
  elif len(scale_configs) == 1:
299
301
  scale_configs = scale_configs * len(attn_processor.scale)
@@ -524,8 +526,9 @@ class FluxIPAdapterMixin:
524
526
  low_cpu_mem_usage=low_cpu_mem_usage,
525
527
  cache_dir=cache_dir,
526
528
  local_files_only=local_files_only,
529
+ dtype=image_encoder_dtype,
527
530
  )
528
- .to(self.device, dtype=image_encoder_dtype)
531
+ .to(self.device)
529
532
  .eval()
530
533
  )
531
534
  self.register_modules(image_encoder=image_encoder)
@@ -577,29 +580,36 @@ class FluxIPAdapterMixin:
577
580
  pipeline.set_ip_adapter_scale(ip_strengths)
578
581
  ```
579
582
  """
580
- transformer = self.transformer
581
- if not isinstance(scale, list):
582
- scale = [[scale] * transformer.config.num_layers]
583
- elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
584
- if len(scale) != transformer.config.num_layers:
585
- raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
583
+
584
+ scale_type = Union[int, float]
585
+ num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
586
+ num_layers = self.transformer.config.num_layers
587
+
588
+ # Single value for all layers of all IP-Adapters
589
+ if isinstance(scale, scale_type):
590
+ scale = [scale for _ in range(num_ip_adapters)]
591
+ # List of per-layer scales for a single IP-Adapter
592
+ elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
586
593
  scale = [scale]
594
+ # Invalid scale type
595
+ elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
596
+ raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
587
597
 
588
- scale_configs = scale
598
+ if len(scale) != num_ip_adapters:
599
+ raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
589
600
 
590
- key_id = 0
591
- for attn_name, attn_processor in transformer.attn_processors.items():
592
- if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
593
- if len(scale_configs) != len(attn_processor.scale):
594
- raise ValueError(
595
- f"Cannot assign {len(scale_configs)} scale_configs to "
596
- f"{len(attn_processor.scale)} IP-Adapter."
597
- )
598
- elif len(scale_configs) == 1:
599
- scale_configs = scale_configs * len(attn_processor.scale)
600
- for i, scale_config in enumerate(scale_configs):
601
- attn_processor.scale[i] = scale_config[key_id]
602
- key_id += 1
601
+ if any(len(s) != num_layers for s in scale if isinstance(s, list)):
602
+ invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
603
+ raise ValueError(
604
+ f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
605
+ )
606
+
607
+ # Scalars are transformed to lists with length num_layers
608
+ scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
609
+
610
+ # Set scales. zip over scale_configs prevents going into single transformer layers
611
+ for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
612
+ attn_processor.scale = scale
603
613
 
604
614
  def unload_ip_adapter(self):
605
615
  """
@@ -793,12 +803,10 @@ class SD3IPAdapterMixin:
793
803
  }
794
804
 
795
805
  self.register_modules(
796
- feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
797
- self.device, dtype=self.dtype
798
- ),
799
- image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
800
- self.device, dtype=self.dtype
801
- ),
806
+ feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
807
+ image_encoder=SiglipVisionModel.from_pretrained(
808
+ image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
809
+ ).to(self.device),
802
810
  )
803
811
  else:
804
812
  raise ValueError(