diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
870
870
  "HunyuanVideoPatchEmbed",
871
871
  "HunyuanVideoTokenRefiner",
872
872
  ]
873
+ _repeated_blocks = [
874
+ "HunyuanVideoTransformerBlock",
875
+ "HunyuanVideoSingleTransformerBlock",
876
+ "HunyuanVideoPatchEmbed",
877
+ "HunyuanVideoTokenRefiner",
878
+ ]
873
879
 
874
880
  @register_to_config
875
881
  def __init__(
@@ -1068,17 +1074,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
1068
1074
  latent_sequence_length = hidden_states.shape[1]
1069
1075
  condition_sequence_length = encoder_hidden_states.shape[1]
1070
1076
  sequence_length = latent_sequence_length + condition_sequence_length
1071
- attention_mask = torch.zeros(
1077
+ attention_mask = torch.ones(
1072
1078
  batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
1073
1079
  ) # [B, N]
1074
-
1075
1080
  effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
1076
1081
  effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
1077
-
1078
- for i in range(batch_size):
1079
- attention_mask[i, : effective_sequence_length[i]] = True
1080
- # [B, 1, 1, N], for broadcasting across attention heads
1081
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
1082
+ indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N]
1083
+ mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
1084
+ attention_mask = attention_mask.masked_fill(mask_indices, False)
1085
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
1082
1086
 
1083
1087
  # 4. Transformer blocks
1084
1088
  if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -0,0 +1,416 @@
1
+ # Copyright 2025 The Framepack Team, The Hunyuan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
24
+ from ..cache_utils import CacheMixin
25
+ from ..embeddings import get_1d_rotary_pos_embed
26
+ from ..modeling_outputs import Transformer2DModelOutput
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import AdaLayerNormContinuous
29
+ from .transformer_hunyuan_video import (
30
+ HunyuanVideoConditionEmbedding,
31
+ HunyuanVideoPatchEmbed,
32
+ HunyuanVideoSingleTransformerBlock,
33
+ HunyuanVideoTokenRefiner,
34
+ HunyuanVideoTransformerBlock,
35
+ )
36
+
37
+
38
+ logger = get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
42
+ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
43
+ super().__init__()
44
+
45
+ self.patch_size = patch_size
46
+ self.patch_size_t = patch_size_t
47
+ self.rope_dim = rope_dim
48
+ self.theta = theta
49
+
50
+ def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
51
+ height = height // self.patch_size
52
+ width = width // self.patch_size
53
+ grid = torch.meshgrid(
54
+ frame_indices.to(device=device, dtype=torch.float32),
55
+ torch.arange(0, height, device=device, dtype=torch.float32),
56
+ torch.arange(0, width, device=device, dtype=torch.float32),
57
+ indexing="ij",
58
+ ) # 3 * [W, H, T]
59
+ grid = torch.stack(grid, dim=0) # [3, W, H, T]
60
+
61
+ freqs = []
62
+ for i in range(3):
63
+ freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
64
+ freqs.append(freq)
65
+
66
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
67
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
68
+
69
+ return freqs_cos, freqs_sin
70
+
71
+
72
+ class FramepackClipVisionProjection(nn.Module):
73
+ def __init__(self, in_channels: int, out_channels: int):
74
+ super().__init__()
75
+ self.up = nn.Linear(in_channels, out_channels * 3)
76
+ self.down = nn.Linear(out_channels * 3, out_channels)
77
+
78
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
79
+ hidden_states = self.up(hidden_states)
80
+ hidden_states = F.silu(hidden_states)
81
+ hidden_states = self.down(hidden_states)
82
+ return hidden_states
83
+
84
+
85
+ class HunyuanVideoHistoryPatchEmbed(nn.Module):
86
+ def __init__(self, in_channels: int, inner_dim: int):
87
+ super().__init__()
88
+ self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
89
+ self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
90
+ self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
91
+
92
+ def forward(
93
+ self,
94
+ latents_clean: Optional[torch.Tensor] = None,
95
+ latents_clean_2x: Optional[torch.Tensor] = None,
96
+ latents_clean_4x: Optional[torch.Tensor] = None,
97
+ ):
98
+ if latents_clean is not None:
99
+ latents_clean = self.proj(latents_clean)
100
+ latents_clean = latents_clean.flatten(2).transpose(1, 2)
101
+ if latents_clean_2x is not None:
102
+ latents_clean_2x = _pad_for_3d_conv(latents_clean_2x, (2, 4, 4))
103
+ latents_clean_2x = self.proj_2x(latents_clean_2x)
104
+ latents_clean_2x = latents_clean_2x.flatten(2).transpose(1, 2)
105
+ if latents_clean_4x is not None:
106
+ latents_clean_4x = _pad_for_3d_conv(latents_clean_4x, (4, 8, 8))
107
+ latents_clean_4x = self.proj_4x(latents_clean_4x)
108
+ latents_clean_4x = latents_clean_4x.flatten(2).transpose(1, 2)
109
+ return latents_clean, latents_clean_2x, latents_clean_4x
110
+
111
+
112
+ class HunyuanVideoFramepackTransformer3DModel(
113
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
114
+ ):
115
+ _supports_gradient_checkpointing = True
116
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
117
+ _no_split_modules = [
118
+ "HunyuanVideoTransformerBlock",
119
+ "HunyuanVideoSingleTransformerBlock",
120
+ "HunyuanVideoHistoryPatchEmbed",
121
+ "HunyuanVideoTokenRefiner",
122
+ ]
123
+
124
+ @register_to_config
125
+ def __init__(
126
+ self,
127
+ in_channels: int = 16,
128
+ out_channels: int = 16,
129
+ num_attention_heads: int = 24,
130
+ attention_head_dim: int = 128,
131
+ num_layers: int = 20,
132
+ num_single_layers: int = 40,
133
+ num_refiner_layers: int = 2,
134
+ mlp_ratio: float = 4.0,
135
+ patch_size: int = 2,
136
+ patch_size_t: int = 1,
137
+ qk_norm: str = "rms_norm",
138
+ guidance_embeds: bool = True,
139
+ text_embed_dim: int = 4096,
140
+ pooled_projection_dim: int = 768,
141
+ rope_theta: float = 256.0,
142
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
143
+ image_condition_type: Optional[str] = None,
144
+ has_image_proj: int = False,
145
+ image_proj_dim: int = 1152,
146
+ has_clean_x_embedder: int = False,
147
+ ) -> None:
148
+ super().__init__()
149
+
150
+ inner_dim = num_attention_heads * attention_head_dim
151
+ out_channels = out_channels or in_channels
152
+
153
+ # 1. Latent and condition embedders
154
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
155
+
156
+ # Framepack history projection embedder
157
+ self.clean_x_embedder = None
158
+ if has_clean_x_embedder:
159
+ self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
160
+
161
+ self.context_embedder = HunyuanVideoTokenRefiner(
162
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
163
+ )
164
+
165
+ # Framepack image-conditioning embedder
166
+ self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
167
+
168
+ self.time_text_embed = HunyuanVideoConditionEmbedding(
169
+ inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
170
+ )
171
+
172
+ # 2. RoPE
173
+ self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
174
+
175
+ # 3. Dual stream transformer blocks
176
+ self.transformer_blocks = nn.ModuleList(
177
+ [
178
+ HunyuanVideoTransformerBlock(
179
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
180
+ )
181
+ for _ in range(num_layers)
182
+ ]
183
+ )
184
+
185
+ # 4. Single stream transformer blocks
186
+ self.single_transformer_blocks = nn.ModuleList(
187
+ [
188
+ HunyuanVideoSingleTransformerBlock(
189
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
190
+ )
191
+ for _ in range(num_single_layers)
192
+ ]
193
+ )
194
+
195
+ # 5. Output projection
196
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
197
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
198
+
199
+ self.gradient_checkpointing = False
200
+
201
+ def forward(
202
+ self,
203
+ hidden_states: torch.Tensor,
204
+ timestep: torch.LongTensor,
205
+ encoder_hidden_states: torch.Tensor,
206
+ encoder_attention_mask: torch.Tensor,
207
+ pooled_projections: torch.Tensor,
208
+ image_embeds: torch.Tensor,
209
+ indices_latents: torch.Tensor,
210
+ guidance: Optional[torch.Tensor] = None,
211
+ latents_clean: Optional[torch.Tensor] = None,
212
+ indices_latents_clean: Optional[torch.Tensor] = None,
213
+ latents_history_2x: Optional[torch.Tensor] = None,
214
+ indices_latents_history_2x: Optional[torch.Tensor] = None,
215
+ latents_history_4x: Optional[torch.Tensor] = None,
216
+ indices_latents_history_4x: Optional[torch.Tensor] = None,
217
+ attention_kwargs: Optional[Dict[str, Any]] = None,
218
+ return_dict: bool = True,
219
+ ):
220
+ if attention_kwargs is not None:
221
+ attention_kwargs = attention_kwargs.copy()
222
+ lora_scale = attention_kwargs.pop("scale", 1.0)
223
+ else:
224
+ lora_scale = 1.0
225
+
226
+ if USE_PEFT_BACKEND:
227
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
228
+ scale_lora_layers(self, lora_scale)
229
+ else:
230
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
231
+ logger.warning(
232
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
233
+ )
234
+
235
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
236
+ p, p_t = self.config.patch_size, self.config.patch_size_t
237
+ post_patch_num_frames = num_frames // p_t
238
+ post_patch_height = height // p
239
+ post_patch_width = width // p
240
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
241
+
242
+ if indices_latents is None:
243
+ indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
244
+
245
+ hidden_states = self.x_embedder(hidden_states)
246
+ image_rotary_emb = self.rope(
247
+ frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
248
+ )
249
+
250
+ latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
251
+ latents_clean, latents_history_2x, latents_history_4x
252
+ )
253
+
254
+ if latents_clean is not None and indices_latents_clean is not None:
255
+ image_rotary_emb_clean = self.rope(
256
+ frame_indices=indices_latents_clean, height=height, width=width, device=hidden_states.device
257
+ )
258
+ if latents_history_2x is not None and indices_latents_history_2x is not None:
259
+ image_rotary_emb_history_2x = self.rope(
260
+ frame_indices=indices_latents_history_2x, height=height, width=width, device=hidden_states.device
261
+ )
262
+ if latents_history_4x is not None and indices_latents_history_4x is not None:
263
+ image_rotary_emb_history_4x = self.rope(
264
+ frame_indices=indices_latents_history_4x, height=height, width=width, device=hidden_states.device
265
+ )
266
+
267
+ hidden_states, image_rotary_emb = self._pack_history_states(
268
+ hidden_states,
269
+ latents_clean,
270
+ latents_history_2x,
271
+ latents_history_4x,
272
+ image_rotary_emb,
273
+ image_rotary_emb_clean,
274
+ image_rotary_emb_history_2x,
275
+ image_rotary_emb_history_4x,
276
+ post_patch_height,
277
+ post_patch_width,
278
+ )
279
+
280
+ temb, _ = self.time_text_embed(timestep, pooled_projections, guidance)
281
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
282
+
283
+ encoder_hidden_states_image = self.image_projection(image_embeds)
284
+ attention_mask_image = encoder_attention_mask.new_ones((batch_size, encoder_hidden_states_image.shape[1]))
285
+
286
+ # must cat before (not after) encoder_hidden_states, due to attn masking
287
+ encoder_hidden_states = torch.cat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
288
+ encoder_attention_mask = torch.cat([attention_mask_image, encoder_attention_mask], dim=1)
289
+
290
+ latent_sequence_length = hidden_states.shape[1]
291
+ condition_sequence_length = encoder_hidden_states.shape[1]
292
+ sequence_length = latent_sequence_length + condition_sequence_length
293
+ attention_mask = torch.zeros(
294
+ batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
295
+ ) # [B, N]
296
+ effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
297
+ effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
298
+
299
+ if batch_size == 1:
300
+ encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
301
+ attention_mask = None
302
+ else:
303
+ for i in range(batch_size):
304
+ attention_mask[i, : effective_sequence_length[i]] = True
305
+ # [B, 1, 1, N], for broadcasting across attention heads
306
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
307
+
308
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
309
+ for block in self.transformer_blocks:
310
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
311
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
312
+ )
313
+
314
+ for block in self.single_transformer_blocks:
315
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
316
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
317
+ )
318
+
319
+ else:
320
+ for block in self.transformer_blocks:
321
+ hidden_states, encoder_hidden_states = block(
322
+ hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
323
+ )
324
+
325
+ for block in self.single_transformer_blocks:
326
+ hidden_states, encoder_hidden_states = block(
327
+ hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
328
+ )
329
+
330
+ hidden_states = hidden_states[:, -original_context_length:]
331
+ hidden_states = self.norm_out(hidden_states, temb)
332
+ hidden_states = self.proj_out(hidden_states)
333
+
334
+ hidden_states = hidden_states.reshape(
335
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
336
+ )
337
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
338
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
339
+
340
+ if USE_PEFT_BACKEND:
341
+ # remove `lora_scale` from each PEFT layer
342
+ unscale_lora_layers(self, lora_scale)
343
+
344
+ if not return_dict:
345
+ return (hidden_states,)
346
+ return Transformer2DModelOutput(sample=hidden_states)
347
+
348
+ def _pack_history_states(
349
+ self,
350
+ hidden_states: torch.Tensor,
351
+ latents_clean: Optional[torch.Tensor] = None,
352
+ latents_history_2x: Optional[torch.Tensor] = None,
353
+ latents_history_4x: Optional[torch.Tensor] = None,
354
+ image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None,
355
+ image_rotary_emb_clean: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
356
+ image_rotary_emb_history_2x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
357
+ image_rotary_emb_history_4x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
358
+ height: int = None,
359
+ width: int = None,
360
+ ):
361
+ image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
362
+
363
+ if latents_clean is not None and image_rotary_emb_clean is not None:
364
+ hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
365
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
366
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
367
+
368
+ if latents_history_2x is not None and image_rotary_emb_history_2x is not None:
369
+ hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
370
+ image_rotary_emb_history_2x = self._pad_rotary_emb(image_rotary_emb_history_2x, height, width, (2, 2, 2))
371
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
372
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
373
+
374
+ if latents_history_4x is not None and image_rotary_emb_history_4x is not None:
375
+ hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
376
+ image_rotary_emb_history_4x = self._pad_rotary_emb(image_rotary_emb_history_4x, height, width, (4, 4, 4))
377
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
378
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
379
+
380
+ return hidden_states, tuple(image_rotary_emb)
381
+
382
+ def _pad_rotary_emb(
383
+ self,
384
+ image_rotary_emb: Tuple[torch.Tensor],
385
+ height: int,
386
+ width: int,
387
+ kernel_size: Tuple[int, int, int],
388
+ ):
389
+ # freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim
390
+ freqs_cos, freqs_sin = image_rotary_emb
391
+ freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
392
+ freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
393
+ freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size)
394
+ freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size)
395
+ freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size)
396
+ freqs_sin = _center_down_sample_3d(freqs_sin, kernel_size)
397
+ freqs_cos = freqs_cos.flatten(2).permute(0, 2, 1).squeeze(0)
398
+ freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0)
399
+ return freqs_cos, freqs_sin
400
+
401
+
402
+ def _pad_for_3d_conv(x, kernel_size):
403
+ if isinstance(x, (tuple, list)):
404
+ return tuple(_pad_for_3d_conv(i, kernel_size) for i in x)
405
+ b, c, t, h, w = x.shape
406
+ pt, ph, pw = kernel_size
407
+ pad_t = (pt - (t % pt)) % pt
408
+ pad_h = (ph - (h % ph)) % ph
409
+ pad_w = (pw - (w % pw)) % pw
410
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
411
+
412
+
413
+ def _center_down_sample_3d(x, kernel_size):
414
+ if isinstance(x, (tuple, list)):
415
+ return tuple(_center_down_sample_3d(i, kernel_size) for i in x)
416
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Genmo team and The HuggingFace Team.
1
+ # Copyright 2025 The Lightricks team and The HuggingFace Team.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,19 +13,19 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import inspect
16
17
  import math
