diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,10 @@ import torch.nn.functional as F
21
21
 
22
22
  from ...configuration_utils import ConfigMixin, register_to_config
23
23
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
- from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
- from ..attention import FeedForward
26
- from ..attention_processor import Attention
24
+ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils.torch_utils import maybe_allow_in_graph
26
+ from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
27
+ from ..attention_dispatch import dispatch_attention_fn
27
28
  from ..cache_utils import CacheMixin
28
29
  from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
29
30
  from ..modeling_outputs import Transformer2DModelOutput
@@ -34,69 +35,117 @@ from ..normalization import FP32LayerNorm
34
35
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
36
 
36
37
 
37
- class WanAttnProcessor2_0:
38
+ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
39
+ # encoder_hidden_states is only passed for cross-attention
40
+ if encoder_hidden_states is None:
41
+ encoder_hidden_states = hidden_states
42
+
43
+ if attn.fused_projections:
44
+ if attn.cross_attention_dim_head is None:
45
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
46
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
47
+ else:
48
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
49
+ query = attn.to_q(hidden_states)
50
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
51
+ else:
52
+ query = attn.to_q(hidden_states)
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+ return query, key, value
56
+
57
+
58
+ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
59
+ if attn.fused_projections:
60
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
61
+ else:
62
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
63
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
64
+ return key_img, value_img
65
+
66
+
67
+ class WanAttnProcessor:
68
+ _attention_backend = None
69
+
38
70
  def __init__(self):
39
71
  if not hasattr(F, "scaled_dot_product_attention"):
40
- raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
72
+ raise ImportError(
73
+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
74
+ )
41
75
 
42
76
  def __call__(
43
77
  self,
44
- attn: Attention,
78
+ attn: "WanAttention",
45
79
  hidden_states: torch.Tensor,
46
80
  encoder_hidden_states: Optional[torch.Tensor] = None,
47
81
  attention_mask: Optional[torch.Tensor] = None,
48
- rotary_emb: Optional[torch.Tensor] = None,
82
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
49
83
  ) -> torch.Tensor:
50
84
  encoder_hidden_states_img = None
51
85
  if attn.add_k_proj is not None:
52
- encoder_hidden_states_img = encoder_hidden_states[:, :257]
53
- encoder_hidden_states = encoder_hidden_states[:, 257:]
54
- if encoder_hidden_states is None:
55
- encoder_hidden_states = hidden_states
86
+ # 512 is the context length of the text encoder, hardcoded for now
87
+ image_context_length = encoder_hidden_states.shape[1] - 512
88
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
89
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
56
90
 
57
- query = attn.to_q(hidden_states)
58
- key = attn.to_k(encoder_hidden_states)
59
- value = attn.to_v(encoder_hidden_states)
91
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
60
92
 
61
- if attn.norm_q is not None:
62
- query = attn.norm_q(query)
63
- if attn.norm_k is not None:
64
- key = attn.norm_k(key)
93
+ query = attn.norm_q(query)
94
+ key = attn.norm_k(key)
65
95
 
66
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
67
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
68
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
96
+ query = query.unflatten(2, (attn.heads, -1))
97
+ key = key.unflatten(2, (attn.heads, -1))
98
+ value = value.unflatten(2, (attn.heads, -1))
69
99
 
70
100
  if rotary_emb is not None:
71
101
 
72
- def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
73
- x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
74
- x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
75
- return x_out.type_as(hidden_states)
76
-
77
- query = apply_rotary_emb(query, rotary_emb)
78
- key = apply_rotary_emb(key, rotary_emb)
102
+ def apply_rotary_emb(
103
+ hidden_states: torch.Tensor,
104
+ freqs_cos: torch.Tensor,
105
+ freqs_sin: torch.Tensor,
106
+ ):
107
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
108
+ cos = freqs_cos[..., 0::2]
109
+ sin = freqs_sin[..., 1::2]
110
+ out = torch.empty_like(hidden_states)
111
+ out[..., 0::2] = x1 * cos - x2 * sin
112
+ out[..., 1::2] = x1 * sin + x2 * cos
113
+ return out.type_as(hidden_states)
114
+
115
+ query = apply_rotary_emb(query, *rotary_emb)
116
+ key = apply_rotary_emb(key, *rotary_emb)
79
117
 
