diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +17 -12
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. diffusers-0.33.0.dist-info/RECORD +0 -608
  475. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  476. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
  477. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1068,17 +1068,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
1068
1068
  latent_sequence_length = hidden_states.shape[1]
1069
1069
  condition_sequence_length = encoder_hidden_states.shape[1]
1070
1070
  sequence_length = latent_sequence_length + condition_sequence_length
1071
- attention_mask = torch.zeros(
1071
+ attention_mask = torch.ones(
1072
1072
  batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
1073
1073
  ) # [B, N]
1074
-
1075
1074
  effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
1076
1075
  effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
1077
-
1078
- for i in range(batch_size):
1079
- attention_mask[i, : effective_sequence_length[i]] = True
1080
- # [B, 1, 1, N], for broadcasting across attention heads
1081
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
1076
+ indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N]
1077
+ mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
1078
+ attention_mask = attention_mask.masked_fill(mask_indices, False)
1079
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
1082
1080
 
1083
1081
  # 4. Transformer blocks
1084
1082
  if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -0,0 +1,416 @@
1
+ # Copyright 2025 The Framepack Team, The Hunyuan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
24
+ from ..cache_utils import CacheMixin
25
+ from ..embeddings import get_1d_rotary_pos_embed
26
+ from ..modeling_outputs import Transformer2DModelOutput
27
+ from ..modeling_utils import ModelMixin
28
+ from ..normalization import AdaLayerNormContinuous
29
+ from .transformer_hunyuan_video import (
30
+ HunyuanVideoConditionEmbedding,
31
+ HunyuanVideoPatchEmbed,
32
+ HunyuanVideoSingleTransformerBlock,
33
+ HunyuanVideoTokenRefiner,
34
+ HunyuanVideoTransformerBlock,
35
+ )
36
+
37
+
38
+ logger = get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
42
+ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
43
+ super().__init__()
44
+
45
+ self.patch_size = patch_size
46
+ self.patch_size_t = patch_size_t
47
+ self.rope_dim = rope_dim
48
+ self.theta = theta
49
+
50
+ def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
51
+ height = height // self.patch_size
52
+ width = width // self.patch_size
53
+ grid = torch.meshgrid(
54
+ frame_indices.to(device=device, dtype=torch.float32),
55
+ torch.arange(0, height, device=device, dtype=torch.float32),
56
+ torch.arange(0, width, device=device, dtype=torch.float32),
57
+ indexing="ij",
58
+ ) # 3 * [W, H, T]
59
+ grid = torch.stack(grid, dim=0) # [3, W, H, T]
60
+
61
+ freqs = []
62
+ for i in range(3):
63
+ freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
64
+ freqs.append(freq)
65
+
66
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
67
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
68
+
69
+ return freqs_cos, freqs_sin
70
+
71
+
72
+ class FramepackClipVisionProjection(nn.Module):
73
+ def __init__(self, in_channels: int, out_channels: int):
74
+ super().__init__()
75
+ self.up = nn.Linear(in_channels, out_channels * 3)
76
+ self.down = nn.Linear(out_channels * 3, out_channels)
77
+
78
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
79
+ hidden_states = self.up(hidden_states)
80
+ hidden_states = F.silu(hidden_states)
81
+ hidden_states = self.down(hidden_states)
82
+ return hidden_states
83
+
84
+
85
+ class HunyuanVideoHistoryPatchEmbed(nn.Module):
86
+ def __init__(self, in_channels: int, inner_dim: int):
87
+ super().__init__()
88
+ self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
89
+ self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
90
+ self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
91
+
92
+ def forward(
93
+ self,
94
+ latents_clean: Optional[torch.Tensor] = None,
95
+ latents_clean_2x: Optional[torch.Tensor] = None,
96
+ latents_clean_4x: Optional[torch.Tensor] = None,
97
+ ):
98
+ if latents_clean is not None:
99
+ latents_clean = self.proj(latents_clean)
100
+ latents_clean = latents_clean.flatten(2).transpose(1, 2)
101
+ if latents_clean_2x is not None:
102
+ latents_clean_2x = _pad_for_3d_conv(latents_clean_2x, (2, 4, 4))
103
+ latents_clean_2x = self.proj_2x(latents_clean_2x)
104
+ latents_clean_2x = latents_clean_2x.flatten(2).transpose(1, 2)
105
+ if latents_clean_4x is not None:
106
+ latents_clean_4x = _pad_for_3d_conv(latents_clean_4x, (4, 8, 8))
107
+ latents_clean_4x = self.proj_4x(latents_clean_4x)
108
+ latents_clean_4x = latents_clean_4x.flatten(2).transpose(1, 2)
109
+ return latents_clean, latents_clean_2x, latents_clean_4x
110
+
111
+
112
+ class HunyuanVideoFramepackTransformer3DModel(
113
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
114
+ ):
115
+ _supports_gradient_checkpointing = True
116
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
117
+ _no_split_modules = [
118
+ "HunyuanVideoTransformerBlock",
119
+ "HunyuanVideoSingleTransformerBlock",
120
+ "HunyuanVideoHistoryPatchEmbed",
121
+ "HunyuanVideoTokenRefiner",
122
+ ]
123
+
124
+ @register_to_config
125
+ def __init__(
126
+ self,
127
+ in_channels: int = 16,
128
+ out_channels: int = 16,
129
+ num_attention_heads: int = 24,
130
+ attention_head_dim: int = 128,
131
+ num_layers: int = 20,
132
+ num_single_layers: int = 40,
133
+ num_refiner_layers: int = 2,
134
+ mlp_ratio: float = 4.0,
135
+ patch_size: int = 2,
136
+ patch_size_t: int = 1,
137
+ qk_norm: str = "rms_norm",
138
+ guidance_embeds: bool = True,
139
+ text_embed_dim: int = 4096,
140
+ pooled_projection_dim: int = 768,
141
+ rope_theta: float = 256.0,
142
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
143
+ image_condition_type: Optional[str] = None,
144
+ has_image_proj: int = False,
145
+ image_proj_dim: int = 1152,
146
+ has_clean_x_embedder: int = False,
147
+ ) -> None:
148
+ super().__init__()
149
+
150
+ inner_dim = num_attention_heads * attention_head_dim
151
+ out_channels = out_channels or in_channels
152
+
153
+ # 1. Latent and condition embedders
154
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
155
+
156
+ # Framepack history projection embedder
157
+ self.clean_x_embedder = None
158
+ if has_clean_x_embedder:
159
+ self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
160
+
161
+ self.context_embedder = HunyuanVideoTokenRefiner(
162
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
163
+ )
164
+
165
+ # Framepack image-conditioning embedder
166
+ self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
167
+
168
+ self.time_text_embed = HunyuanVideoConditionEmbedding(
169
+ inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
170
+ )
171
+
172
+ # 2. RoPE
173
+ self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
174
+
175
+ # 3. Dual stream transformer blocks
176
+ self.transformer_blocks = nn.ModuleList(
177
+ [
178
+ HunyuanVideoTransformerBlock(
179
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
180
+ )
181
+ for _ in range(num_layers)
182
+ ]
183
+ )
184
+
185
+ # 4. Single stream transformer blocks
186
+ self.single_transformer_blocks = nn.ModuleList(
187
+ [
188
+ HunyuanVideoSingleTransformerBlock(
189
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
190
+ )
191
+ for _ in range(num_single_layers)
192
+ ]
193
+ )
194
+
195
+ # 5. Output projection
196
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
197
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
198
+
199
+ self.gradient_checkpointing = False
200
+
201
+ def forward(
202
+ self,
203
+ hidden_states: torch.Tensor,
204
+ timestep: torch.LongTensor,
205
+ encoder_hidden_states: torch.Tensor,
206
+ encoder_attention_mask: torch.Tensor,
207
+ pooled_projections: torch.Tensor,
208
+ image_embeds: torch.Tensor,
209
+ indices_latents: torch.Tensor,
210
+ guidance: Optional[torch.Tensor] = None,
211
+ latents_clean: Optional[torch.Tensor] = None,
212
+ indices_latents_clean: Optional[torch.Tensor] = None,
213
+ latents_history_2x: Optional[torch.Tensor] = None,
214
+ indices_latents_history_2x: Optional[torch.Tensor] = None,
215
+ latents_history_4x: Optional[torch.Tensor] = None,
216
+ indices_latents_history_4x: Optional[torch.Tensor] = None,
217
+ attention_kwargs: Optional[Dict[str, Any]] = None,
218
+ return_dict: bool = True,
219
+ ):
220
+ if attention_kwargs is not None:
221
+ attention_kwargs = attention_kwargs.copy()
222
+ lora_scale = attention_kwargs.pop("scale", 1.0)
223
+ else:
224
+ lora_scale = 1.0
225
+
226
+ if USE_PEFT_BACKEND:
227
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
228
+ scale_lora_layers(self, lora_scale)
229
+ else:
230
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
231
+ logger.warning(
232
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
233
+ )
234
+
235
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
236
+ p, p_t = self.config.patch_size, self.config.patch_size_t
237
+ post_patch_num_frames = num_frames // p_t
238
+ post_patch_height = height // p
239
+ post_patch_width = width // p
240
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
241
+
242
+ if indices_latents is None:
243
+ indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
244
+
245
+ hidden_states = self.x_embedder(hidden_states)
246
+ image_rotary_emb = self.rope(
247
+ frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
248
+ )
249
+
250
+ latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
251
+ latents_clean, latents_history_2x, latents_history_4x
252
+ )
253
+
254
+ if latents_clean is not None and indices_latents_clean is not None:
255
+ image_rotary_emb_clean = self.rope(
256
+ frame_indices=indices_latents_clean, height=height, width=width, device=hidden_states.device
257
+ )
258
+ if latents_history_2x is not None and indices_latents_history_2x is not None:
259
+ image_rotary_emb_history_2x = self.rope(
260
+ frame_indices=indices_latents_history_2x, height=height, width=width, device=hidden_states.device
261
+ )
262
+ if latents_history_4x is not None and indices_latents_history_4x is not None:
263
+ image_rotary_emb_history_4x = self.rope(
264
+ frame_indices=indices_latents_history_4x, height=height, width=width, device=hidden_states.device
265
+ )
266
+
267
+ hidden_states, image_rotary_emb = self._pack_history_states(
268
+ hidden_states,
269
+ latents_clean,
270
+ latents_history_2x,
271
+ latents_history_4x,
272
+ image_rotary_emb,
273
+ image_rotary_emb_clean,
274
+ image_rotary_emb_history_2x,
275
+ image_rotary_emb_history_4x,
276
+ post_patch_height,
277
+ post_patch_width,
278
+ )
279
+
280
+ temb, _ = self.time_text_embed(timestep, pooled_projections, guidance)
281
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
282
+
283
+ encoder_hidden_states_image = self.image_projection(image_embeds)
284
+ attention_mask_image = encoder_attention_mask.new_ones((batch_size, encoder_hidden_states_image.shape[1]))
285
+
286
+ # must cat before (not after) encoder_hidden_states, due to attn masking
287
+ encoder_hidden_states = torch.cat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
288
+ encoder_attention_mask = torch.cat([attention_mask_image, encoder_attention_mask], dim=1)
289
+
290
+ latent_sequence_length = hidden_states.shape[1]
291
+ condition_sequence_length = encoder_hidden_states.shape[1]
292
+ sequence_length = latent_sequence_length + condition_sequence_length
293
+ attention_mask = torch.zeros(
294
+ batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
295
+ ) # [B, N]
296
+ effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
297
+ effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
298
+
299
+ if batch_size == 1:
300
+ encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
301
+ attention_mask = None
302
+ else:
303
+ for i in range(batch_size):
304
+ attention_mask[i, : effective_sequence_length[i]] = True
305
+ # [B, 1, 1, N], for broadcasting across attention heads
306
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
307
+
308
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
309
+ for block in self.transformer_blocks:
310
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
311
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
312
+ )
313
+
314
+ for block in self.single_transformer_blocks:
315
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
316
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
317
+ )
318
+
319
+ else:
320
+ for block in self.transformer_blocks:
321
+ hidden_states, encoder_hidden_states = block(
322
+ hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
323
+ )
324
+
325
+ for block in self.single_transformer_blocks:
326
+ hidden_states, encoder_hidden_states = block(
327
+ hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
328
+ )
329
+
330
+ hidden_states = hidden_states[:, -original_context_length:]
331
+ hidden_states = self.norm_out(hidden_states, temb)
332
+ hidden_states = self.proj_out(hidden_states)
333
+
334
+ hidden_states = hidden_states.reshape(
335
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
336
+ )
337
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
338
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
339
+
340
+ if USE_PEFT_BACKEND:
341
+ # remove `lora_scale` from each PEFT layer
342
+ unscale_lora_layers(self, lora_scale)
343
+
344
+ if not return_dict:
345
+ return (hidden_states,)
346
+ return Transformer2DModelOutput(sample=hidden_states)
347
+
348
+ def _pack_history_states(
349
+ self,
350
+ hidden_states: torch.Tensor,
351
+ latents_clean: Optional[torch.Tensor] = None,
352
+ latents_history_2x: Optional[torch.Tensor] = None,
353
+ latents_history_4x: Optional[torch.Tensor] = None,
354
+ image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None,
355
+ image_rotary_emb_clean: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
356
+ image_rotary_emb_history_2x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
357
+ image_rotary_emb_history_4x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
358
+ height: int = None,
359
+ width: int = None,
360
+ ):
361
+ image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
362
+
363
+ if latents_clean is not None and image_rotary_emb_clean is not None:
364
+ hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
365
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
366
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
367
+
368
+ if latents_history_2x is not None and image_rotary_emb_history_2x is not None:
369
+ hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
370
+ image_rotary_emb_history_2x = self._pad_rotary_emb(image_rotary_emb_history_2x, height, width, (2, 2, 2))
371
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
372
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
373
+
374
+ if latents_history_4x is not None and image_rotary_emb_history_4x is not None:
375
+ hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
376
+ image_rotary_emb_history_4x = self._pad_rotary_emb(image_rotary_emb_history_4x, height, width, (4, 4, 4))
377
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
378
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
379
+
380
+ return hidden_states, tuple(image_rotary_emb)
381
+
382
+ def _pad_rotary_emb(
383
+ self,
384
+ image_rotary_emb: Tuple[torch.Tensor],
385
+ height: int,
386
+ width: int,
387
+ kernel_size: Tuple[int, int, int],
388
+ ):
389
+ # freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim
390
+ freqs_cos, freqs_sin = image_rotary_emb
391
+ freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
392
+ freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
393
+ freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size)
394
+ freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size)
395
+ freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size)
396
+ freqs_sin = _center_down_sample_3d(freqs_sin, kernel_size)
397
+ freqs_cos = freqs_cos.flatten(2).permute(0, 2, 1).squeeze(0)
398
+ freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0)
399
+ return freqs_cos, freqs_sin
400
+
401
+
402
+ def _pad_for_3d_conv(x, kernel_size):
403
+ if isinstance(x, (tuple, list)):
404
+ return tuple(_pad_for_3d_conv(i, kernel_size) for i in x)
405
+ b, c, t, h, w = x.shape
406
+ pt, ph, pw = kernel_size
407
+ pad_t = (pt - (t % pt)) % pt
408
+ pad_h = (ph - (h % ph)) % ph
409
+ pad_w = (pw - (w % pw)) % pw
410
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
411
+
412
+
413
+ def _center_down_sample_3d(x, kernel_size):
414
+ if isinstance(x, (tuple, list)):
415
+ return tuple(_center_down_sample_3d(i, kernel_size) for i in x)
416
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Genmo team and The HuggingFace Team.
1
+ # Copyright 2025 The Genmo team and The HuggingFace Team.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -481,7 +481,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
481
481
 
