diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ import torch
20
20
  import torch.nn as nn
21
21
  import torch.nn.functional as F
22
22
 
23
- from ..utils import is_torch_version
23
+ from ..utils import is_torch_npu_available, is_torch_version
24
24
  from .activations import get_activation
25
25
  from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
26
26
 
@@ -71,7 +71,7 @@ class AdaLayerNorm(nn.Module):
71
71
 
72
72
  if self.chunk_dim == 1:
73
73
  # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
74
- # other if-branch. This branch is specific to CogVideoX for now.
74
+ # other if-branch. This branch is specific to CogVideoX and OmniGen for now.
75
75
  shift, scale = temb.chunk(2, dim=1)
76
76
  shift = shift[:, None, :]
77
77
  scale = scale[:, None, :]
@@ -219,14 +219,13 @@ class LuminaRMSNormZero(nn.Module):
219
219
  4 * embedding_dim,
220
220
  bias=True,
221
221
  )
222
- self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
222
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps)
223
223
 
224
224
  def forward(
225
225
  self,
226
226
  x: torch.Tensor,
227
227
  emb: Optional[torch.Tensor] = None,
228
228
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
229
- # emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
230
229
  emb = self.linear(self.silu(emb))
231
230
  scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
232
231
  x = self.norm(x) * (1 + scale_msa[:, None])
@@ -307,6 +306,20 @@ class AdaGroupNorm(nn.Module):
307
306
 
308
307
 
309
308
  class AdaLayerNormContinuous(nn.Module):
309
+ r"""
310
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
311
+
312
+ Args:
313
+ embedding_dim (`int`): Embedding dimension to use during projection.
314
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
315
+ elementwise_affine (`bool`, defaults to `True`):
316
+ Boolean flag to denote if affine transformation should be applied.
317
+ eps (`float`, defaults to 1e-5): Epsilon factor.
318
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
319
+ norm_type (`str`, defaults to `"layer_norm"`):
320
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
321
+ """
322
+
310
323
  def __init__(
311
324
  self,
312
325
  embedding_dim: int,
@@ -463,6 +476,17 @@ else:
463
476
  # Has optional bias parameter compared to torch layer norm
464
477
  # TODO: replace with torch layernorm once min required torch version >= 2.1
465
478
  class LayerNorm(nn.Module):
479
+ r"""
480
+ LayerNorm with the bias parameter.
481
+
482
+ Args:
483
+ dim (`int`): Dimensionality to use for the parameters.
484
+ eps (`float`, defaults to 1e-5): Epsilon factor.
485
+ elementwise_affine (`bool`, defaults to `True`):
486
+ Boolean flag to denote if affine transformation should be applied.
487
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
488
+ """
489
+
466
490
  def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
467
491
  super().__init__()
468
492
 
@@ -485,6 +509,17 @@ else:
485
509
 
486
510
 
487
511
  class RMSNorm(nn.Module):
512
+ r"""
513
+ RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
514
+
515
+ Args:
516
+ dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
517
+ eps (`float`): Small value to use when calculating the reciprocal of the square-root.
518
+ elementwise_affine (`bool`, defaults to `True`):
519
+ Boolean flag to denote if affine transformation should be applied.
520
+ bias (`bool`, defaults to False): If also training the `bias` param.
521
+ """
522
+
488
523
  def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
489
524
  super().__init__()
490
525
 
@@ -505,19 +540,30 @@ class RMSNorm(nn.Module):
505
540
  self.bias = nn.Parameter(torch.zeros(dim))
506
541
 
507
542
  def forward(self, hidden_states):
508
- input_dtype = hidden_states.dtype
509
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
510
- hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
511
-
512
- if self.weight is not None:
513
- # convert into half-precision if necessary
514
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
515
- hidden_states = hidden_states.to(self.weight.dtype)
516
- hidden_states = hidden_states * self.weight
543
+ if is_torch_npu_available():
544
+ import torch_npu
545
+
546
+ if self.weight is not None:
547
+ # convert into half-precision if necessary
548
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
549
+ hidden_states = hidden_states.to(self.weight.dtype)
550
+ hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
517
551
  if self.bias is not None:
518
552
  hidden_states = hidden_states + self.bias
519
553
  else:
520
- hidden_states = hidden_states.to(input_dtype)
554
+ input_dtype = hidden_states.dtype
555
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
556
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
557
+
558
+ if self.weight is not None:
559
+ # convert into half-precision if necessary
560
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
561
+ hidden_states = hidden_states.to(self.weight.dtype)
562
+ hidden_states = hidden_states * self.weight
563
+ if self.bias is not None:
564
+ hidden_states = hidden_states + self.bias
565
+ else:
566
+ hidden_states = hidden_states.to(input_dtype)
521
567
 
522
568
  return hidden_states
523
569
 
@@ -553,6 +599,13 @@ class MochiRMSNorm(nn.Module):
553
599
 
554
600
 
555
601
  class GlobalResponseNorm(nn.Module):
602
+ r"""
603
+ Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808).
604
+
605
+ Args:
606
+ dim (`int`): Number of dimensions to use for the `gamma` and `beta`.
607
+ """
608
+
556
609
  # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
557
610
  def __init__(self, dim):
558
611
  super().__init__()
@@ -366,7 +366,7 @@ class ResnetBlock2D(nn.Module):
366
366
  hidden_states = self.conv2(hidden_states)
367
367
 
368
368
  if self.conv_shortcut is not None:
369
- input_tensor = self.conv_shortcut(input_tensor)
369
+ input_tensor = self.conv_shortcut(input_tensor.contiguous())
370
370
 
371
371
  output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
372
372
 
@@ -4,6 +4,7 @@ from ...utils import is_torch_available
4
4
  if is_torch_available():
5
5
  from .auraflow_transformer_2d import AuraFlowTransformer2DModel
6
6
  from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
7
+ from .consisid_transformer_3d import ConsisIDTransformer3DModel
7
8
  from .dit_transformer_2d import DiTTransformer2DModel
8
9
  from .dual_transformer_2d import DualTransformer2DModel
9
10
  from .hunyuan_transformer_2d import HunyuanDiT2DModel
@@ -17,9 +18,14 @@ if is_torch_available():
17
18
  from .transformer_2d import Transformer2DModel
18
19
  from .transformer_allegro import AllegroTransformer3DModel
19
20
  from .transformer_cogview3plus import CogView3PlusTransformer2DModel
21
+ from .transformer_cogview4 import CogView4Transformer2DModel
22
+ from .transformer_easyanimate import EasyAnimateTransformer3DModel
20
23
  from .transformer_flux import FluxTransformer2DModel
21
24
  from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
22
25
  from .transformer_ltx import LTXVideoTransformer3DModel
26
+ from .transformer_lumina2 import Lumina2Transformer2DModel
23
27
  from .transformer_mochi import MochiTransformer3DModel
28
+ from .transformer_omnigen import OmniGenTransformer2DModel
24
29
  from .transformer_sd3 import SD3Transformer2DModel
25
30
  from .transformer_temporal import TransformerTemporalModel
31
+ from .transformer_wan import WanTransformer3DModel
@@ -13,14 +13,15 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from typing import Any, Dict, Union
16
+ from typing import Dict, Union
17
17
 
18
18
  import torch
19
19
  import torch.nn as nn
20
20
  import torch.nn.functional as F
21
21
 
22
22
  from ...configuration_utils import ConfigMixin, register_to_config
23
- from ...utils import is_torch_version, logging
23
+ from ...loaders import FromOriginalModelMixin
24
+ from ...utils import logging
24
25
  from ...utils.torch_utils import maybe_allow_in_graph
25
26
  from ..attention_processor import (
26
27
  Attention,
@@ -253,7 +254,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
253
254
  return encoder_hidden_states, hidden_states
254
255
 
255
256
 
256
- class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
257
+ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
257
258
  r"""
258
259
  A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
259
260
 
@@ -275,6 +276,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
275
276
  """
276
277
 
277
278
  _no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
279
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
278
280
  _supports_gradient_checkpointing = True
279
281
 
280
282
  @register_to_config
@@ -442,10 +444,6 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
442
444
  if self.original_attn_processors is not None:
443
445
  self.set_attn_processor(self.original_attn_processors)
444
446
 
445
- def _set_gradient_checkpointing(self, module, value=False):
446
- if hasattr(module, "gradient_checkpointing"):
447
- module.gradient_checkpointing = value
448
-
449
447
  def forward(
450
448
  self,
451
449
  hidden_states: torch.FloatTensor,
@@ -467,23 +465,11 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
467
465
  # MMDiT blocks.
468
466
  for index_block, block in enumerate(self.joint_transformer_blocks):
469
467
  if torch.is_grad_enabled() and self.gradient_checkpointing:
470
-
471
- def create_custom_forward(module, return_dict=None):
472
- def custom_forward(*inputs):
473
- if return_dict is not None:
474
- return module(*inputs, return_dict=return_dict)
475
- else:
476
- return module(*inputs)
477
-
478
- return custom_forward
479
-
480
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
481
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
482
- create_custom_forward(block),
468
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
469
+ block,
483
470
  hidden_states,
484
471
  encoder_hidden_states,
485
472
  temb,
486
- **ckpt_kwargs,
487
473
  )
488
474
 
489
475
  else:
@@ -498,22 +484,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
498
484
 
499
485
  for index_block, block in enumerate(self.single_transformer_blocks):
500
486
  if torch.is_grad_enabled() and self.gradient_checkpointing:
501
-
502
- def create_custom_forward(module, return_dict=None):
503
- def custom_forward(*inputs):
504
- if return_dict is not None:
505
- return module(*inputs, return_dict=return_dict)
506
- else:
507
- return module(*inputs)
508
-
509
- return custom_forward
510
-
511
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
512
- combined_hidden_states = torch.utils.checkpoint.checkpoint(
513
- create_custom_forward(block),
487
+ combined_hidden_states = self._gradient_checkpointing_func(
488
+ block,
514
489
  combined_hidden_states,
515
490
  temb,
516
- **ckpt_kwargs,
517
491
  )
518
492
 
519
493
  else:
@@ -20,10 +20,11 @@ from torch import nn
20
20
 
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
22
  from ...loaders import PeftAdapterMixin
23
- from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
23
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24
24
  from ...utils.torch_utils import maybe_allow_in_graph
25
25
  from ..attention import Attention, FeedForward
26
26
  from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
27
+ from ..cache_utils import CacheMixin
27
28
  from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
28
29
  from ..modeling_outputs import Transformer2DModelOutput
29
30
  from ..modeling_utils import ModelMixin
@@ -120,8 +121,10 @@ class CogVideoXBlock(nn.Module):
120
121
  encoder_hidden_states: torch.Tensor,
121
122
  temb: torch.Tensor,
122
123
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
124
+ attention_kwargs: Optional[Dict[str, Any]] = None,
123
125
  ) -> torch.Tensor:
124
126
  text_seq_length = encoder_hidden_states.size(1)
127
+ attention_kwargs = attention_kwargs or {}
125
128
 
126
129
  # norm & modulate
127
130
  norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
@@ -133,6 +136,7 @@ class CogVideoXBlock(nn.Module):
133
136
  hidden_states=norm_hidden_states,
134
137
  encoder_hidden_states=norm_encoder_hidden_states,
135
138
  image_rotary_emb=image_rotary_emb,
139
+ **attention_kwargs,
136
140
  )
137
141
 
138
142
  hidden_states = hidden_states + gate_msa * attn_hidden_states
@@ -153,7 +157,7 @@ class CogVideoXBlock(nn.Module):
153
157
  return hidden_states, encoder_hidden_states
154
158
 
155
159
 
156
- class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
160
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
157
161
  """
158
162
  A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
159
163
 
@@ -209,7 +213,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
209
213
  Scaling factor to apply in 3D positional embeddings across temporal dimensions.
210
214
  """
211
215
 
216
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
212
217
  _supports_gradient_checkpointing = True
218
+ _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
213
219
 
214
220
  @register_to_config
215
221
  def __init__(
@@ -325,9 +331,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
325
331
 
326
332
  self.gradient_checkpointing = False
327
333
 
328
- def _set_gradient_checkpointing(self, module, value=False):
329
- self.gradient_checkpointing = value
330
-
331
334
  @property
332
335
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
333
336
  def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -483,21 +486,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
483
486
  # 3. Transformer blocks
484
487
  for i, block in enumerate(self.transformer_blocks):
485
488
  if torch.is_grad_enabled() and self.gradient_checkpointing:
486
-
487
- def create_custom_forward(module):
488
- def custom_forward(*inputs):
489
- return module(*inputs)
490
-
491
- return custom_forward
492
-
493
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
494
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
495
- create_custom_forward(block),
489
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
490
+ block,
496
491
  hidden_states,
497
492
  encoder_hidden_states,
498
493
  emb,
499
494
  image_rotary_emb,
500
- **ckpt_kwargs,
495
+ attention_kwargs,
501
496
  )
502
497
  else:
503
498
  hidden_states, encoder_hidden_states = block(
@@ -505,16 +500,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
505
500
  encoder_hidden_states=encoder_hidden_states,
506
501
  temb=emb,
507
502
  image_rotary_emb=image_rotary_emb,
503
+ attention_kwargs=attention_kwargs,
508
504
  )
509
505
 
510
- if not self.config.use_rotary_positional_embeddings:
511
- # CogVideoX-2B
512
- hidden_states = self.norm_final(hidden_states)
513
- else:
514
- # CogVideoX-5B
515
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
516
- hidden_states = self.norm_final(hidden_states)
517
- hidden_states = hidden_states[:, text_seq_length:]
506
+ hidden_states = self.norm_final(hidden_states)
518
507
 
519
508
  # 4. Final block
520
509
  hidden_states = self.norm_out(hidden_states, temb=emb)