diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import torch.nn.functional as F
18
18
  from torch import nn
19
19
 
20
20
  from ...configuration_utils import LegacyConfigMixin, register_to_config
21
- from ...utils import deprecate, is_torch_version, logging
21
+ from ...utils import deprecate, logging
22
22
  from ..attention import BasicTransformerBlock
23
23
  from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
24
24
  from ..modeling_outputs import Transformer2DModelOutput
@@ -66,6 +66,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
66
66
 
67
67
  _supports_gradient_checkpointing = True
68
68
  _no_split_modules = ["BasicTransformerBlock"]
69
+ _skip_layerwise_casting_patterns = ["latent_image_embedding", "norm"]
69
70
 
70
71
  @register_to_config
71
72
  def __init__(
@@ -210,9 +211,9 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
210
211
 
211
212
  def _init_vectorized_inputs(self, norm_type):
212
213
  assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
213
- assert (
214
- self.config.num_vector_embeds is not None
215
- ), "Transformer2DModel over discrete input must provide num_embed"
214
+ assert self.config.num_vector_embeds is not None, (
215
+ "Transformer2DModel over discrete input must provide num_embed"
216
+ )
216
217
 
217
218
  self.height = self.config.sample_size
218
219
  self.width = self.config.sample_size
@@ -320,10 +321,6 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
320
321
  in_features=self.caption_channels, hidden_size=self.inner_dim
321
322
  )
322
323
 
