diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +595 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
  384. diffusers-0.33.1.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,789 @@
1
+ # Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24
+ from ...utils.torch_utils import maybe_allow_in_graph
25
+ from ..attention import Attention, FeedForward
26
+ from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
27
+ from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
28
+ from ..modeling_outputs import Transformer2DModelOutput
29
+ from ..modeling_utils import ModelMixin
30
+ from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ class PerceiverAttention(nn.Module):
37
+ def __init__(self, dim: int, dim_head: int = 64, heads: int = 8, kv_dim: Optional[int] = None):
38
+ super().__init__()
39
+
40
+ self.scale = dim_head**-0.5
41
+ self.dim_head = dim_head
42
+ self.heads = heads
43
+ inner_dim = dim_head * heads
44
+
45
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
46
+ self.norm2 = nn.LayerNorm(dim)
47
+
48
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
49
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
50
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
51
+
52
+ def forward(self, image_embeds: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
53
+ # Apply normalization
54
+ image_embeds = self.norm1(image_embeds)
55
+ latents = self.norm2(latents)
56
+
57
+ batch_size, seq_len, _ = latents.shape # Get batch size and sequence length
58
+
59
+ # Compute query, key, and value matrices
60
+ query = self.to_q(latents)
61
+ kv_input = torch.cat((image_embeds, latents), dim=-2)
62
+ key, value = self.to_kv(kv_input).chunk(2, dim=-1)
63
+
64
+ # Reshape the tensors for multi-head attention
65
+ query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
66
+ key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
67
+ value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
68
+
69
+ # attention
70
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
71
+ weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
72
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
73
+ output = weight @ value
74
+
75
+ # Reshape and return the final output
76
+ output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
77
+
78
+ return self.to_out(output)
79
+
80
+
81
+ class LocalFacialExtractor(nn.Module):
82
+ def __init__(
83
+ self,
84
+ id_dim: int = 1280,
85
+ vit_dim: int = 1024,
86
+ depth: int = 10,
87
+ dim_head: int = 64,
88
+ heads: int = 16,
89
+ num_id_token: int = 5,
90
+ num_queries: int = 32,
91
+ output_dim: int = 2048,
92
+ ff_mult: int = 4,
93
+ num_scale: int = 5,
94
+ ):
95
+ super().__init__()
96
+
97
+ # Storing identity token and query information
98
+ self.num_id_token = num_id_token
99
+ self.vit_dim = vit_dim
100
+ self.num_queries = num_queries
101
+ assert depth % num_scale == 0
102
+ self.depth = depth // num_scale
103
+ self.num_scale = num_scale
104
+ scale = vit_dim**-0.5
105
+
106
+ # Learnable latent query embeddings
107
+ self.latents = nn.Parameter(torch.randn(1, num_queries, vit_dim) * scale)
108
+ # Projection layer to map the latent output to the desired dimension
109
+ self.proj_out = nn.Parameter(scale * torch.randn(vit_dim, output_dim))
110
+
111
+ # Attention and ConsisIDFeedForward layer stack
112
+ self.layers = nn.ModuleList([])
113
+ for _ in range(depth):
114
+ self.layers.append(
115
+ nn.ModuleList(
116
+ [
117
+ PerceiverAttention(dim=vit_dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
118
+ nn.Sequential(
119
+ nn.LayerNorm(vit_dim),
120
+ nn.Linear(vit_dim, vit_dim * ff_mult, bias=False),
121
+ nn.GELU(),
122
+ nn.Linear(vit_dim * ff_mult, vit_dim, bias=False),
123
+ ), # ConsisIDFeedForward layer
124
+ ]
125
+ )
126
+ )
127
+
128
+ # Mappings for each of the 5 different ViT features
129
+ for i in range(num_scale):
130
+ setattr(
131
+ self,
132
+ f"mapping_{i}",
133
+ nn.Sequential(
134
+ nn.Linear(vit_dim, vit_dim),
135
+ nn.LayerNorm(vit_dim),
136
+ nn.LeakyReLU(),
137
+ nn.Linear(vit_dim, vit_dim),
138
+ nn.LayerNorm(vit_dim),
139
+ nn.LeakyReLU(),
140
+ nn.Linear(vit_dim, vit_dim),
141
+ ),
142
+ )
143
+
144
+ # Mapping for identity embedding vectors
145
+ self.id_embedding_mapping = nn.Sequential(
146
+ nn.Linear(id_dim, vit_dim),
147
+ nn.LayerNorm(vit_dim),
148
+ nn.LeakyReLU(),
149
+ nn.Linear(vit_dim, vit_dim),
150
+ nn.LayerNorm(vit_dim),
151
+ nn.LeakyReLU(),
152
+ nn.Linear(vit_dim, vit_dim * num_id_token),
153
+ )
154
+
155
+ def forward(self, id_embeds: torch.Tensor, vit_hidden_states: List[torch.Tensor]) -> torch.Tensor:
156
+ # Repeat latent queries for the batch size
157
+ latents = self.latents.repeat(id_embeds.size(0), 1, 1)
158
+
159
+ # Map the identity embedding to tokens
160
+ id_embeds = self.id_embedding_mapping(id_embeds)
161
+ id_embeds = id_embeds.reshape(-1, self.num_id_token, self.vit_dim)
162
+
163
+ # Concatenate identity tokens with the latent queries
164
+ latents = torch.cat((latents, id_embeds), dim=1)
165
+
166
+ # Process each of the num_scale visual feature inputs
167
+ for i in range(self.num_scale):
168
+ vit_feature = getattr(self, f"mapping_{i}")(vit_hidden_states[i])
169
+ ctx_feature = torch.cat((id_embeds, vit_feature), dim=1)
170
+
171
+ # Pass through the PerceiverAttention and ConsisIDFeedForward layers
172
+ for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
173
+ latents = attn(ctx_feature, latents) + latents
174
+ latents = ff(latents) + latents
175
+
176
+ # Retain only the query latents
177
+ latents = latents[:, : self.num_queries]
178
+ # Project the latents to the output dimension
179
+ latents = latents @ self.proj_out
180
+ return latents
181
+
182
+
183
+ class PerceiverCrossAttention(nn.Module):
184
+ def __init__(self, dim: int = 3072, dim_head: int = 128, heads: int = 16, kv_dim: int = 2048):
185
+ super().__init__()
186
+
187
+ self.scale = dim_head**-0.5
188
+ self.dim_head = dim_head
189
+ self.heads = heads
190
+ inner_dim = dim_head * heads
191
+
192
+ # Layer normalization to stabilize training
193
+ self.norm1 = nn.LayerNorm(dim if kv_dim is None else kv_dim)
194
+ self.norm2 = nn.LayerNorm(dim)
195
+
196
+ # Linear transformations to produce queries, keys, and values
197
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
198
+ self.to_kv = nn.Linear(dim if kv_dim is None else kv_dim, inner_dim * 2, bias=False)
199
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
200
+
201
+ def forward(self, image_embeds: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
202
+ # Apply layer normalization to the input image and latent features
203
+ image_embeds = self.norm1(image_embeds)
204
+ hidden_states = self.norm2(hidden_states)
205
+
206
+ batch_size, seq_len, _ = hidden_states.shape
207
+
208
+ # Compute queries, keys, and values
209
+ query = self.to_q(hidden_states)
210
+ key, value = self.to_kv(image_embeds).chunk(2, dim=-1)
211
+
212
+ # Reshape tensors to split into attention heads
213
+ query = query.reshape(query.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
214
+ key = key.reshape(key.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
215
+ value = value.reshape(value.size(0), -1, self.heads, self.dim_head).transpose(1, 2)
216
+
217
+ # Compute attention weights
218
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
219
+ weight = (query * scale) @ (key * scale).transpose(-2, -1) # More stable scaling than post-division
220
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
221
+
222
+ # Compute the output via weighted combination of values
223
+ out = weight @ value
224
+
225
+ # Reshape and permute to prepare for final linear transformation
226
+ out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
227
+
228
+ return self.to_out(out)
229
+
230
+
231
+ @maybe_allow_in_graph
232
+ class ConsisIDBlock(nn.Module):
233
+ r"""
234
+ Transformer block used in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) model.
235
+
236
+ Parameters:
237
+ dim (`int`):
238
+ The number of channels in the input and output.
239
+ num_attention_heads (`int`):
240
+ The number of heads to use for multi-head attention.
241
+ attention_head_dim (`int`):
242
+ The number of channels in each head.
243
+ time_embed_dim (`int`):
244
+ The number of channels in timestep embedding.
245
+ dropout (`float`, defaults to `0.0`):
246
+ The dropout probability to use.
247
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
248
+ Activation function to be used in feed-forward.
249
+ attention_bias (`bool`, defaults to `False`):
250
+ Whether or not to use bias in attention projection layers.
251
+ qk_norm (`bool`, defaults to `True`):
252
+ Whether or not to use normalization after query and key projections in Attention.
253
+ norm_elementwise_affine (`bool`, defaults to `True`):
254
+ Whether to use learnable elementwise affine parameters for normalization.
255
+ norm_eps (`float`, defaults to `1e-5`):
256
+ Epsilon value for normalization layers.
257
+ final_dropout (`bool` defaults to `False`):
258
+ Whether to apply a final dropout after the last feed-forward layer.
259
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
260
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
261
+ ff_bias (`bool`, defaults to `True`):
262
+ Whether or not to use bias in Feed-forward layer.
263
+ attention_out_bias (`bool`, defaults to `True`):
264
+ Whether or not to use bias in Attention output projection layer.
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ dim: int,
270
+ num_attention_heads: int,
271
+ attention_head_dim: int,
272
+ time_embed_dim: int,
273
+ dropout: float = 0.0,
274
+ activation_fn: str = "gelu-approximate",
275
+ attention_bias: bool = False,
276
+ qk_norm: bool = True,
277
+ norm_elementwise_affine: bool = True,
278
+ norm_eps: float = 1e-5,
279
+ final_dropout: bool = True,
280
+ ff_inner_dim: Optional[int] = None,
281
+ ff_bias: bool = True,
282
+ attention_out_bias: bool = True,
283
+ ):
284
+ super().__init__()
285
+
286
+ # 1. Self Attention
287
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
288
+
289
+ self.attn1 = Attention(
290
+ query_dim=dim,
291
+ dim_head=attention_head_dim,
292
+ heads=num_attention_heads,
293
+ qk_norm="layer_norm" if qk_norm else None,
294
+ eps=1e-6,
295
+ bias=attention_bias,
296
+ out_bias=attention_out_bias,
297
+ processor=CogVideoXAttnProcessor2_0(),
298
+ )
299
+
300
+ # 2. Feed Forward
301
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
302
+
303
+ self.ff = FeedForward(
304
+ dim,
305
+ dropout=dropout,
306
+ activation_fn=activation_fn,
307
+ final_dropout=final_dropout,
308
+ inner_dim=ff_inner_dim,
309
+ bias=ff_bias,
310
+ )
311
+
312
+ def forward(
313
+ self,
314
+ hidden_states: torch.Tensor,
315
+ encoder_hidden_states: torch.Tensor,
316
+ temb: torch.Tensor,
317
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
318
+ ) -> torch.Tensor:
319
+ text_seq_length = encoder_hidden_states.size(1)
320
+
321
+ # norm & modulate
322
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
323
+ hidden_states, encoder_hidden_states, temb
324
+ )
325
+
326
+ # attention
327
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
328
+ hidden_states=norm_hidden_states,
329
+ encoder_hidden_states=norm_encoder_hidden_states,
330
+ image_rotary_emb=image_rotary_emb,
331
+ )
332
+
333
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
334
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
335
+
336
+ # norm & modulate
337
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
338
+ hidden_states, encoder_hidden_states, temb
339
+ )
340
+
341
+ # feed-forward
342
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
343
+ ff_output = self.ff(norm_hidden_states)
344
+
345
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
346
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
347
+
348
+ return hidden_states, encoder_hidden_states
349
+
350
+
351
+ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
352
+ """
353
+ A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID).
354
+
355
+ Parameters:
356
+ num_attention_heads (`int`, defaults to `30`):
357
+ The number of heads to use for multi-head attention.
358
+ attention_head_dim (`int`, defaults to `64`):
359
+ The number of channels in each head.
360
+ in_channels (`int`, defaults to `16`):
361
+ The number of channels in the input.
362
+ out_channels (`int`, *optional*, defaults to `16`):
363
+ The number of channels in the output.
364
+ flip_sin_to_cos (`bool`, defaults to `True`):
365
+ Whether to flip the sin to cos in the time embedding.
366
+ time_embed_dim (`int`, defaults to `512`):
367
+ Output dimension of timestep embeddings.
368
+ text_embed_dim (`int`, defaults to `4096`):
369
+ Input dimension of text embeddings from the text encoder.
370
+ num_layers (`int`, defaults to `30`):
371
+ The number of layers of Transformer blocks to use.
372
+ dropout (`float`, defaults to `0.0`):
373
+ The dropout probability to use.
374
+ attention_bias (`bool`, defaults to `True`):
375
+ Whether to use bias in the attention projection layers.
376
+ sample_width (`int`, defaults to `90`):
377
+ The width of the input latents.
378
+ sample_height (`int`, defaults to `60`):
379
+ The height of the input latents.
380
+ sample_frames (`int`, defaults to `49`):
381
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
382
+ instead of 13 because ConsisID processed 13 latent frames at once in its default and recommended settings,
383
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
384
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
385
+ patch_size (`int`, defaults to `2`):
386
+ The size of the patches to use in the patch embedding layer.
387
+ temporal_compression_ratio (`int`, defaults to `4`):
388
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
389
+ max_text_seq_length (`int`, defaults to `226`):
390
+ The maximum sequence length of the input text embeddings.
391
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
392
+ Activation function to use in feed-forward.
393
+ timestep_activation_fn (`str`, defaults to `"silu"`):
394
+ Activation function to use when generating the timestep embeddings.
395
+ norm_elementwise_affine (`bool`, defaults to `True`):
396
+ Whether to use elementwise affine in normalization layers.
397
+ norm_eps (`float`, defaults to `1e-5`):
398
+ The epsilon value to use in normalization layers.
399
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
400
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
401
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
402
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
403
+ is_train_face (`bool`, defaults to `False`):
404
+ Whether to use enable the identity-preserving module during the training process. When set to `True`, the
405
+ model will focus on identity-preserving tasks.
406
+ is_kps (`bool`, defaults to `False`):
407
+ Whether to enable keypoint for global facial extractor. If `True`, keypoints will be in the model.
408
+ cross_attn_interval (`int`, defaults to `2`):
409
+ The interval between cross-attention layers in the Transformer architecture. A larger value may reduce the
410
+ frequency of cross-attention computations, which can help reduce computational overhead.
411
+ cross_attn_dim_head (`int`, optional, defaults to `128`):
412
+ The dimensionality of each attention head in the cross-attention layers of the Transformer architecture. A
413
+ larger value increases the capacity to attend to more complex patterns, but also increases memory and
414
+ computation costs.
415
+ cross_attn_num_heads (`int`, optional, defaults to `16`):
416
+ The number of attention heads in the cross-attention layers. More heads allow for more parallel attention
417
+ mechanisms, capturing diverse relationships between different components of the input, but can also
418
+ increase computational requirements.
419
+ LFE_id_dim (`int`, optional, defaults to `1280`):
420
+ The dimensionality of the identity vector used in the Local Facial Extractor (LFE). This vector represents
421
+ the identity features of a face, which are important for tasks like face recognition and identity
422
+ preservation across different frames.
423
+ LFE_vit_dim (`int`, optional, defaults to `1024`):
424
+ The dimension of the vision transformer (ViT) output used in the Local Facial Extractor (LFE). This value
425
+ dictates the size of the transformer-generated feature vectors that will be processed for facial feature
426
+ extraction.
427
+ LFE_depth (`int`, optional, defaults to `10`):
428
+ The number of layers in the Local Facial Extractor (LFE). Increasing the depth allows the model to capture
429
+ more complex representations of facial features, but also increases the computational load.
430
+ LFE_dim_head (`int`, optional, defaults to `64`):
431
+ The dimensionality of each attention head in the Local Facial Extractor (LFE). This parameter affects how
432
+ finely the model can process and focus on different parts of the facial features during the extraction
433
+ process.
434
+ LFE_num_heads (`int`, optional, defaults to `16`):
435
+ The number of attention heads in the Local Facial Extractor (LFE). More heads can improve the model's
436
+ ability to capture diverse facial features, but at the cost of increased computational complexity.
437
+ LFE_num_id_token (`int`, optional, defaults to `5`):
438
+ The number of identity tokens used in the Local Facial Extractor (LFE). This defines how many
439
+ identity-related tokens the model will process to ensure face identity preservation during feature
440
+ extraction.
441
+ LFE_num_querie (`int`, optional, defaults to `32`):
442
+ The number of query tokens used in the Local Facial Extractor (LFE). These tokens are used to capture
443
+ high-frequency face-related information that aids in accurate facial feature extraction.
444
+ LFE_output_dim (`int`, optional, defaults to `2048`):
445
+ The output dimension of the Local Facial Extractor (LFE). This dimension determines the size of the feature
446
+ vectors produced by the LFE module, which will be used for subsequent tasks such as face recognition or
447
+ tracking.
448
+ LFE_ff_mult (`int`, optional, defaults to `4`):
449
+ The multiplication factor applied to the feed-forward network's hidden layer size in the Local Facial
450
+ Extractor (LFE). A higher value increases the model's capacity to learn more complex facial feature
451
+ transformations, but also increases the computation and memory requirements.
452
+ LFE_num_scale (`int`, optional, defaults to `5`):
453
+ The number of different scales visual feature. A higher value increases the model's capacity to learn more
454
+ complex facial feature transformations, but also increases the computation and memory requirements.
455
+ local_face_scale (`float`, defaults to `1.0`):
456
+ A scaling factor used to adjust the importance of local facial features in the model. This can influence
457
+ how strongly the model focuses on high frequency face-related content.
458
+ """
459
+
460
+ _supports_gradient_checkpointing = True
461
+
462
+ @register_to_config
463
+ def __init__(
464
+ self,
465
+ num_attention_heads: int = 30,
466
+ attention_head_dim: int = 64,
467
+ in_channels: int = 16,
468
+ out_channels: Optional[int] = 16,
469
+ flip_sin_to_cos: bool = True,
470
+ freq_shift: int = 0,
471
+ time_embed_dim: int = 512,
472
+ text_embed_dim: int = 4096,
473
+ num_layers: int = 30,
474
+ dropout: float = 0.0,
475
+ attention_bias: bool = True,
476
+ sample_width: int = 90,
477
+ sample_height: int = 60,
478
+ sample_frames: int = 49,
479
+ patch_size: int = 2,
480
+ temporal_compression_ratio: int = 4,
481
+ max_text_seq_length: int = 226,
482
+ activation_fn: str = "gelu-approximate",
483
+ timestep_activation_fn: str = "silu",
484
+ norm_elementwise_affine: bool = True,
485
+ norm_eps: float = 1e-5,
486
+ spatial_interpolation_scale: float = 1.875,
487
+ temporal_interpolation_scale: float = 1.0,
488
+ use_rotary_positional_embeddings: bool = False,
489
+ use_learned_positional_embeddings: bool = False,
490
+ is_train_face: bool = False,
491
+ is_kps: bool = False,
492
+ cross_attn_interval: int = 2,
493
+ cross_attn_dim_head: int = 128,
494
+ cross_attn_num_heads: int = 16,
495
+ LFE_id_dim: int = 1280,
496
+ LFE_vit_dim: int = 1024,
497
+ LFE_depth: int = 10,
498
+ LFE_dim_head: int = 64,
499
+ LFE_num_heads: int = 16,
500
+ LFE_num_id_token: int = 5,
501
+ LFE_num_querie: int = 32,
502
+ LFE_output_dim: int = 2048,
503
+ LFE_ff_mult: int = 4,
504
+ LFE_num_scale: int = 5,
505
+ local_face_scale: float = 1.0,
506
+ ):
507
+ super().__init__()
508
+ inner_dim = num_attention_heads * attention_head_dim
509
+
510
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
511
+ raise ValueError(
512
+ "There are no ConsisID checkpoints available with disable rotary embeddings and learned positional "
513
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
514
+ "issue at https://github.com/huggingface/diffusers/issues."
515
+ )
516
+
517
+ # 1. Patch embedding
518
+ self.patch_embed = CogVideoXPatchEmbed(
519
+ patch_size=patch_size,
520
+ in_channels=in_channels,
521
+ embed_dim=inner_dim,
522
+ text_embed_dim=text_embed_dim,
523
+ bias=True,
524
+ sample_width=sample_width,
525
+ sample_height=sample_height,
526
+ sample_frames=sample_frames,
527
+ temporal_compression_ratio=temporal_compression_ratio,
528
+ max_text_seq_length=max_text_seq_length,
529
+ spatial_interpolation_scale=spatial_interpolation_scale,
530
+ temporal_interpolation_scale=temporal_interpolation_scale,
531
+ use_positional_embeddings=not use_rotary_positional_embeddings,
532
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
533
+ )
534
+ self.embedding_dropout = nn.Dropout(dropout)
535
+
536
+ # 2. Time embeddings
537
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
538
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
539
+
540
+ # 3. Define spatio-temporal transformers blocks
541
+ self.transformer_blocks = nn.ModuleList(
542
+ [
543
+ ConsisIDBlock(
544
+ dim=inner_dim,
545
+ num_attention_heads=num_attention_heads,
546
+ attention_head_dim=attention_head_dim,
547
+ time_embed_dim=time_embed_dim,
548
+ dropout=dropout,
549
+ activation_fn=activation_fn,
550
+ attention_bias=attention_bias,
551
+ norm_elementwise_affine=norm_elementwise_affine,
552
+ norm_eps=norm_eps,
553
+ )
554
+ for _ in range(num_layers)
555
+ ]
556
+ )
557
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
558
+
559
+ # 4. Output blocks
560
+ self.norm_out = AdaLayerNorm(
561
+ embedding_dim=time_embed_dim,
562
+ output_dim=2 * inner_dim,
563
+ norm_elementwise_affine=norm_elementwise_affine,
564
+ norm_eps=norm_eps,
565
+ chunk_dim=1,
566
+ )
567
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
568
+
569
+ self.is_train_face = is_train_face
570
+ self.is_kps = is_kps
571
+
572
+ # 5. Define identity-preserving config
573
+ if is_train_face:
574
+ # LFE configs
575
+ self.LFE_id_dim = LFE_id_dim
576
+ self.LFE_vit_dim = LFE_vit_dim
577
+ self.LFE_depth = LFE_depth
578
+ self.LFE_dim_head = LFE_dim_head
579
+ self.LFE_num_heads = LFE_num_heads
580
+ self.LFE_num_id_token = LFE_num_id_token
581
+ self.LFE_num_querie = LFE_num_querie
582
+ self.LFE_output_dim = LFE_output_dim
583
+ self.LFE_ff_mult = LFE_ff_mult
584
+ self.LFE_num_scale = LFE_num_scale
585
+ # cross configs
586
+ self.inner_dim = inner_dim
587
+ self.cross_attn_interval = cross_attn_interval
588
+ self.num_cross_attn = num_layers // cross_attn_interval
589
+ self.cross_attn_dim_head = cross_attn_dim_head
590
+ self.cross_attn_num_heads = cross_attn_num_heads
591
+ self.cross_attn_kv_dim = int(self.inner_dim / 3 * 2)
592
+ self.local_face_scale = local_face_scale
593
+ # face modules
594
+ self._init_face_inputs()
595
+
596
+ self.gradient_checkpointing = False
597
+
598
+ def _init_face_inputs(self):
599
+ self.local_facial_extractor = LocalFacialExtractor(
600
+ id_dim=self.LFE_id_dim,
601
+ vit_dim=self.LFE_vit_dim,
602
+ depth=self.LFE_depth,
603
+ dim_head=self.LFE_dim_head,
604
+ heads=self.LFE_num_heads,
605
+ num_id_token=self.LFE_num_id_token,
606
+ num_queries=self.LFE_num_querie,
607
+ output_dim=self.LFE_output_dim,
608
+ ff_mult=self.LFE_ff_mult,
609
+ num_scale=self.LFE_num_scale,
610
+ )
611
+ self.perceiver_cross_attention = nn.ModuleList(
612
+ [
613
+ PerceiverCrossAttention(
614
+ dim=self.inner_dim,
615
+ dim_head=self.cross_attn_dim_head,
616
+ heads=self.cross_attn_num_heads,
617
+ kv_dim=self.cross_attn_kv_dim,
618
+ )
619
+ for _ in range(self.num_cross_attn)
620
+ ]
621
+ )
622
+
623
+ @property
624
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
625
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
626
+ r"""
627
+ Returns:
628
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
629
+ indexed by its weight name.
630
+ """
631
+ # set recursively
632
+ processors = {}
633
+
634
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
635
+ if hasattr(module, "get_processor"):
636
+ processors[f"{name}.processor"] = module.get_processor()
637
+
638
+ for sub_name, child in module.named_children():
639
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
640
+
641
+ return processors
642
+
643
+ for name, module in self.named_children():
644
+ fn_recursive_add_processors(name, module, processors)
645
+
646
+ return processors
647
+
648
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
649
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
650
+ r"""
651
+ Sets the attention processor to use to compute attention.
652
+
653
+ Parameters:
654
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
655
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
656
+ for **all** `Attention` layers.
657
+
658
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
659
+ processor. This is strongly recommended when setting trainable attention processors.
660
+
661
+ """
662
+ count = len(self.attn_processors.keys())
663
+
664
+ if isinstance(processor, dict) and len(processor) != count:
665
+ raise ValueError(
666
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
667
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
668
+ )
669
+
670
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
671
+ if hasattr(module, "set_processor"):
672
+ if not isinstance(processor, dict):
673
+ module.set_processor(processor)
674
+ else:
675
+ module.set_processor(processor.pop(f"{name}.processor"))
676
+
677
+ for sub_name, child in module.named_children():
678
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
679
+
680
+ for name, module in self.named_children():
681
+ fn_recursive_attn_processor(name, module, processor)
682
+
683
+ def forward(
684
+ self,
685
+ hidden_states: torch.Tensor,
686
+ encoder_hidden_states: torch.Tensor,
687
+ timestep: Union[int, float, torch.LongTensor],
688
+ timestep_cond: Optional[torch.Tensor] = None,
689
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
690
+ attention_kwargs: Optional[Dict[str, Any]] = None,
691
+ id_cond: Optional[torch.Tensor] = None,
692
+ id_vit_hidden: Optional[torch.Tensor] = None,
693
+ return_dict: bool = True,
694
+ ):
695
+ if attention_kwargs is not None:
696
+ attention_kwargs = attention_kwargs.copy()
697
+ lora_scale = attention_kwargs.pop("scale", 1.0)
698
+ else:
699
+ lora_scale = 1.0
700
+
701
+ if USE_PEFT_BACKEND:
702
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
703
+ scale_lora_layers(self, lora_scale)
704
+ else:
705
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
706
+ logger.warning(
707
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
708
+ )
709
+
710
+ # fuse clip and insightface
711
+ valid_face_emb = None
712
+ if self.is_train_face:
713
+ id_cond = id_cond.to(device=hidden_states.device, dtype=hidden_states.dtype)
714
+ id_vit_hidden = [
715
+ tensor.to(device=hidden_states.device, dtype=hidden_states.dtype) for tensor in id_vit_hidden
716
+ ]
717
+ valid_face_emb = self.local_facial_extractor(
718
+ id_cond, id_vit_hidden
719
+ ) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])
720
+
721
+ batch_size, num_frames, channels, height, width = hidden_states.shape
722
+
723
+ # 1. Time embedding
724
+ timesteps = timestep
725
+ t_emb = self.time_proj(timesteps)
726
+
727
+ # timesteps does not contain any weights and will always return f32 tensors
728
+ # but time_embedding might actually be running in fp16. so we need to cast here.
729
+ # there might be better ways to encapsulate this.
730
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
731
+ emb = self.time_embedding(t_emb, timestep_cond)
732
+
733
+ # 2. Patch embedding
734
+ # torch.Size([1, 226, 4096]) torch.Size([1, 13, 32, 60, 90])
735
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # torch.Size([1, 17776, 3072])
736
+ hidden_states = self.embedding_dropout(hidden_states) # torch.Size([1, 17776, 3072])
737
+
738
+ text_seq_length = encoder_hidden_states.shape[1]
739
+ encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072])
740
+ hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072])
741
+
742
+ # 3. Transformer blocks
743
+ ca_idx = 0
744
+ for i, block in enumerate(self.transformer_blocks):
745
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
746
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
747
+ block,
748
+ hidden_states,
749
+ encoder_hidden_states,
750
+ emb,
751
+ image_rotary_emb,
752
+ )
753
+ else:
754
+ hidden_states, encoder_hidden_states = block(
755
+ hidden_states=hidden_states,
756
+ encoder_hidden_states=encoder_hidden_states,
757
+ temb=emb,
758
+ image_rotary_emb=image_rotary_emb,
759
+ )
760
+
761
+ if self.is_train_face:
762
+ if i % self.cross_attn_interval == 0 and valid_face_emb is not None:
763
+ hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](
764
+ valid_face_emb, hidden_states
765
+ ) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
766
+ ca_idx += 1
767
+
768
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
769
+ hidden_states = self.norm_final(hidden_states)
770
+ hidden_states = hidden_states[:, text_seq_length:]
771
+
772
+ # 4. Final block
773
+ hidden_states = self.norm_out(hidden_states, temb=emb)
774
+ hidden_states = self.proj_out(hidden_states)
775
+
776
+ # 5. Unpatchify
777
+ # Note: we use `-1` instead of `channels`:
778
+ # - It is okay to `channels` use for ConsisID (number of input channels is equal to output channels)
779
+ p = self.config.patch_size
780
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
781
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
782
+
783
+ if USE_PEFT_BACKEND:
784
+ # remove `lora_scale` from each PEFT layer
785
+ unscale_lora_layers(self, lora_scale)
786
+
787
+ if not return_dict:
788
+ return (output,)
789
+ return Transformer2DModelOutput(sample=output)