diffusers 0.33.1__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +13 -10
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +38 -18
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/METADATA +70 -55
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/WHEEL +1 -1
  475. diffusers-0.33.1.dist-info/RECORD +0 -608
  476. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  477. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -19,18 +19,13 @@ import torch
19
19
  import torch.nn as nn
20
20
 
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
- from ...models.attention import FeedForward
23
- from ...models.attention_processor import (
24
- Attention,
25
- AttentionProcessor,
26
- CogVideoXAttnProcessor2_0,
27
- )
28
- from ...models.modeling_utils import ModelMixin
29
- from ...models.normalization import AdaLayerNormContinuous
30
22
  from ...utils import logging
23
+ from ..attention import FeedForward
24
+ from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0
31
25
  from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
32
26
  from ..modeling_outputs import Transformer2DModelOutput
33
- from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage
34
29
 
35
30
 
36
31
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple, Union
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  import torch.nn as nn
@@ -73,8 +73,9 @@ class CogView4AdaLayerNormZero(nn.Module):
73
73
  def forward(
74
74
  self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
75
75
  ) -> Tuple[torch.Tensor, torch.Tensor]:
76
- norm_hidden_states = self.norm(hidden_states)
77
- norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
76
+ dtype = hidden_states.dtype
77
+ norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
78
+ norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
78
79
 
79
80
  emb = self.linear(temb)
80
81
  (
@@ -111,8 +112,11 @@ class CogView4AdaLayerNormZero(nn.Module):
111
112
 
112
113
  class CogView4AttnProcessor:
113
114
  """
114
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
115
+ Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on
115
116
  query and key vectors, but does not include spatial normalization.
117
+
118
+ The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
119
+ text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
116
120
  """
117
121
 
118
122
  def __init__(self):
@@ -125,8 +129,10 @@ class CogView4AttnProcessor:
125
129
  hidden_states: torch.Tensor,
126
130
  encoder_hidden_states: torch.Tensor,
127
131
  attention_mask: Optional[torch.Tensor] = None,
128
- image_rotary_emb: Optional[torch.Tensor] = None,
129
- ) -> torch.Tensor:
132
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
133
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ dtype = encoder_hidden_states.dtype
135
+
130
136
  batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
131
137
  batch_size, image_seq_length, embed_dim = hidden_states.shape
132
138
  hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -142,9 +148,9 @@ class CogView4AttnProcessor:
142
148
 
143
149
  # 2. QK normalization
144
150
  if attn.norm_q is not None:
145
- query = attn.norm_q(query)
151
+ query = attn.norm_q(query).to(dtype=dtype)
146
152
  if attn.norm_k is not None:
147
- key = attn.norm_k(key)
153
+ key = attn.norm_k(key).to(dtype=dtype)
148
154
 
149
155
  # 3. Rotational positional embeddings applied to latent stream
150
156
  if image_rotary_emb is not None:
@@ -159,13 +165,14 @@ class CogView4AttnProcessor:
159
165
 
160
166
  # 4. Attention
161
167
  if attention_mask is not None:
162
- text_attention_mask = attention_mask.float().to(query.device)
163
- actual_text_seq_length = text_attention_mask.size(1)
164
- new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
165
- new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
166
- new_attention_mask = new_attention_mask.unsqueeze(2)
167
- attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
168
- attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
168
+ text_attn_mask = attention_mask
169
+ assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
170
+ text_attn_mask = text_attn_mask.float().to(query.device)
171
+ mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
172
+ mix_attn_mask[:, :text_seq_length] = text_attn_mask
173
+ mix_attn_mask = mix_attn_mask.unsqueeze(2)
174
+ attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
175
+ attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
169
176
 