80
118
  # I2V task
81
119
  hidden_states_img = None
82
120
  if encoder_hidden_states_img is not None:
83
- key_img = attn.add_k_proj(encoder_hidden_states_img)
121
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
84
122
  key_img = attn.norm_added_k(key_img)
85
- value_img = attn.add_v_proj(encoder_hidden_states_img)
86
-
87
- key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
88
- value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
89
123
 
90
- hidden_states_img = F.scaled_dot_product_attention(
91
- query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
124
+ key_img = key_img.unflatten(2, (attn.heads, -1))
125
+ value_img = value_img.unflatten(2, (attn.heads, -1))
126
+
127
+ hidden_states_img = dispatch_attention_fn(
128
+ query,
129
+ key_img,
130
+ value_img,
131
+ attn_mask=None,
132
+ dropout_p=0.0,
133
+ is_causal=False,
134
+ backend=self._attention_backend,
92
135
  )
93
- hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
136
+ hidden_states_img = hidden_states_img.flatten(2, 3)
94
137
  hidden_states_img = hidden_states_img.type_as(query)
95
138
 
96
- hidden_states = F.scaled_dot_product_attention(
97
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
139
+ hidden_states = dispatch_attention_fn(
140
+ query,
141
+ key,
142
+ value,
143
+ attn_mask=attention_mask,
144
+ dropout_p=0.0,
145
+ is_causal=False,
146
+ backend=self._attention_backend,
98
147
  )
99
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
148
+ hidden_states = hidden_states.flatten(2, 3)
100
149
  hidden_states = hidden_states.type_as(query)
101
150
 
102
151
  if hidden_states_img is not None:
@@ -107,15 +156,140 @@ class WanAttnProcessor2_0:
107
156
  return hidden_states
108
157
 
109
158
 
159
+ class WanAttnProcessor2_0:
160
+ def __new__(cls, *args, **kwargs):
161
+ deprecation_message = (
162
+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
163
+ "Please use WanAttnProcessor instead. "
164
+ )
165
+ deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
166
+ return WanAttnProcessor(*args, **kwargs)
167
+
168
+
169
+ class WanAttention(torch.nn.Module, AttentionModuleMixin):
170
+ _default_processor_cls = WanAttnProcessor
171
+ _available_processors = [WanAttnProcessor]
172
+
173
+ def __init__(
174
+ self,
175
+ dim: int,
176
+ heads: int = 8,
177
+ dim_head: int = 64,
178
+ eps: float = 1e-5,
179
+ dropout: float = 0.0,
180
+ added_kv_proj_dim: Optional[int] = None,
181
+ cross_attention_dim_head: Optional[int] = None,
182
+ processor=None,
183
+ is_cross_attention=None,
184
+ ):
185
+ super().__init__()
186
+
187
+ self.inner_dim = dim_head * heads
188
+ self.heads = heads
189
+ self.added_kv_proj_dim = added_kv_proj_dim
190
+ self.cross_attention_dim_head = cross_attention_dim_head
191
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
192
+
193
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
194
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
195
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
196
+ self.to_out = torch.nn.ModuleList(
197
+ [
198
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
199
+ torch.nn.Dropout(dropout),
200
+ ]
201
+ )
202
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
203
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
204
+
205
+ self.add_k_proj = self.add_v_proj = None
206
+ if added_kv_proj_dim is not None:
207
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
208
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
209
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
210
+
211
+ self.is_cross_attention = cross_attention_dim_head is not None
212
+
213
+ self.set_processor(processor)
214
+
215
+ def fuse_projections(self):
216
+ if getattr(self, "fused_projections", False):
217
+ return
218
+
219
+ if self.cross_attention_dim_head is None:
220
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
221
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
222
+ out_features, in_features = concatenated_weights.shape
223
+ with torch.device("meta"):
224
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
225
+ self.to_qkv.load_state_dict(
226
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
227
+ )
228
+ else:
229
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
230
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
231
+ out_features, in_features = concatenated_weights.shape
232
+ with torch.device("meta"):
233
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
234
+ self.to_kv.load_state_dict(
235
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
236
+ )
237
+
238
+ if self.added_kv_proj_dim is not None:
239
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
240
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
241
+ out_features, in_features = concatenated_weights.shape
242
+ with torch.device("meta"):
243
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
244
+ self.to_added_kv.load_state_dict(
245
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
246
+ )
247
+
248
+ self.fused_projections = True
249
+
250
+ @torch.no_grad()
251
+ def unfuse_projections(self):
252
+ if not getattr(self, "fused_projections", False):
253
+ return
254
+
255
+ if hasattr(self, "to_qkv"):
256
+ delattr(self, "to_qkv")
257
+ if hasattr(self, "to_kv"):
258
+ delattr(self, "to_kv")
259
+ if hasattr(self, "to_added_kv"):
260
+ delattr(self, "to_added_kv")
261
+
262
+ self.fused_projections = False
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ encoder_hidden_states: Optional[torch.Tensor] = None,
268
+ attention_mask: Optional[torch.Tensor] = None,
269
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
270
+ **kwargs,
271
+ ) -> torch.Tensor:
272
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
273
+
274
+
110
275
  class WanImageEmbedding(torch.nn.Module):
111
- def __init__(self, in_features: int, out_features: int):
276
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
112
277
  super().__init__()
113
278
 
114
279
  self.norm1 = FP32LayerNorm(in_features)
115
280
  self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
116
281
  self.norm2 = FP32LayerNorm(out_features)
282
+ if pos_embed_seq_len is not None:
283
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
284
+ else:
285
+ self.pos_embed = None
117
286
 
118
287
  def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
288
+ if self.pos_embed is not None:
289
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
290
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
291
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
292
+
119
293
  hidden_states = self.norm1(encoder_hidden_states_image)
120
294
  hidden_states = self.ff(hidden_states)
121
295
  hidden_states = self.norm2(hidden_states)
@@ -130,6 +304,7 @@ class WanTimeTextImageEmbedding(nn.Module):
130
304
  time_proj_dim: int,
131
305
  text_embed_dim: int,
132
306
  image_embed_dim: Optional[int] = None,
307
+ pos_embed_seq_len: Optional[int] = None,
133
308
  ):