17
18
  from typing import Any, Dict, Optional, Tuple, Union
18
19
 
19
20
  import torch
20
21
  import torch.nn as nn
21
- import torch.nn.functional as F
22
22
 
23
23
  from ...configuration_utils import ConfigMixin, register_to_config
24
24
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25
- from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26
26
  from ...utils.torch_utils import maybe_allow_in_graph
27
- from ..attention import FeedForward
28
- from ..attention_processor import Attention
27
+ from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
28
+ from ..attention_dispatch import dispatch_attention_fn
29
29
  from ..cache_utils import CacheMixin
30
30
  from ..embeddings import PixArtAlphaTextProjection
31
31
  from ..modeling_outputs import Transformer2DModelOutput
@@ -37,20 +37,30 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
37
 
38
38
 
39
39
  class LTXVideoAttentionProcessor2_0:
40
+ def __new__(cls, *args, **kwargs):
41
+ deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
42
+ deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
43
+
44
+ return LTXVideoAttnProcessor(*args, **kwargs)
45
+
46
+
47
+ class LTXVideoAttnProcessor:
40
48
  r"""
41
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
42
- used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
49
+ Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
50
+ model. It applies a normalization layer and rotary embedding on the query and key vector.
43
51
  """
44
52
 
53
+ _attention_backend = None
54
+
45
55
  def __init__(self):