170
177
  hidden_states = F.scaled_dot_product_attention(
171
178
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
@@ -183,9 +190,276 @@ class CogView4AttnProcessor:
183
190
  return hidden_states, encoder_hidden_states
184
191
 
185
192
 
193
+ class CogView4TrainingAttnProcessor:
194
+ """
195
+ Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary
196
+ embedding on query and key vectors, but does not include spatial normalization.
197
+
198
+ This processor differs from CogView4AttnProcessor in several important ways:
199
+ 1. It supports attention masking with variable sequence lengths for multi-resolution training
200
+ 2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is
201
+ provided
202
+ """
203
+
204
+ def __init__(self):
205
+ if not hasattr(F, "scaled_dot_product_attention"):
206
+ raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
207
+
208
+ def __call__(
209
+ self,
210
+ attn: Attention,
211
+ hidden_states: torch.Tensor,
212
+ encoder_hidden_states: torch.Tensor,
213
+ latent_attn_mask: Optional[torch.Tensor] = None,
214
+ text_attn_mask: Optional[torch.Tensor] = None,
215
+ batch_flag: Optional[torch.Tensor] = None,
216
+ image_rotary_emb: Optional[
217
+ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
218
+ ] = None,
219
+ **kwargs,
220
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
221
+ """
222
+ Args:
223
+ attn (`Attention`):
224
+ The attention module.
225
+ hidden_states (`torch.Tensor`):
226
+ The input hidden states.
227
+ encoder_hidden_states (`torch.Tensor`):
228
+ The encoder hidden states for cross-attention.
229
+ latent_attn_mask (`torch.Tensor`, *optional*):
230
+ Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full
231
+ attention is used for all latent tokens. Note: the shape of latent_attn_mask is (batch_size,
232
+ num_latent_tokens).
233
+ text_attn_mask (`torch.Tensor`, *optional*):
234
+ Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention
235
+ is used for all text tokens.
236
+ batch_flag (`torch.Tensor`, *optional*):
237
+ Values from 0 to n-1 indicating which samples belong to the same batch. Samples with the same
238
+ batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form
239
+ batch1, and samples 3-4 form batch2. If None, no packing is used.
240
+ image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*):
241
+ The rotary embedding for the image part of the input.
242
+ Returns:
243
+ `Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams.
244
+ """
245
+
246
+ # Get dimensions and device info
247
+ batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
248
+ batch_size, image_seq_length, embed_dim = hidden_states.shape
249
+ dtype = encoder_hidden_states.dtype
250
+ device = encoder_hidden_states.device
251
+ latent_hidden_states = hidden_states
252
+ # Combine text and image streams for joint processing
253
+ mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1)
254
+
255
+ # 1. Construct attention mask and maybe packing input
256
+ # Create default masks if not provided
257
+ if text_attn_mask is None:
258
+ text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device)
259
+ if latent_attn_mask is None:
260
+ latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device)
261
+
262
+ # Validate mask shapes and types
263
+ assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
264
+ assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32"
265
+ assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)"
266
+ assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32"
267
+
268
+ # Create combined mask for text and image tokens
269
+ mixed_attn_mask = torch.ones(
270
+ (batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device
271
+ )
272
+ mixed_attn_mask[:, :text_seq_length] = text_attn_mask
273
+ mixed_attn_mask[:, text_seq_length:] = latent_attn_mask
274
+
275
+ # Convert mask to attention matrix format (where 1 means attend, 0 means don't attend)
276
+ mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype)
277
+ attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2)
278
+
279
+ # Handle batch packing if enabled
280
+ if batch_flag is not None:
281
+ assert batch_flag.dim() == 1
282
+ # Determine packed batch size based on batch_flag
283
+ packing_batch_size = torch.max(batch_flag).item() + 1
284
+
285
+ # Calculate actual sequence lengths for each sample based on masks
286
+ text_seq_length = torch.sum(text_attn_mask, dim=1)
287
+ latent_seq_length = torch.sum(latent_attn_mask, dim=1)
288
+ mixed_seq_length = text_seq_length + latent_seq_length
289
+
290
+ # Calculate packed sequence lengths for each packed batch
291
+ mixed_seq_length_packed = [
292
+ torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size)
293
+ ]
294
+
295
+ assert len(mixed_seq_length_packed) == packing_batch_size
296
+
297
+ # Pack sequences by removing padding tokens
298
+ mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1)
299
+ mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1)
300
+ mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1]
301
+ assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0]
302
+
303
+ # Split the unpadded sequence into packed batches
304
+ mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed)
305
+
306
+ # Re-pad to create packed batches with right-side padding
307
+ mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence(
308
+ mixed_hidden_states_packed,
309
+ batch_first=True,
310
+ padding_value=0.0,
311
+ padding_side="right",
312
+ )
313
+
314
+ # Create attention mask for packed batches
315
+ l = mixed_hidden_states_packed_padded.shape[1]
316
+ attn_mask_matrix = torch.zeros(
317
+ (packing_batch_size, l, l),
318
+ dtype=dtype,
319
+ device=device,
320
+ )
321
+
322
+ # Fill attention mask with block diagonal matrices
323
+ # This ensures that tokens can only attend to other tokens within the same original sample
324
+ for idx, mask in enumerate(attn_mask_matrix):
325
+ seq_lengths = mixed_seq_length[batch_flag == idx]
326
+ offset = 0
327
+ for length in seq_lengths:
328
+ # Create a block of 1s for each sample in the packed batch
329
+ mask[offset : offset + length, offset : offset + length] = 1
330
+ offset += length
331
+
332
+ attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool)
333
+ attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim
334
+ attention_mask = attn_mask_matrix
335
+
336
+ # Prepare hidden states for attention computation
337
+ if batch_flag is None:
338
+ # If no packing, just combine text and image tokens
339
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
340
+ else:
341
+ # If packing, use the packed sequence
342
+ hidden_states = mixed_hidden_states_packed_padded
343
+
344
+ # 2. QKV projections - convert hidden states to query, key, value
345
+ query = attn.to_q(hidden_states)
346
+ key = attn.to_k(hidden_states)
347
+ value = attn.to_v(hidden_states)
348
+
349
+ # Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim]
350
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
351
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
352
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
353
+
354
+ # 3. QK normalization - apply layer norm to queries and keys if configured
355
+ if attn.norm_q is not None:
356
+ query = attn.norm_q(query).to(dtype=dtype)
357
+ if attn.norm_k is not None:
358
+ key = attn.norm_k(key).to(dtype=dtype)
359
+
360
+ # 4. Apply rotary positional embeddings to image tokens only
361
+ if image_rotary_emb is not None:
362
+ from ..embeddings import apply_rotary_emb
363
+
364
+ if batch_flag is None:
365
+ # Apply RoPE only to image tokens (after text tokens)
366
+ query[:, :, text_seq_length:, :] = apply_rotary_emb(
367
+ query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
368
+ )
369
+ key[:, :, text_seq_length:, :] = apply_rotary_emb(
370
+ key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
371
+ )
372
+ else:
373
+ # For packed batches, need to carefully apply RoPE to appropriate tokens
374
+ assert query.shape[0] == packing_batch_size
375
+ assert key.shape[0] == packing_batch_size
376
+ assert len(image_rotary_emb) == batch_size
377
+
378
+ rope_idx = 0
379
+ for idx in range(packing_batch_size):
380
+ offset = 0
381
+ # Get text and image sequence lengths for samples in this packed batch
382
+ text_seq_length_bi = text_seq_length[batch_flag == idx]
383
+ latent_seq_length_bi = latent_seq_length[batch_flag == idx]
384
+
385
+ # Apply RoPE to each image segment in the packed sequence
386
+ for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi):
387
+ mlen = tlen + llen
388
+ # Apply RoPE only to image tokens (after text tokens)
389
+ query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
390
+ query[idx, :, offset + tlen : offset + mlen, :],
391
+ image_rotary_emb[rope_idx],
392
+ use_real_unbind_dim=-2,
393
+ )
394
+ key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
395
+ key[idx, :, offset + tlen : offset + mlen, :],
396
+ image_rotary_emb[rope_idx],
397
+ use_real_unbind_dim=-2,
398
+ )
399
+ offset += mlen
400
+ rope_idx += 1
401
+
402
+ hidden_states = F.scaled_dot_product_attention(
403
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
404
+ )
405
+
406
+ # Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim]
407
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
408
+ hidden_states = hidden_states.type_as(query)
409
+
410
+ # 5. Output projection - project attention output to model dimension
411
+ hidden_states = attn.to_out[0](hidden_states)
412
+ hidden_states = attn.to_out[1](hidden_states)
413
+
414
+ # Split the output back into text and image streams
415
+ if batch_flag is None:
416
+ # Simple split for non-packed case
417
+ encoder_hidden_states, hidden_states = hidden_states.split(
418
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
419
+ )
420
+ else:
421
+ # For packed case: need to unpack, split text/image, then restore to original shapes
422
+ # First, unpad the sequence based on the packed sequence lengths
423
+ hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence(
424
+ hidden_states,
425
+ lengths=torch.tensor(mixed_seq_length_packed),
426
+ batch_first=True,
427
+ )
428
+ # Concatenate all unpadded sequences
429
+ hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0)
430
+ # Split by original sample sequence lengths
431
+ hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist())
432
+ assert len(hidden_states_unpack) == batch_size
433
+
434
+ # Further split each sample's sequence into text and image parts
435
+ hidden_states_unpack = [
436
+ torch.split(h, [tlen, llen])
437
+ for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length)
438
+ ]
439
+ # Separate text and image sequences
440
+ encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack]
441
+ hidden_states_unpad = [h[1] for h in hidden_states_unpack]
442
+
443
+ # Update the original tensors with the processed values, respecting the attention masks
444
+ for idx in range(batch_size):
445
+ # Place unpacked text tokens back in the encoder_hidden_states tensor
446
+ encoder_hidden_states[idx][text_attn_mask[idx] == 1] = encoder_hidden_states_unpad[idx]
447
+ # Place unpacked image tokens back in the latent_hidden_states tensor
448
+ latent_hidden_states[idx][latent_attn_mask[idx] == 1] = hidden_states_unpad[idx]
449
+
450
+ # Update the output hidden states
451
+ hidden_states = latent_hidden_states
452
+
453
+ return hidden_states, encoder_hidden_states
454
+
455
+
186
456
  class CogView4TransformerBlock(nn.Module):