134
309
  super().__init__()
135
310
 
@@ -141,15 +316,18 @@ class WanTimeTextImageEmbedding(nn.Module):
141
316
 
142
317
  self.image_embedder = None
143
318
  if image_embed_dim is not None:
144
- self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
319
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
145
320
 
146
321
  def forward(
147
322
  self,
148
323
  timestep: torch.Tensor,
149
324
  encoder_hidden_states: torch.Tensor,
150
325
  encoder_hidden_states_image: Optional[torch.Tensor] = None,
326
+ timestep_seq_len: Optional[int] = None,
151
327
  ):
152
328
  timestep = self.timesteps_proj(timestep)
329
+ if timestep_seq_len is not None:
330
+ timestep = timestep.unflatten(0, (-1, timestep_seq_len))
153
331
 
154
332
  time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
155
333
  if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
@@ -166,7 +344,11 @@ class WanTimeTextImageEmbedding(nn.Module):
166
344
 
167
345
  class WanRotaryPosEmbed(nn.Module):
168
346
  def __init__(
169
- self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
347
+ self,
348
+ attention_head_dim: int,
349
+ patch_size: Tuple[int, int, int],
350
+ max_seq_len: int,
351
+ theta: float = 10000.0,
170
352
  ):
171
353
  super().__init__()