323
- def _set_gradient_checkpointing(self, module, value=False):
324
- if hasattr(module, "gradient_checkpointing"):
325
- module.gradient_checkpointing = value
326
-
327
324
  def forward(
328
325
  self,
329
326
  hidden_states: torch.Tensor,
@@ -416,19 +413,8 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
416
413
  # 2. Blocks
417
414
  for block in self.transformer_blocks:
418
415
  if torch.is_grad_enabled() and self.gradient_checkpointing:
419
-
420
- def create_custom_forward(module, return_dict=None):
421
- def custom_forward(*inputs):
422
- if return_dict is not None:
423
- return module(*inputs, return_dict=return_dict)
424
- else:
425
- return module(*inputs)
426
-
427
- return custom_forward
428
-
429
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
430
- hidden_states = torch.utils.checkpoint.checkpoint(
431
- create_custom_forward(block),
416
+ hidden_states = self._gradient_checkpointing_func(
417
+ block,
432
418
  hidden_states,
433
419
  attention_mask,
434
420
  encoder_hidden_states,
@@ -436,7 +422,6 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
436
422
  timestep,
437
423
  cross_attention_kwargs,
438
424
  class_labels,
439
- **ckpt_kwargs,
440
425
  )
441
426
  else:
442
427
  hidden_states = block(
@@ -13,17 +13,18 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Any, Dict, Optional, Tuple
16
+ from typing import Optional, Tuple
17
17
 
18
18
  import torch
19
19
  import torch.nn as nn
20
20
  import torch.nn.functional as F
21
21
 
22
22
  from ...configuration_utils import ConfigMixin, register_to_config
23
- from ...utils import is_torch_version, logging
23
+ from ...utils import logging
24
24
  from ...utils.torch_utils import maybe_allow_in_graph
25
25
  from ..attention import FeedForward
26
26
  from ..attention_processor import AllegroAttnProcessor2_0, Attention
27
+ from ..cache_utils import CacheMixin
27
28
  from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
28
29
  from ..modeling_outputs import Transformer2DModelOutput
29
30
  from ..modeling_utils import ModelMixin
@@ -172,7 +173,7 @@ class AllegroTransformerBlock(nn.Module):
172
173
  return hidden_states
173
174
 
174
175
 
175
- class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
176
+ class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
176
177
  _supports_gradient_checkpointing = True
177
178
 
178
179
  """
@@ -221,6 +222,9 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
221
222
  Scaling factor to apply in 3D positional embeddings across time dimension.
222
223
  """
223
224
 
225
+ _supports_gradient_checkpointing = True
226
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
227
+
224
228
  @register_to_config
225
229
  def __init__(
226
230
  self,
@@ -300,9 +304,6 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
300
304
 
301
305
  self.gradient_checkpointing = False
302
306
 
303
- def _set_gradient_checkpointing(self, module, value=False):
304
- self.gradient_checkpointing = value
305
-
306
307
  def forward(
307
308
  self,
308
309
  hidden_states: torch.Tensor,
@@ -372,23 +373,14 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
372
373
  for i, block in enumerate(self.transformer_blocks):
373
374
  # TODO(aryan): Implement gradient checkpointing
374
375
  if torch.is_grad_enabled() and self.gradient_checkpointing:
375
-
376
- def create_custom_forward(module):
377
- def custom_forward(*inputs):
378
- return module(*inputs)
379
-
380
- return custom_forward
381
-
382
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
383
- hidden_states = torch.utils.checkpoint.checkpoint(
384
- create_custom_forward(block),
376
+ hidden_states = self._gradient_checkpointing_func(
377
+ block,
385
378
  hidden_states,
386
379
  encoder_hidden_states,
387
380
  timestep,
388
381
  attention_mask,
389
382
  encoder_attention_mask,
390
383
  image_rotary_emb,
391
- **ckpt_kwargs,
392
384
  )
393
385
  else:
394
386
  hidden_states = block(
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from typing import Any, Dict, Union
16
+ from typing import Dict, Union
17
17
 
18
18
  import torch
19
19
  import torch.nn as nn
@@ -27,7 +27,7 @@ from ...models.attention_processor import (
27
27
  )
28
28
  from ...models.modeling_utils import ModelMixin
29
29
  from ...models.normalization import AdaLayerNormContinuous
30
- from ...utils import is_torch_version, logging
30
+ from ...utils import logging
31
31
  from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
32
32
  from ..modeling_outputs import Transformer2DModelOutput
33
33
  from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
@@ -166,6 +166,8 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
166
166
  """
167
167
 
168
168
  _supports_gradient_checkpointing = True
169
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
170
+ _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
169
171
 
170
172
  @register_to_config
171
173
  def __init__(
@@ -287,10 +289,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
287
289
  for name, module in self.named_children():
288
290
  fn_recursive_attn_processor(name, module, processor)
289
291
 
290
- def _set_gradient_checkpointing(self, module, value=False):
291
- if hasattr(module, "gradient_checkpointing"):
292
- module.gradient_checkpointing = value
293
-
294
292
  def forward(
295
293
  self,
296
294
  hidden_states: torch.Tensor,
@@ -342,20 +340,11 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
342
340
 
343
341
  for index_block, block in enumerate(self.transformer_blocks):
344
342
  if torch.is_grad_enabled() and self.gradient_checkpointing:
345
-
346
- def create_custom_forward(module):
347
- def custom_forward(*inputs):
348
- return module(*inputs)
349
-
350
- return custom_forward
351
-
352
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
353
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
354
- create_custom_forward(block),
343
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
344
+ block,
355
345
  hidden_states,
356
346
  encoder_hidden_states,
357
347
  emb,
358
- **ckpt_kwargs,
359
348
  )
360
349
  else:
361
350
  hidden_states, encoder_hidden_states = block(
@@ -0,0 +1,462 @@
1
+ # Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24
+ from ..attention import FeedForward
25
+ from ..attention_processor import Attention
26
+ from ..cache_utils import CacheMixin
27
+ from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
28
+ from ..modeling_outputs import Transformer2DModelOutput
29
+ from ..modeling_utils import ModelMixin
30
+ from ..normalization import AdaLayerNormContinuous
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ class CogView4PatchEmbed(nn.Module):
37
+ def __init__(
38
+ self,
39
+ in_channels: int = 16,
40
+ hidden_size: int = 2560,
41
+ patch_size: int = 2,
42
+ text_hidden_size: int = 4096,
43
+ ):
44
+ super().__init__()
45
+ self.patch_size = patch_size
46
+
47
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
48
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
49
+
50
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
51
+ batch_size, channel, height, width = hidden_states.shape
52
+ post_patch_height = height // self.patch_size
53
+ post_patch_width = width // self.patch_size
54
+
55
+ hidden_states = hidden_states.reshape(
56
+ batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
57
+ )
58
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
59
+ hidden_states = self.proj(hidden_states)
60
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
61
+
62
+ return hidden_states, encoder_hidden_states
63
+
64
+
65
+ class CogView4AdaLayerNormZero(nn.Module):
66
+ def __init__(self, embedding_dim: int, dim: int) -> None:
67
+ super().__init__()
68
+
69
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
70
+ self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
71
+ self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
72
+
73
+ def forward(
74
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
75
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ norm_hidden_states = self.norm(hidden_states)
77
+ norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
78
+
79
+ emb = self.linear(temb)
80
+ (
81
+ shift_msa,
82
+ c_shift_msa,
83
+ scale_msa,
84
+ c_scale_msa,
85
+ gate_msa,
86
+ c_gate_msa,
87
+ shift_mlp,
88
+ c_shift_mlp,
89
+ scale_mlp,
90
+ c_scale_mlp,
91
+ gate_mlp,
92
+ c_gate_mlp,
93
+ ) = emb.chunk(12, dim=1)
94
+
95
+ hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
96
+ encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
97
+
98
+ return (
99
+ hidden_states,
100
+ gate_msa,
101
+ shift_mlp,
102
+ scale_mlp,
103
+ gate_mlp,
104
+ encoder_hidden_states,
105
+ c_gate_msa,
106
+ c_shift_mlp,
107
+ c_scale_mlp,
108
+ c_gate_mlp,
109
+ )
110
+
111
+
112
+ class CogView4AttnProcessor:
113
+ """
114
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
115
+ query and key vectors, but does not include spatial normalization.
116
+ """
117
+
118
+ def __init__(self):
119
+ if not hasattr(F, "scaled_dot_product_attention"):
120
+ raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
121
+
122
+ def __call__(
123
+ self,
124
+ attn: Attention,
125
+ hidden_states: torch.Tensor,
126
+ encoder_hidden_states: torch.Tensor,
127
+ attention_mask: Optional[torch.Tensor] = None,
128
+ image_rotary_emb: Optional[torch.Tensor] = None,
129
+ ) -> torch.Tensor:
130
+ batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
131
+ batch_size, image_seq_length, embed_dim = hidden_states.shape
132
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
133
+
134
+ # 1. QKV projections
135
+ query = attn.to_q(hidden_states)
136
+ key = attn.to_k(hidden_states)
137
+ value = attn.to_v(hidden_states)
138
+
139
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
140
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
141
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
142
+
143
+ # 2. QK normalization
144
+ if attn.norm_q is not None:
145
+ query = attn.norm_q(query)
146
+ if attn.norm_k is not None:
147
+ key = attn.norm_k(key)
148
+
149
+ # 3. Rotational positional embeddings applied to latent stream
150
+ if image_rotary_emb is not None:
151
+ from ..embeddings import apply_rotary_emb
152
+
153
+ query[:, :, text_seq_length:, :] = apply_rotary_emb(
154
+ query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
155
+ )
156
+ key[:, :, text_seq_length:, :] = apply_rotary_emb(
157
+ key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
158
+ )
159
+
160
+ # 4. Attention
161
+ if attention_mask is not None:
162
+ text_attention_mask = attention_mask.float().to(query.device)
163
+ actual_text_seq_length = text_attention_mask.size(1)
164
+ new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
165
+ new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
166
+ new_attention_mask = new_attention_mask.unsqueeze(2)
167
+ attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
168
+ attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
169
+
170
+ hidden_states = F.scaled_dot_product_attention(
171
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
172
+ )
173
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
174
+ hidden_states = hidden_states.type_as(query)
175
+
176
+ # 5. Output projection
177
+ hidden_states = attn.to_out[0](hidden_states)
178
+ hidden_states = attn.to_out[1](hidden_states)
179
+
180
+ encoder_hidden_states, hidden_states = hidden_states.split(
181
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
182
+ )
183
+ return hidden_states, encoder_hidden_states
184
+
185
+
186
+ class CogView4TransformerBlock(nn.Module):
187
+ def __init__(
188
+ self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512
189
+ ) -> None:
190
+ super().__init__()
191
+
192
+ # 1. Attention
193
+ self.norm1 = CogView4AdaLayerNormZero(time_embed_dim, dim)
194
+ self.attn1 = Attention(
195
+ query_dim=dim,
196
+ heads=num_attention_heads,
197
+ dim_head=attention_head_dim,
198
+ out_dim=dim,
199
+ bias=True,
200
+ qk_norm="layer_norm",
201
+ elementwise_affine=False,
202
+ eps=1e-5,
203
+ processor=CogView4AttnProcessor(),
204
+ )
205
+
206
+ # 2. Feedforward
207
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
208
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
209
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
210
+
211
+ def forward(
212
+ self,
213
+ hidden_states: torch.Tensor,
214
+ encoder_hidden_states: torch.Tensor,
215
+ temb: Optional[torch.Tensor] = None,
216
+ image_rotary_emb: Optional[torch.Tensor] = None,
217
+ attention_mask: Optional[torch.Tensor] = None,
218
+ **kwargs,
219
+ ) -> torch.Tensor:
220
+ # 1. Timestep conditioning
221
+ (
222
+ norm_hidden_states,
223
+ gate_msa,
224
+ shift_mlp,
225
+ scale_mlp,
226
+ gate_mlp,
227
+ norm_encoder_hidden_states,
228
+ c_gate_msa,
229
+ c_shift_mlp,
230
+ c_scale_mlp,
231
+ c_gate_mlp,
232
+ ) = self.norm1(hidden_states, encoder_hidden_states, temb)
233
+
234
+ # 2. Attention
235
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
236
+ hidden_states=norm_hidden_states,
237
+ encoder_hidden_states=norm_encoder_hidden_states,
238
+ image_rotary_emb=image_rotary_emb,
239
+ attention_mask=attention_mask,
240
+ **kwargs,
241
+ )
242
+ hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
243
+ encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
244
+
245
+ # 3. Feedforward
246
+ norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
247
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
248
+ 1 + c_scale_mlp.unsqueeze(1)
249
+ ) + c_shift_mlp.unsqueeze(1)
250
+
251
+ ff_output = self.ff(norm_hidden_states)
252
+ ff_output_context = self.ff(norm_encoder_hidden_states)
253
+ hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
254
+ encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
255
+
256
+ return hidden_states, encoder_hidden_states
257
+
258
+
259
+ class CogView4RotaryPosEmbed(nn.Module):
260
+ def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None:
261
+ super().__init__()
262
+
263
+ self.dim = dim
264
+ self.patch_size = patch_size
265
+ self.rope_axes_dim = rope_axes_dim
266
+ self.theta = theta
267
+
268
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
269
+ batch_size, num_channels, height, width = hidden_states.shape
270
+ height, width = height // self.patch_size, width // self.patch_size
271
+
272
+ dim_h, dim_w = self.dim // 2, self.dim // 2
273
+ h_inv_freq = 1.0 / (
274
+ self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
275
+ )
276
+ w_inv_freq = 1.0 / (
277
+ self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
278
+ )
279
+ h_seq = torch.arange(self.rope_axes_dim[0])
280
+ w_seq = torch.arange(self.rope_axes_dim[1])
281
+ freqs_h = torch.outer(h_seq, h_inv_freq)
282
+ freqs_w = torch.outer(w_seq, w_inv_freq)
283
+
284
+ h_idx = torch.arange(height, device=freqs_h.device)
285
+ w_idx = torch.arange(width, device=freqs_w.device)
286
+ inner_h_idx = h_idx * self.rope_axes_dim[0] // height
287
+ inner_w_idx = w_idx * self.rope_axes_dim[1] // width
288
+
289
+ freqs_h = freqs_h[inner_h_idx]
290
+ freqs_w = freqs_w[inner_w_idx]
291
+
292
+ # Create position matrices for height and width
293
+ # [height, 1, dim//4] and [1, width, dim//4]
294
+ freqs_h = freqs_h.unsqueeze(1)
295
+ freqs_w = freqs_w.unsqueeze(0)
296
+ # Broadcast freqs_h and freqs_w to [height, width, dim//4]
297
+ freqs_h = freqs_h.expand(height, width, -1)
298
+ freqs_w = freqs_w.expand(height, width, -1)
299
+
300
+ # Concatenate along last dimension to get [height, width, dim//2]
301
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1)
302
+ freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
303
+ freqs = freqs.reshape(height * width, -1)
304
+ return (freqs.cos(), freqs.sin())
305
+
306
+
307
+ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
308
+ r"""
309
+ Args:
310
+ patch_size (`int`, defaults to `2`):
311
+ The size of the patches to use in the patch embedding layer.
312
+ in_channels (`int`, defaults to `16`):
313
+ The number of channels in the input.
314
+ num_layers (`int`, defaults to `30`):
315
+ The number of layers of Transformer blocks to use.
316
+ attention_head_dim (`int`, defaults to `40`):
317
+ The number of channels in each head.
318
+ num_attention_heads (`int`, defaults to `64`):
319
+ The number of heads to use for multi-head attention.
320
+ out_channels (`int`, defaults to `16`):
321
+ The number of channels in the output.
322
+ text_embed_dim (`int`, defaults to `4096`):
323
+ Input dimension of text embeddings from the text encoder.
324
+ time_embed_dim (`int`, defaults to `512`):
325
+ Output dimension of timestep embeddings.
326
+ condition_dim (`int`, defaults to `256`):
327
+ The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
328
+ crop_coords).
329
+ pos_embed_max_size (`int`, defaults to `128`):
330
+ The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
331
+ to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
332
+ means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
333
+ patch_size => 128 * 8 * 2 => 2048`.
334
+ sample_size (`int`, defaults to `128`):
335
+ The base resolution of input latents. If height/width is not provided during generation, this value is used
336
+ to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
337
+ """
338
+
339
+ _supports_gradient_checkpointing = True
340
+ _no_split_modules = ["CogView4TransformerBlock", "CogView4PatchEmbed", "CogView4PatchEmbed"]
341
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
342
+
343
+ @register_to_config
344
+ def __init__(
345
+ self,
346
+ patch_size: int = 2,
347
+ in_channels: int = 16,
348
+ out_channels: int = 16,
349
+ num_layers: int = 30,
350
+ attention_head_dim: int = 40,
351
+ num_attention_heads: int = 64,
352
+ text_embed_dim: int = 4096,
353
+ time_embed_dim: int = 512,
354
+ condition_dim: int = 256,
355
+ pos_embed_max_size: int = 128,
356
+ sample_size: int = 128,
357
+ rope_axes_dim: Tuple[int, int] = (256, 256),
358
+ ):
359
+ super().__init__()
360
+
361
+ # CogView4 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
362
+ # Each of these are sincos embeddings of shape 2 * condition_dim
363
+ pooled_projection_dim = 3 * 2 * condition_dim
364
+ inner_dim = num_attention_heads * attention_head_dim
365
+ out_channels = out_channels
366
+
367
+ # 1. RoPE
368
+ self.rope = CogView4RotaryPosEmbed(attention_head_dim, patch_size, rope_axes_dim, theta=10000.0)
369
+
370
+ # 2. Patch & Text-timestep embedding
371
+ self.patch_embed = CogView4PatchEmbed(in_channels, inner_dim, patch_size, text_embed_dim)
372
+
373
+ self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
374
+ embedding_dim=time_embed_dim,
375
+ condition_dim=condition_dim,
376
+ pooled_projection_dim=pooled_projection_dim,
377
+ timesteps_dim=inner_dim,
378
+ )
379
+
380
+ # 3. Transformer blocks
381
+ self.transformer_blocks = nn.ModuleList(
382
+ [
383
+ CogView4TransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
384
+ for _ in range(num_layers)
385
+ ]
386
+ )
387
+
388
+ # 4. Output projection
389
+ self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
390
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
391
+
392
+ self.gradient_checkpointing = False
393
+
394
+ def forward(
395
+ self,
396
+ hidden_states: torch.Tensor,
397
+ encoder_hidden_states: torch.Tensor,
398
+ timestep: torch.LongTensor,
399
+ original_size: torch.Tensor,
400
+ target_size: torch.Tensor,
401
+ crop_coords: torch.Tensor,
402
+ attention_kwargs: Optional[Dict[str, Any]] = None,
403
+ return_dict: bool = True,
404
+ attention_mask: Optional[torch.Tensor] = None,
405
+ **kwargs,
406
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
407
+ if attention_kwargs is not None:
408
+ attention_kwargs = attention_kwargs.copy()
409
+ lora_scale = attention_kwargs.pop("scale", 1.0)
410
+ else:
411
+ lora_scale = 1.0
412
+
413
+ if USE_PEFT_BACKEND:
414
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
415
+ scale_lora_layers(self, lora_scale)
416
+ else:
417
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
418
+ logger.warning(
419
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
420
+ )
421
+
422
+ batch_size, num_channels, height, width = hidden_states.shape
423
+
424
+ # 1. RoPE
425
+ image_rotary_emb = self.rope(hidden_states)
426
+
427
+ # 2. Patch & Timestep embeddings
428
+ p = self.config.patch_size
429
+ post_patch_height = height // p
430
+ post_patch_width = width // p
431
+
432
+ hidden_states, encoder_hidden_states = self.patch_embed(hidden_states, encoder_hidden_states)
433
+
434
+ temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
435
+ temb = F.silu(temb)
436
+
437
+ # 3. Transformer blocks
438
+ for block in self.transformer_blocks:
439
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
440
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
441
+ block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
442
+ )
443
+ else:
444
+ hidden_states, encoder_hidden_states = block(
445
+ hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
446
+ )
447
+
448
+ # 4. Output norm & projection
449
+ hidden_states = self.norm_out(hidden_states, temb)
450
+ hidden_states = self.proj_out(hidden_states)
451
+
452
+ # 5. Unpatchify
453
+ hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
454
+ output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
455
+
456
+ if USE_PEFT_BACKEND:
457
+ # remove `lora_scale` from each PEFT layer
458
+ unscale_lora_layers(self, lora_scale)
459
+
460
+ if not return_dict:
461
+ return (output,)
462
+ return Transformer2DModelOutput(sample=output)