187
457
  def __init__(
188
- self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512
458
+ self,
459
+ dim: int = 2560,
460
+ num_attention_heads: int = 64,
461
+ attention_head_dim: int = 40,
462
+ time_embed_dim: int = 512,
189
463
  ) -> None:
190
464
  super().__init__()
191
465
 
@@ -213,9 +487,11 @@ class CogView4TransformerBlock(nn.Module):
213
487
  hidden_states: torch.Tensor,
214
488
  encoder_hidden_states: torch.Tensor,
215
489
  temb: Optional[torch.Tensor] = None,
216
- image_rotary_emb: Optional[torch.Tensor] = None,
217
- attention_mask: Optional[torch.Tensor] = None,
218
- **kwargs,
490
+ image_rotary_emb: Optional[
491
+ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
492
+ ] = None,
493
+ attention_mask: Optional[Dict[str, torch.Tensor]] = None,
494
+ attention_kwargs: Optional[Dict[str, Any]] = None,
219
495
  ) -> torch.Tensor:
220
496
  # 1. Timestep conditioning
221
497
  (
@@ -232,12 +508,14 @@ class CogView4TransformerBlock(nn.Module):
232
508
  ) = self.norm1(hidden_states, encoder_hidden_states, temb)
233
509
 
234
510
  # 2. Attention