482
482
  def apply_rotary_emb(x, freqs):
483
483
  cos, sin = freqs
484
- x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
484
+ x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
485
485
  x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
486
486
  out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
487
487
  return out
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Genmo team and The HuggingFace Team.
1
+ # Copyright 2025 The Genmo team and The HuggingFace Team.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -1,4 +1,4 @@
1
- # Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -283,7 +283,7 @@ class OmniGenBlock(nn.Module):
283
283
 
284
284
  class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
285
285
  """
286
- The Transformer model introduced in OmniGen (https://arxiv.org/pdf/2409.11340).
286
+ The Transformer model introduced in OmniGen (https://huggingface.co/papers/2409.11340).
287
287
 
288
288
  Parameters:
289
289
  in_channels (`int`, defaults to `4`):
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
1
+ # Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,19 +18,19 @@ import torch.nn as nn
18
18
 
19
19
  from ...configuration_utils import ConfigMixin, register_to_config
20
20
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
21
- from ...models.attention import FeedForward, JointTransformerBlock
22
- from ...models.attention_processor import (
21
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
22
+ from ...utils.torch_utils import maybe_allow_in_graph
23
+ from ..attention import FeedForward, JointTransformerBlock
24
+ from ..attention_processor import (
23
25
  Attention,
24
26
  AttentionProcessor,
25
27
  FusedJointAttnProcessor2_0,
26
28
  JointAttnProcessor2_0,
27
29
  )
28
- from ...models.modeling_utils import ModelMixin
29
- from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
30
- from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
31
- from ...utils.torch_utils import maybe_allow_in_graph
32
30
  from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
33
31
  from ..modeling_outputs import Transformer2DModelOutput
32
+ from ..modeling_utils import ModelMixin
33
+ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
34
34
 
35
35
 
36
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -49,8 +49,10 @@ class WanAttnProcessor2_0:
49
49
  ) -> torch.Tensor:
50
50
  encoder_hidden_states_img = None
51
51
  if attn.add_k_proj is not None:
52
- encoder_hidden_states_img = encoder_hidden_states[:, :257]
53
- encoder_hidden_states = encoder_hidden_states[:, 257:]
52
+ # 512 is the context length of the text encoder, hardcoded for now
53
+ image_context_length = encoder_hidden_states.shape[1] - 512
54
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
55
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
54
56
  if encoder_hidden_states is None:
55
57
  encoder_hidden_states = hidden_states
56
58
 
@@ -70,7 +72,8 @@ class WanAttnProcessor2_0:
70
72
  if rotary_emb is not None:
71
73
 
72
74
  def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
73
- x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
75
+ dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
76
+ x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
74
77
  x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
75
78
  return x_out.type_as(hidden_states)
76
79
 
@@ -108,14 +111,23 @@ class WanAttnProcessor2_0:
108
111
 
109
112
 
110
113
  class WanImageEmbedding(torch.nn.Module):
111
- def __init__(self, in_features: int, out_features: int):
114
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
112
115
  super().__init__()
113
116
 
114
117
  self.norm1 = FP32LayerNorm(in_features)
115
118
  self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
116
119
  self.norm2 = FP32LayerNorm(out_features)
120
+ if pos_embed_seq_len is not None:
121
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
122
+ else:
123
+ self.pos_embed = None
117
124
 
118
125
  def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
126
+ if self.pos_embed is not None:
127
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
128
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
129
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
130
+
119
131
  hidden_states = self.norm1(encoder_hidden_states_image)
120
132
  hidden_states = self.ff(hidden_states)
121
133
  hidden_states = self.norm2(hidden_states)
@@ -130,6 +142,7 @@ class WanTimeTextImageEmbedding(nn.Module):
130
142
  time_proj_dim: int,
131
143
  text_embed_dim: int,
132
144
  image_embed_dim: Optional[int] = None,
145
+ pos_embed_seq_len: Optional[int] = None,
133
146
  ):
134
147
  super().__init__()
135
148
 
@@ -141,7 +154,7 @@ class WanTimeTextImageEmbedding(nn.Module):
141
154
 
142
155
  self.image_embedder = None
143
156
  if image_embed_dim is not None:
144
- self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
157
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
145
158
 
146
159
  def forward(
147
160
  self,
@@ -178,9 +191,10 @@ class WanRotaryPosEmbed(nn.Module):
178
191
  t_dim = attention_head_dim - h_dim - w_dim
179
192
 
180
193
  freqs = []
194
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
181
195
  for dim in [t_dim, h_dim, w_dim]:
182
196
  freq = get_1d_rotary_pos_embed(
183
- dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
197
+ dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
184
198
  )
185
199
  freqs.append(freq)
186
200
  self.freqs = torch.cat(freqs, dim=1)
@@ -190,8 +204,8 @@ class WanRotaryPosEmbed(nn.Module):
190
204
  p_t, p_h, p_w = self.patch_size
191
205
  ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
192
206
 
193
- self.freqs = self.freqs.to(hidden_states.device)
194
- freqs = self.freqs.split_with_sizes(
207
+ freqs = self.freqs.to(hidden_states.device)
208
+ freqs = freqs.split_with_sizes(
195
209
  [
196
210
  self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
197
211
  self.attention_head_dim // 6,
@@ -350,6 +364,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
350
364
  image_dim: Optional[int] = None,
351
365
  added_kv_proj_dim: Optional[int] = None,
352
366
  rope_max_seq_len: int = 1024,
367
+ pos_embed_seq_len: Optional[int] = None,
353
368
  ) -> None:
354
369
  super().__init__()
355
370
 
@@ -368,6 +383,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
368
383
  time_proj_dim=inner_dim * 6,
369
384
  text_embed_dim=text_dim,
370
385
  image_embed_dim=image_dim,
386
+ pos_embed_seq_len=pos_embed_seq_len,
371
387
  )
372
388
 
373
389
  # 3. Transformer blocks