46
- if not hasattr(F, "scaled_dot_product_attention"):
47
- raise ImportError(
48
- "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
56
+ if is_torch_version("<", "2.0"):
57
+ raise ValueError(
58
+ "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
49
59
  )
50
60
 
51
61
  def __call__(
52
62
  self,
53
- attn: Attention,
63
+ attn: "LTXAttention",
54
64
  hidden_states: torch.Tensor,
55
65
  encoder_hidden_states: Optional[torch.Tensor] = None,
56
66
  attention_mask: Optional[torch.Tensor] = None,
@@ -78,14 +88,20 @@ class LTXVideoAttentionProcessor2_0:
78
88
  query = apply_rotary_emb(query, image_rotary_emb)
79
89
  key = apply_rotary_emb(key, image_rotary_emb)
80
90
 
81
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
82
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
83
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
84
-
85
- hidden_states = F.scaled_dot_product_attention(
86
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
91
+ query = query.unflatten(2, (attn.heads, -1))
92
+ key = key.unflatten(2, (attn.heads, -1))
93
+ value = value.unflatten(2, (attn.heads, -1))
94
+
95
+ hidden_states = dispatch_attention_fn(
96
+ query,
97
+ key,
98
+ value,
99
+ attn_mask=attention_mask,
100
+ dropout_p=0.0,
101
+ is_causal=False,
102
+ backend=self._attention_backend,
87
103
  )
88
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
104
+ hidden_states = hidden_states.flatten(2, 3)
89
105
  hidden_states = hidden_states.to(query.dtype)
90
106
 
91
107
  hidden_states = attn.to_out[0](hidden_states)
@@ -93,6 +109,70 @@ class LTXVideoAttentionProcessor2_0:
93
109
  return hidden_states
94
110
 
95
111
 
112
+ class LTXAttention(torch.nn.Module, AttentionModuleMixin):
113
+ _default_processor_cls = LTXVideoAttnProcessor
114
+ _available_processors = [LTXVideoAttnProcessor]
115
+
116
+ def __init__(
117
+ self,
118
+ query_dim: int,
119
+ heads: int = 8,
120
+ kv_heads: int = 8,
121
+ dim_head: int = 64,
122
+ dropout: float = 0.0,
123
+ bias: bool = True,
124
+ cross_attention_dim: Optional[int] = None,
125
+ out_bias: bool = True,
126
+ qk_norm: str = "rms_norm_across_heads",
127
+ processor=None,
128
+ ):
129
+ super().__init__()
130
+ if qk_norm != "rms_norm_across_heads":
131
+ raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
132
+
133
+ self.head_dim = dim_head
134
+ self.inner_dim = dim_head * heads
135
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
136
+ self.query_dim = query_dim
137
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
138
+ self.use_bias = bias
139
+ self.dropout = dropout
140
+ self.out_dim = query_dim
141
+ self.heads = heads
142
+
143
+ norm_eps = 1e-5
144
+ norm_elementwise_affine = True
145
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
146
+ self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
147
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
148
+ self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
149
+ self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
150
+ self.to_out = torch.nn.ModuleList([])
151
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
152
+ self.to_out.append(torch.nn.Dropout(dropout))
153
+
154
+ if processor is None:
155
+ processor = self._default_processor_cls()
156
+ self.set_processor(processor)
157
+
158
+ def forward(
159
+ self,
160
+ hidden_states: torch.Tensor,
161
+ encoder_hidden_states: Optional[torch.Tensor] = None,
162
+ attention_mask: Optional[torch.Tensor] = None,
163
+ image_rotary_emb: Optional[torch.Tensor] = None,
164
+ **kwargs,
165
+ ) -> torch.Tensor:
166
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
167
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
168
+ if len(unused_kwargs) > 0:
169
+ logger.warning(
170
+ f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
171
+ )
172
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
173
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
174
+
175
+
96
176
  class LTXVideoRotaryPosEmbed(nn.Module):
97
177
  def __init__(
98
178
  self,
@@ -231,7 +311,7 @@ class LTXVideoTransformerBlock(nn.Module):
231
311
  super().__init__()
232
312
 
233
313
  self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
234
- self.attn1 = Attention(
314
+ self.attn1 = LTXAttention(
235
315
  query_dim=dim,
236
316
  heads=num_attention_heads,
237
317
  kv_heads=num_attention_heads,
@@ -240,11 +320,10 @@ class LTXVideoTransformerBlock(nn.Module):
240
320
  cross_attention_dim=None,
241
321
  out_bias=attention_out_bias,
242
322
  qk_norm=qk_norm,
243
- processor=LTXVideoAttentionProcessor2_0(),
244
323
  )
245
324
 
246
325
  self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
247
- self.attn2 = Attention(
326
+ self.attn2 = LTXAttention(
248
327
  query_dim=dim,
249
328
  cross_attention_dim=cross_attention_dim,
250
329
  heads=num_attention_heads,
@@ -253,7 +332,6 @@ class LTXVideoTransformerBlock(nn.Module):
253
332
  bias=attention_bias,
254
333
  out_bias=attention_out_bias,
255
334
  qk_norm=qk_norm,
256
- processor=LTXVideoAttentionProcessor2_0(),
257
335
  )
258
336
 
259
337
  self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -299,7 +377,9 @@ class LTXVideoTransformerBlock(nn.Module):
299
377
 
300
378
 
301
379
  @maybe_allow_in_graph
302
- class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
380
+ class LTXVideoTransformer3DModel(
381
+ ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
382
+ ):
303
383
  r"""
304
384
  A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
305
385
 
@@ -328,6 +408,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
328
408
 
329
409
  _supports_gradient_checkpointing = True
330
410
  _skip_layerwise_casting_patterns = ["norm"]
411
+ _repeated_blocks = ["LTXVideoTransformerBlock"]
331
412
 
332
413
  @register_to_config
333
414
  def __init__(
@@ -481,7 +562,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
481
562
 
482
563
  def apply_rotary_emb(x, freqs):
483
564
  cos, sin = freqs
484
- x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
565
+ x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
485
566
  x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
486
567
  out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
487
568
  return out
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.