511
+ if attention_kwargs is None:
512
+ attention_kwargs = {}
235
513
  attn_hidden_states, attn_encoder_hidden_states = self.attn1(
236
514
  hidden_states=norm_hidden_states,
237
515
  encoder_hidden_states=norm_encoder_hidden_states,
238
516
  image_rotary_emb=image_rotary_emb,
239
517
  attention_mask=attention_mask,
240
- **kwargs,
518
+ **attention_kwargs,
241
519
  )
242
520
  hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
243
521
  encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
@@ -402,7 +680,9 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
402
680
  attention_kwargs: Optional[Dict[str, Any]] = None,
403
681
  return_dict: bool = True,
404
682
  attention_mask: Optional[torch.Tensor] = None,
405
- **kwargs,
683
+ image_rotary_emb: Optional[
684
+ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
685
+ ] = None,
406
686
  ) -> Union[torch.Tensor, Transformer2DModelOutput]:
407
687
  if attention_kwargs is not None:
408
688
  attention_kwargs = attention_kwargs.copy()
@@ -422,7 +702,8 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
422
702
  batch_size, num_channels, height, width = hidden_states.shape
423
703
 
424
704
  # 1. RoPE
425
- image_rotary_emb = self.rope(hidden_states)
705
+ if image_rotary_emb is None:
706
+ image_rotary_emb = self.rope(hidden_states)
426
707
 
427
708
  # 2. Patch & Timestep embeddings
428
709
  p = self.config.patch_size
@@ -438,11 +719,22 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
438
719
  for block in self.transformer_blocks:
439
720
  if torch.is_grad_enabled() and self.gradient_checkpointing:
440
721
  hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
441
- block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
722
+ block,
723
+ hidden_states,
724
+ encoder_hidden_states,
725
+ temb,
726
+ image_rotary_emb,
727
+ attention_mask,
728
+ attention_kwargs,
442
729
  )
443
730
  else:
444
731
  hidden_states, encoder_hidden_states = block(
445
- hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
732
+ hidden_states,
733
+ encoder_hidden_states,
734
+ temb,
735
+ image_rotary_emb,
736
+ attention_mask,
737
+ attention_kwargs,
446
738
  )
447
739
 
448
740
  # 4. Output norm & projection