diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple, Union
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  import torch.nn as nn
@@ -21,13 +21,14 @@ import torch.nn.functional as F
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
22
  from ...loaders import PeftAdapterMixin
23
23
  from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24
+ from ...utils.torch_utils import maybe_allow_in_graph
24
25
  from ..attention import FeedForward
25
26
  from ..attention_processor import Attention
26
27
  from ..cache_utils import CacheMixin
27
28
  from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
28
29
  from ..modeling_outputs import Transformer2DModelOutput
29
30
  from ..modeling_utils import ModelMixin
30
- from ..normalization import AdaLayerNormContinuous
31
+ from ..normalization import LayerNorm, RMSNorm
31
32
 
32
33
 
33
34
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -73,8 +74,9 @@ class CogView4AdaLayerNormZero(nn.Module):
73
74
  def forward(
74
75
  self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
75
76
  ) -> Tuple[torch.Tensor, torch.Tensor]:
76
- norm_hidden_states = self.norm(hidden_states)
77
- norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
77
+ dtype = hidden_states.dtype
78
+ norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
79
+ norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
78
80
 
79
81
  emb = self.linear(temb)
80
82
  (
@@ -111,8 +113,11 @@ class CogView4AdaLayerNormZero(nn.Module):
111
113
 
112
114
  class CogView4AttnProcessor:
113
115
  """
114
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
116
+ Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on
115
117
  query and key vectors, but does not include spatial normalization.
118
+
119
+ The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
120
+ text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
116
121
  """
117
122
 
118
123
  def __init__(self):
@@ -125,8 +130,10 @@ class CogView4AttnProcessor:
125
130
  hidden_states: torch.Tensor,
126
131
  encoder_hidden_states: torch.Tensor,
127
132
  attention_mask: Optional[torch.Tensor] = None,
128
- image_rotary_emb: Optional[torch.Tensor] = None,
129
- ) -> torch.Tensor:
133
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
134
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
135
+ dtype = encoder_hidden_states.dtype
136
+
130
137
  batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
131
138
  batch_size, image_seq_length, embed_dim = hidden_states.shape
132
139
  hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -142,9 +149,9 @@ class CogView4AttnProcessor:
142
149
 
143
150
  # 2. QK normalization
144
151
  if attn.norm_q is not None:
145
- query = attn.norm_q(query)
152
+ query = attn.norm_q(query).to(dtype=dtype)
146
153
  if attn.norm_k is not None:
147
- key = attn.norm_k(key)
154
+ key = attn.norm_k(key).to(dtype=dtype)
148
155
 
149
156
  # 3. Rotational positional embeddings applied to latent stream
150
157
  if image_rotary_emb is not None:
@@ -159,13 +166,14 @@ class CogView4AttnProcessor:
159
166
 
160
167
  # 4. Attention
161
168
  if attention_mask is not None:
162
- text_attention_mask = attention_mask.float().to(query.device)
163
- actual_text_seq_length = text_attention_mask.size(1)
164
- new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
165
- new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
166
- new_attention_mask = new_attention_mask.unsqueeze(2)
167
- attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
168
- attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
169
+ text_attn_mask = attention_mask
170
+ assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
171
+ text_attn_mask = text_attn_mask.float().to(query.device)
172
+ mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
173
+ mix_attn_mask[:, :text_seq_length] = text_attn_mask
174
+ mix_attn_mask = mix_attn_mask.unsqueeze(2)
175
+ attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
176
+ attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
169
177
 