172
354
 
@@ -176,37 +358,55 @@ class WanRotaryPosEmbed(nn.Module):
176
358
 
177
359
  h_dim = w_dim = 2 * (attention_head_dim // 6)
178
360
  t_dim = attention_head_dim - h_dim - w_dim
361
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
362
+
363
+ freqs_cos = []
364
+ freqs_sin = []
179
365
 
180
- freqs = []
181
366
  for dim in [t_dim, h_dim, w_dim]:
182
- freq = get_1d_rotary_pos_embed(
183
- dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
367
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
368
+ dim,
369
+ max_seq_len,
370
+ theta,
371
+ use_real=True,
372
+ repeat_interleave_real=True,
373
+ freqs_dtype=freqs_dtype,
184
374
  )
185
- freqs.append(freq)
186
- self.freqs = torch.cat(freqs, dim=1)
375
+ freqs_cos.append(freq_cos)
376
+ freqs_sin.append(freq_sin)
377
+
378
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
379
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
187
380
 
188
381
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
189
382
  batch_size, num_channels, num_frames, height, width = hidden_states.shape
190
383
  p_t, p_h, p_w = self.patch_size
191
384
  ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
192
385
 
193
- self.freqs = self.freqs.to(hidden_states.device)
194
- freqs = self.freqs.split_with_sizes(
195
- [
196
- self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
197
- self.attention_head_dim // 6,
198
- self.attention_head_dim // 6,
199
- ],
200
- dim=1,
201
- )
386
+ split_sizes = [
387
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
388
+ self.attention_head_dim // 3,
389
+ self.attention_head_dim // 3,
390
+ ]
391
+
392
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
393
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
202
394
 
203
- freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
204
- freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
205
- freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
206
- freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
207
- return freqs
395
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
396
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
397
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
208
398
 
399
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
400
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
401
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
209
402
 
403
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
404
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
405
+
406
+ return freqs_cos, freqs_sin
407
+
408
+
409
+ @maybe_allow_in_graph
210
410
  class WanTransformerBlock(nn.Module):
211
411
  def __init__(
212
412
  self,
@@ -222,33 +422,24 @@ class WanTransformerBlock(nn.Module):
222
422
 
223
423
  # 1. Self-attention
224
424
  self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
225
- self.attn1 = Attention(
226
- query_dim=dim,
425
+ self.attn1 = WanAttention(
426
+ dim=dim,
227
427
  heads=num_heads,
228
- kv_heads=num_heads,
229
428
  dim_head=dim // num_heads,
230
- qk_norm=qk_norm,
231
429
  eps=eps,
232
- bias=True,
233
- cross_attention_dim=None,
234
- out_bias=True,
235
- processor=WanAttnProcessor2_0(),
430
+ cross_attention_dim_head=None,
431
+ processor=WanAttnProcessor(),
236
432
  )
237
433
 
238
434
  # 2. Cross-attention
239
- self.attn2 = Attention(
240
- query_dim=dim,
435
+ self.attn2 = WanAttention(
436
+ dim=dim,
241
437
  heads=num_heads,
242
- kv_heads=num_heads,
243
438
  dim_head=dim // num_heads,
244
- qk_norm=qk_norm,
245
439
  eps=eps,
246
- bias=True,
247
- cross_attention_dim=None,
248
- out_bias=True,
249
440
  added_kv_proj_dim=added_kv_proj_dim,
250
- added_proj_bias=True,
251
- processor=WanAttnProcessor2_0(),
441
+ cross_attention_dim_head=dim // num_heads,
442
+ processor=WanAttnProcessor(),
252
443
  )
253
444
  self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
254
445
 
@@ -265,18 +456,32 @@ class WanTransformerBlock(nn.Module):
265
456
  temb: torch.Tensor,
266
457
  rotary_emb: torch.Tensor,
267
458
  ) -> torch.Tensor:
268
- shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
269
- self.scale_shift_table + temb.float()
270
- ).chunk(6, dim=1)
459
+ if temb.ndim == 4:
460
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
461
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
462
+ self.scale_shift_table.unsqueeze(0) + temb.float()
463
+ ).chunk(6, dim=2)
464
+ # batch_size, seq_len, 1, inner_dim
465
+ shift_msa = shift_msa.squeeze(2)
466
+ scale_msa = scale_msa.squeeze(2)
467
+ gate_msa = gate_msa.squeeze(2)
468
+ c_shift_msa = c_shift_msa.squeeze(2)
469
+ c_scale_msa = c_scale_msa.squeeze(2)
470
+ c_gate_msa = c_gate_msa.squeeze(2)
471
+ else:
472
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
473
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
474
+ self.scale_shift_table + temb.float()
475
+ ).chunk(6, dim=1)
271
476
 