170
178
  hidden_states = F.scaled_dot_product_attention(
171
179
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
@@ -183,9 +191,277 @@ class CogView4AttnProcessor:
183
191
  return hidden_states, encoder_hidden_states
184
192
 
185
193
 
194
+ class CogView4TrainingAttnProcessor:
195
+ """
196
+ Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary
197
+ embedding on query and key vectors, but does not include spatial normalization.
198
+
199
+ This processor differs from CogView4AttnProcessor in several important ways:
200
+ 1. It supports attention masking with variable sequence lengths for multi-resolution training
201
+ 2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is
202
+ provided
203
+ """
204
+
205
+ def __init__(self):
206
+ if not hasattr(F, "scaled_dot_product_attention"):
207
+ raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
208
+
209
+ def __call__(
210
+ self,
211
+ attn: Attention,
212
+ hidden_states: torch.Tensor,
213
+ encoder_hidden_states: torch.Tensor,
214
+ latent_attn_mask: Optional[torch.Tensor] = None,
215
+ text_attn_mask: Optional[torch.Tensor] = None,
216
+ batch_flag: Optional[torch.Tensor] = None,
217
+ image_rotary_emb: Optional[
218
+ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
219
+ ] = None,
220
+ **kwargs,
221
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
222
+ """
223
+ Args:
224
+ attn (`Attention`):
225
+ The attention module.
226
+ hidden_states (`torch.Tensor`):
227
+ The input hidden states.
228
+ encoder_hidden_states (`torch.Tensor`):
229
+ The encoder hidden states for cross-attention.
230
+ latent_attn_mask (`torch.Tensor`, *optional*):
231
+ Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full
232
+ attention is used for all latent tokens. Note: the shape of latent_attn_mask is (batch_size,
233
+ num_latent_tokens).
234
+ text_attn_mask (`torch.Tensor`, *optional*):
235
+ Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention
236
+ is used for all text tokens.
237
+ batch_flag (`torch.Tensor`, *optional*):
238
+ Values from 0 to n-1 indicating which samples belong to the same batch. Samples with the same
239
+ batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form
240
+ batch1, and samples 3-4 form batch2. If None, no packing is used.
241
+ image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*):
242
+ The rotary embedding for the image part of the input.
243
+ Returns:
244
+ `Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams.
245
+ """
246
+
247
+ # Get dimensions and device info
248
+ batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
249
+ batch_size, image_seq_length, embed_dim = hidden_states.shape
250
+ dtype = encoder_hidden_states.dtype
251
+ device = encoder_hidden_states.device
252
+ latent_hidden_states = hidden_states
253
+ # Combine text and image streams for joint processing
254
+ mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1)
255
+
256
+ # 1. Construct attention mask and maybe packing input
257
+ # Create default masks if not provided
258
+ if text_attn_mask is None:
259
+ text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device)
260
+ if latent_attn_mask is None:
261
+ latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device)
262
+
263
+ # Validate mask shapes and types
264
+ assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
265
+ assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32"
266
+ assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)"
267
+ assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32"
268
+
269
+ # Create combined mask for text and image tokens
270
+ mixed_attn_mask = torch.ones(
271
+ (batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device
272
+ )
273
+ mixed_attn_mask[:, :text_seq_length] = text_attn_mask
274
+ mixed_attn_mask[:, text_seq_length:] = latent_attn_mask
275
+
276
+ # Convert mask to attention matrix format (where 1 means attend, 0 means don't attend)
277
+ mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype)
278
+ attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2)
279
+
280
+ # Handle batch packing if enabled
281
+ if batch_flag is not None:
282
+ assert batch_flag.dim() == 1
283
+ # Determine packed batch size based on batch_flag
284
+ packing_batch_size = torch.max(batch_flag).item() + 1
285
+
286
+ # Calculate actual sequence lengths for each sample based on masks
287
+ text_seq_length = torch.sum(text_attn_mask, dim=1)
288
+ latent_seq_length = torch.sum(latent_attn_mask, dim=1)
289
+ mixed_seq_length = text_seq_length + latent_seq_length
290
+
291
+ # Calculate packed sequence lengths for each packed batch
292
+ mixed_seq_length_packed = [
293
+ torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size)
294
+ ]
295
+
296
+ assert len(mixed_seq_length_packed) == packing_batch_size
297
+
298
+ # Pack sequences by removing padding tokens
299
+ mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1)
300
+ mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1)
301
+ mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1]
302
+ assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0]
303
+
304
+ # Split the unpadded sequence into packed batches
305
+ mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed)
306
+
307
+ # Re-pad to create packed batches with right-side padding
308
+ mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence(
309
+ mixed_hidden_states_packed,
310
+ batch_first=True,
311
+ padding_value=0.0,
312
+ padding_side="right",
313
+ )
314
+
315
+ # Create attention mask for packed batches
316
+ l = mixed_hidden_states_packed_padded.shape[1]
317
+ attn_mask_matrix = torch.zeros(
318
+ (packing_batch_size, l, l),
319
+ dtype=dtype,
320
+ device=device,
321
+ )
322
+
323
+ # Fill attention mask with block diagonal matrices
324
+ # This ensures that tokens can only attend to other tokens within the same original sample
325
+ for idx, mask in enumerate(attn_mask_matrix):
326
+ seq_lengths = mixed_seq_length[batch_flag == idx]
327
+ offset = 0
328
+ for length in seq_lengths:
329
+ # Create a block of 1s for each sample in the packed batch
330
+ mask[offset : offset + length, offset : offset + length] = 1
331
+ offset += length
332
+
333
+ attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool)
334
+ attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim
335
+ attention_mask = attn_mask_matrix
336
+
337
+ # Prepare hidden states for attention computation
338
+ if batch_flag is None:
339
+ # If no packing, just combine text and image tokens
340
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
341
+ else:
342
+ # If packing, use the packed sequence
343
+ hidden_states = mixed_hidden_states_packed_padded
344
+
345
+ # 2. QKV projections - convert hidden states to query, key, value
346
+ query = attn.to_q(hidden_states)
347
+ key = attn.to_k(hidden_states)
348
+ value = attn.to_v(hidden_states)
349
+
350
+ # Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim]
351
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
352
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
353
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
354
+
355
+ # 3. QK normalization - apply layer norm to queries and keys if configured
356
+ if attn.norm_q is not None:
357
+ query = attn.norm_q(query).to(dtype=dtype)
358
+ if attn.norm_k is not None:
359
+ key = attn.norm_k(key).to(dtype=dtype)
360
+
361
+ # 4. Apply rotary positional embeddings to image tokens only
362
+ if image_rotary_emb is not None:
363
+ from ..embeddings import apply_rotary_emb
364
+
365
+ if batch_flag is None:
366
+ # Apply RoPE only to image tokens (after text tokens)
367
+ query[:, :, text_seq_length:, :] = apply_rotary_emb(
368
+ query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
369
+ )
370
+ key[:, :, text_seq_length:, :] = apply_rotary_emb(
371
+ key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
372
+ )
373
+ else:
374
+ # For packed batches, need to carefully apply RoPE to appropriate tokens
375
+ assert query.shape[0] == packing_batch_size
376
+ assert key.shape[0] == packing_batch_size
377
+ assert len(image_rotary_emb) == batch_size
378
+
379
+ rope_idx = 0
380
+ for idx in range(packing_batch_size):
381
+ offset = 0
382
+ # Get text and image sequence lengths for samples in this packed batch
383
+ text_seq_length_bi = text_seq_length[batch_flag == idx]
384
+ latent_seq_length_bi = latent_seq_length[batch_flag == idx]
385
+
386
+ # Apply RoPE to each image segment in the packed sequence
387
+ for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi):
388
+ mlen = tlen + llen
389
+ # Apply RoPE only to image tokens (after text tokens)
390
+ query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
391
+ query[idx, :, offset + tlen : offset + mlen, :],
392
+ image_rotary_emb[rope_idx],
393
+ use_real_unbind_dim=-2,
394
+ )
395
+ key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
396
+ key[idx, :, offset + tlen : offset + mlen, :],
397
+ image_rotary_emb[rope_idx],
398
+ use_real_unbind_dim=-2,
399
+ )
400
+ offset += mlen
401
+ rope_idx += 1
402
+
403
+ hidden_states = F.scaled_dot_product_attention(
404
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
405
+ )
406
+
407
+ # Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim]
408
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
409
+ hidden_states = hidden_states.type_as(query)
410
+
411
+ # 5. Output projection - project attention output to model dimension
412
+ hidden_states = attn.to_out[0](hidden_states)
413
+ hidden_states = attn.to_out[1](hidden_states)
414
+
415
+ # Split the output back into text and image streams
416
+ if batch_flag is None:
417
+ # Simple split for non-packed case
418
+ encoder_hidden_states, hidden_states = hidden_states.split(
419
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
420
+ )
421
+ else:
422
+ # For packed case: need to unpack, split text/image, then restore to original shapes
423
+ # First, unpad the sequence based on the packed sequence lengths
424
+ hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence(
425
+ hidden_states,
426
+ lengths=torch.tensor(mixed_seq_length_packed),
427
+ batch_first=True,
428
+ )
429
+ # Concatenate all unpadded sequences
430
+ hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0)
431
+ # Split by original sample sequence lengths
432
+ hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist())
433
+ assert len(hidden_states_unpack) == batch_size
434
+
435
+ # Further split each sample's sequence into text and image parts
436
+ hidden_states_unpack = [
437
+ torch.split(h, [tlen, llen])
438
+ for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length)
439
+ ]
440
+ # Separate text and image sequences
441
+ encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack]
442
+ hidden_states_unpad = [h[1] for h in hidden_states_unpack]
443
+
444
+ # Update the original tensors with the processed values, respecting the attention masks
445
+ for idx in range(batch_size):
446
+ # Place unpacked text tokens back in the encoder_hidden_states tensor
447
+ encoder_hidden_states[idx][text_attn_mask[idx] == 1] = encoder_hidden_states_unpad[idx]
448
+ # Place unpacked image tokens back in the latent_hidden_states tensor
449
+ latent_hidden_states[idx][latent_attn_mask[idx] == 1] = hidden_states_unpad[idx]
450
+
451
+ # Update the output hidden states
452
+ hidden_states = latent_hidden_states
453
+
454
+ return hidden_states, encoder_hidden_states
455
+
456
+
457
+ @maybe_allow_in_graph
186
458
  class CogView4TransformerBlock(nn.Module):