272
477
  # 1. Self-attention
273
478
  norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
274
- attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
479
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
275
480
  hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
276
481
 
277
482
  # 2. Cross-attention
278
483
  norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
279
- attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
484
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
280
485
  hidden_states = hidden_states + attn_output
281
486
 
282
487
  # 3. Feed-forward
@@ -289,7 +494,9 @@ class WanTransformerBlock(nn.Module):
289
494
  return hidden_states
290
495
 
291
496
 
292
- class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
497
+ class WanTransformer3DModel(
498
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
499
+ ):
293
500
  r"""
294
501
  A Transformer model for video-like data used in the Wan model.
295
502
 
@@ -331,6 +538,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
331
538
  _no_split_modules = ["WanTransformerBlock"]
332
539
  _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
333
540
  _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
541
+ _repeated_blocks = ["WanTransformerBlock"]
334
542
 
335
543
  @register_to_config
336
544
  def __init__(
@@ -350,6 +558,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
350
558
  image_dim: Optional[int] = None,
351
559
  added_kv_proj_dim: Optional[int] = None,
352
560
  rope_max_seq_len: int = 1024,
561
+ pos_embed_seq_len: Optional[int] = None,
353
562
  ) -> None:
354
563
  super().__init__()
355
564
 
@@ -368,6 +577,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
368
577
  time_proj_dim=inner_dim * 6,
369
578
  text_embed_dim=text_dim,
370
579
  image_embed_dim=image_dim,
580
+ pos_embed_seq_len=pos_embed_seq_len,
371
581
  )
372
582
 
373
583
  # 3. Transformer blocks
@@ -422,10 +632,22 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
422
632
  hidden_states = self.patch_embedding(hidden_states)
423
633
  hidden_states = hidden_states.flatten(2).transpose(1, 2)
424
634
 
635
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
636
+ if timestep.ndim == 2:
637
+ ts_seq_len = timestep.shape[1]
638
+ timestep = timestep.flatten() # batch_size * seq_len
639
+ else:
640
+ ts_seq_len = None
641
+
425
642
  temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
426
- timestep, encoder_hidden_states, encoder_hidden_states_image
643
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
427
644
  )
428
- timestep_proj = timestep_proj.unflatten(1, (6, -1))
645
+ if ts_seq_len is not None:
646
+ # batch_size, seq_len, 6, inner_dim
647
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
648
+ else:
649
+ # batch_size, 6, inner_dim
650
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
429
651
 
430
652
  if encoder_hidden_states_image is not None:
431
653
  encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
@@ -441,7 +663,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
441
663
  hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
442
664
 
443
665
  # 5. Output norm, projection & unpatchify
444
- shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
666
+ if temb.ndim == 3:
667
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
668
+ shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
669
+ shift = shift.squeeze(2)
670
+ scale = scale.squeeze(2)
671
+ else:
672
+ # batch_size, inner_dim
673
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
445
674
 
446
675
  # Move the shift and scale tensors to the same device as hidden_states.
447
676
  # When using multi-GPU inference via accelerate these will be on the