187
459
  def __init__(
188
- self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512
460
+ self,
461
+ dim: int = 2560,
462
+ num_attention_heads: int = 64,
463
+ attention_head_dim: int = 40,
464
+ time_embed_dim: int = 512,
189
465
  ) -> None:
190
466
  super().__init__()
191
467
 
@@ -213,9 +489,11 @@ class CogView4TransformerBlock(nn.Module):
213
489
  hidden_states: torch.Tensor,
214
490
  encoder_hidden_states: torch.Tensor,
215
491
  temb: Optional[torch.Tensor] = None,
216
- image_rotary_emb: Optional[torch.Tensor] = None,
217
- attention_mask: Optional[torch.Tensor] = None,
218
- **kwargs,
492
+ image_rotary_emb: Optional[
493
+ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
494
+ ] = None,
495
+ attention_mask: Optional[Dict[str, torch.Tensor]] = None,
496
+ attention_kwargs: Optional[Dict[str, Any]] = None,
219
497
  ) -> torch.Tensor:
220
498
  # 1. Timestep conditioning
221
499
  (
@@ -232,12 +510,14 @@ class CogView4TransformerBlock(nn.Module):
232
510
  ) = self.norm1(hidden_states, encoder_hidden_states, temb)
233
511
 
234
512
  # 2. Attention
513
+ if attention_kwargs is None:
514
+ attention_kwargs = {}
235
515
  attn_hidden_states, attn_encoder_hidden_states = self.attn1(
236
516
  hidden_states=norm_hidden_states,
237
517
  encoder_hidden_states=norm_encoder_hidden_states,
238
518
  image_rotary_emb=image_rotary_emb,
239
519
  attention_mask=attention_mask,
240
- **kwargs,
520
+ **attention_kwargs,
241
521
  )
242
522
  hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
243
523
  encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
@@ -304,6 +584,38 @@ class CogView4RotaryPosEmbed(nn.Module):
304
584
  return (freqs.cos(), freqs.sin())
305
585
 
306
586
 
587
+ class CogView4AdaLayerNormContinuous(nn.Module):
588
+ """
589
+ CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
590
+ Linear on conditioning embedding.
591
+ """
592
+
593
+ def __init__(
594
+ self,
595
+ embedding_dim: int,
596
+ conditioning_embedding_dim: int,
597
+ elementwise_affine: bool = True,
598
+ eps: float = 1e-5,
599
+ bias: bool = True,
600
+ norm_type: str = "layer_norm",
601
+ ):
602
+ super().__init__()
603
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
604
+ if norm_type == "layer_norm":
605
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
606
+ elif norm_type == "rms_norm":
607
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
608
+ else:
609
+ raise ValueError(f"unknown norm_type {norm_type}")
610
+
611
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
612
+ # *** NO SiLU here ***
613
+ emb = self.linear(conditioning_embedding.to(x.dtype))
614
+ scale, shift = torch.chunk(emb, 2, dim=1)
615
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
616
+ return x
617
+
618
+
307
619
  class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
308
620
  r"""
309
621
  Args:
@@ -386,7 +698,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
386
698
  )
387
699
 
388
700
  # 4. Output projection
389
- self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
701
+ self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
390
702
  self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
391
703
 
392
704
  self.gradient_checkpointing = False
@@ -402,7 +714,9 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
402
714
  attention_kwargs: Optional[Dict[str, Any]] = None,
403
715
  return_dict: bool = True,
404
716
  attention_mask: Optional[torch.Tensor] = None,
405
- **kwargs,
717
+ image_rotary_emb: Optional[
718
+ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
719
+ ] = None,
406
720
  ) -> Union[torch.Tensor, Transformer2DModelOutput]:
407
721
  if attention_kwargs is not None:
408
722
  attention_kwargs = attention_kwargs.copy()
@@ -422,7 +736,8 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
422
736
  batch_size, num_channels, height, width = hidden_states.shape
423
737
 
424
738
  # 1. RoPE
425
- image_rotary_emb = self.rope(hidden_states)
739
+ if image_rotary_emb is None:
740
+ image_rotary_emb = self.rope(hidden_states)
426
741
 
427
742
  # 2. Patch & Timestep embeddings
428
743
  p = self.config.patch_size
@@ -438,11 +753,22 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
438
753
  for block in self.transformer_blocks:
439
754
  if torch.is_grad_enabled() and self.gradient_checkpointing:
440
755
  hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
441
- block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
756
+ block,
757
+ hidden_states,
758
+ encoder_hidden_states,
759
+ temb,
760
+ image_rotary_emb,
761
+ attention_mask,
762
+ attention_kwargs,
442
763
  )
443
764
  else:
444
765
  hidden_states, encoder_hidden_states = block(
445
- hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
766
+ hidden_states,
767
+ encoder_hidden_states,
768
+ temb,
769
+ image_rotary_emb,
770
+ attention_mask,
771
+ attention_kwargs,
446
772
  )
447
773
 
448
774
  # 4. Output norm & projection