diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
1
+ # Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,19 +18,19 @@ import torch.nn as nn
18
18
 
19
19
  from ...configuration_utils import ConfigMixin, register_to_config
20
20
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
21
- from ...models.attention import FeedForward, JointTransformerBlock
22
- from ...models.attention_processor import (
21
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
22
+ from ...utils.torch_utils import maybe_allow_in_graph
23
+ from ..attention import FeedForward, JointTransformerBlock
24
+ from ..attention_processor import (
23
25
  Attention,
24
26
  AttentionProcessor,
25
27
  FusedJointAttnProcessor2_0,
26
28
  JointAttnProcessor2_0,
27
29
  )
28
- from ...models.modeling_utils import ModelMixin
29
- from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
30
- from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
31
- from ...utils.torch_utils import maybe_allow_in_graph
32
30
  from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
33
31
  from ..modeling_outputs import Transformer2DModelOutput
32
+ from ..modeling_utils import ModelMixin
33
+ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
34
34
 
35
35
 
36
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -0,0 +1,607 @@
1
+ # Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
+ from ..attention import FeedForward
26
+ from ..attention_processor import Attention
27
+ from ..cache_utils import CacheMixin
28
+ from ..embeddings import (
29
+ PixArtAlphaTextProjection,
30
+ TimestepEmbedding,
31
+ get_1d_rotary_pos_embed,
32
+ get_1d_sincos_pos_embed_from_grid,
33
+ )
34
+ from ..modeling_outputs import Transformer2DModelOutput
35
+ from ..modeling_utils import ModelMixin, get_parameter_dtype
36
+ from ..normalization import FP32LayerNorm
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ class SkyReelsV2AttnProcessor2_0:
43
+ def __init__(self):
44
+ if not hasattr(F, "scaled_dot_product_attention"):
45
+ raise ImportError(
46
+ "SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
47
+ )
48
+
49
+ def __call__(
50
+ self,
51
+ attn: Attention,
52
+ hidden_states: torch.Tensor,
53
+ encoder_hidden_states: Optional[torch.Tensor] = None,
54
+ attention_mask: Optional[torch.Tensor] = None,
55
+ rotary_emb: Optional[torch.Tensor] = None,
56
+ ) -> torch.Tensor:
57
+ encoder_hidden_states_img = None
58
+ if attn.add_k_proj is not None:
59
+ # 512 is the context length of the text encoder, hardcoded for now
60
+ image_context_length = encoder_hidden_states.shape[1] - 512
61
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
62
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
63
+ if encoder_hidden_states is None:
64
+ encoder_hidden_states = hidden_states
65
+
66
+ query = attn.to_q(hidden_states)
67
+ key = attn.to_k(encoder_hidden_states)
68
+ value = attn.to_v(encoder_hidden_states)
69
+
70
+ if attn.norm_q is not None:
71
+ query = attn.norm_q(query)
72
+ if attn.norm_k is not None:
73
+ key = attn.norm_k(key)
74
+
75
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
76
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
77
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
78
+
79
+ if rotary_emb is not None:
80
+
81
+ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
82
+ x_rotated = torch.view_as_complex(hidden_states.to(torch.float32).unflatten(3, (-1, 2)))
83
+ x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
84
+ return x_out.type_as(hidden_states)
85
+
86
+ query = apply_rotary_emb(query, rotary_emb)
87
+ key = apply_rotary_emb(key, rotary_emb)
88
+
89
+ # I2V task
90
+ hidden_states_img = None
91
+ if encoder_hidden_states_img is not None:
92
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
93
+ key_img = attn.norm_added_k(key_img)
94
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
95
+
96
+ key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
97
+ value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
98
+
99
+ hidden_states_img = F.scaled_dot_product_attention(
100
+ query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
101
+ )
102
+ hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
103
+ hidden_states_img = hidden_states_img.type_as(query)
104
+
105
+ hidden_states = F.scaled_dot_product_attention(
106
+ query,
107
+ key,
108
+ value,
109
+ attn_mask=attention_mask,
110
+ dropout_p=0.0,
111
+ is_causal=False,
112
+ )
113
+
114
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
115
+ hidden_states = hidden_states.type_as(query)
116
+
117
+ if hidden_states_img is not None:
118
+ hidden_states = hidden_states + hidden_states_img
119
+
120
+ hidden_states = attn.to_out[0](hidden_states)
121
+ hidden_states = attn.to_out[1](hidden_states)
122
+ return hidden_states
123
+
124
+
125
+ # Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding
126
+ class SkyReelsV2ImageEmbedding(torch.nn.Module):
127
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
128
+ super().__init__()
129
+
130
+ self.norm1 = FP32LayerNorm(in_features)
131
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
132
+ self.norm2 = FP32LayerNorm(out_features)
133
+ if pos_embed_seq_len is not None:
134
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
135
+ else:
136
+ self.pos_embed = None
137
+
138
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
139
+ if self.pos_embed is not None:
140
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
141
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
142
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
143
+
144
+ hidden_states = self.norm1(encoder_hidden_states_image)
145
+ hidden_states = self.ff(hidden_states)
146
+ hidden_states = self.norm2(hidden_states)
147
+ return hidden_states
148
+
149
+
150
+ class SkyReelsV2Timesteps(nn.Module):
151
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, output_type: str = "pt"):
152
+ super().__init__()
153
+ self.num_channels = num_channels
154
+ self.output_type = output_type
155
+ self.flip_sin_to_cos = flip_sin_to_cos
156
+
157
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
158
+ original_shape = timesteps.shape
159
+ t_emb = get_1d_sincos_pos_embed_from_grid(
160
+ self.num_channels,
161
+ timesteps,
162
+ output_type=self.output_type,
163
+ flip_sin_to_cos=self.flip_sin_to_cos,
164
+ )
165
+ # Reshape back to maintain batch structure
166
+ if len(original_shape) > 1:
167
+ t_emb = t_emb.reshape(*original_shape, self.num_channels)
168
+ return t_emb
169
+
170
+
171
+ class SkyReelsV2TimeTextImageEmbedding(nn.Module):
172
+ def __init__(
173
+ self,
174
+ dim: int,
175
+ time_freq_dim: int,
176
+ time_proj_dim: int,
177
+ text_embed_dim: int,
178
+ image_embed_dim: Optional[int] = None,
179
+ pos_embed_seq_len: Optional[int] = None,
180
+ ):
181
+ super().__init__()
182
+
183
+ self.timesteps_proj = SkyReelsV2Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True)
184
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
185
+ self.act_fn = nn.SiLU()
186
+ self.time_proj = nn.Linear(dim, time_proj_dim)
187
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
188
+
189
+ self.image_embedder = None
190
+ if image_embed_dim is not None:
191
+ self.image_embedder = SkyReelsV2ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
192
+
193
+ def forward(
194
+ self,
195
+ timestep: torch.Tensor,
196
+ encoder_hidden_states: torch.Tensor,
197
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
198
+ ):
199
+ timestep = self.timesteps_proj(timestep)
200
+
201
+ time_embedder_dtype = get_parameter_dtype(self.time_embedder)
202
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
203
+ timestep = timestep.to(time_embedder_dtype)
204
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
205
+ timestep_proj = self.time_proj(self.act_fn(temb))
206
+
207
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
208
+ if encoder_hidden_states_image is not None:
209
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
210
+
211
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
212
+
213
+
214
+ class SkyReelsV2RotaryPosEmbed(nn.Module):
215
+ def __init__(
216
+ self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
217
+ ):
218
+ super().__init__()
219
+
220
+ self.attention_head_dim = attention_head_dim
221
+ self.patch_size = patch_size
222
+ self.max_seq_len = max_seq_len
223
+
224
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
225
+ t_dim = attention_head_dim - h_dim - w_dim
226
+
227
+ freqs = []
228
+ for dim in [t_dim, h_dim, w_dim]:
229
+ freq = get_1d_rotary_pos_embed(
230
+ dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float32
231
+ )
232
+ freqs.append(freq)
233
+ self.freqs = torch.cat(freqs, dim=1)
234
+
235
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
236
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
237
+ p_t, p_h, p_w = self.patch_size
238
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
239
+
240
+ freqs = self.freqs.to(hidden_states.device)
241
+ freqs = freqs.split_with_sizes(
242
+ [
243
+ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
244
+ self.attention_head_dim // 6,
245
+ self.attention_head_dim // 6,
246
+ ],
247
+ dim=1,
248
+ )
249
+
250
+ freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
251
+ freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
252
+ freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
253
+ freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
254
+ return freqs
255
+
256
+
257
+ class SkyReelsV2TransformerBlock(nn.Module):
258
+ def __init__(
259
+ self,
260
+ dim: int,
261
+ ffn_dim: int,
262
+ num_heads: int,
263
+ qk_norm: str = "rms_norm_across_heads",
264
+ cross_attn_norm: bool = False,
265
+ eps: float = 1e-6,
266
+ added_kv_proj_dim: Optional[int] = None,
267
+ ):
268
+ super().__init__()
269
+
270
+ # 1. Self-attention
271
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
272
+ self.attn1 = Attention(
273
+ query_dim=dim,
274
+ heads=num_heads,
275
+ kv_heads=num_heads,
276
+ dim_head=dim // num_heads,
277
+ qk_norm=qk_norm,
278
+ eps=eps,
279
+ bias=True,
280
+ cross_attention_dim=None,
281
+ out_bias=True,
282
+ processor=SkyReelsV2AttnProcessor2_0(),
283
+ )
284
+
285
+ # 2. Cross-attention
286
+ self.attn2 = Attention(
287
+ query_dim=dim,
288
+ heads=num_heads,
289
+ kv_heads=num_heads,
290
+ dim_head=dim // num_heads,
291
+ qk_norm=qk_norm,
292
+ eps=eps,
293
+ bias=True,
294
+ cross_attention_dim=None,
295
+ out_bias=True,
296
+ added_kv_proj_dim=added_kv_proj_dim,
297
+ added_proj_bias=True,
298
+ processor=SkyReelsV2AttnProcessor2_0(),
299
+ )
300
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
301
+
302
+ # 3. Feed-forward
303
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
304
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
305
+
306
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
307
+
308
+ def forward(
309
+ self,
310
+ hidden_states: torch.Tensor,
311
+ encoder_hidden_states: torch.Tensor,
312
+ temb: torch.Tensor,
313
+ rotary_emb: torch.Tensor,
314
+ attention_mask: torch.Tensor,
315
+ ) -> torch.Tensor:
316
+ if temb.dim() == 3:
317
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
318
+ self.scale_shift_table + temb.float()
319
+ ).chunk(6, dim=1)
320
+ elif temb.dim() == 4:
321
+ # For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
322
+ e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
323
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
324
+ # 1. Self-attention
325
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
326
+ attn_output = self.attn1(
327
+ hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask
328
+ )
329
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
330
+ # 2. Cross-attention
331
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
332
+ attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
333
+ hidden_states = hidden_states + attn_output
334
+
335
+ # 3. Feed-forward
336
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
337
+ hidden_states
338
+ )
339
+ ff_output = self.ffn(norm_hidden_states)
340
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
341
+ return hidden_states
342
+
343
+
344
+ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
345
+ r"""
346
+ A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
347
+
348
+ Args:
349
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
350
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
351
+ num_attention_heads (`int`, defaults to `16`):
352
+ Fixed length for text embeddings.
353
+ attention_head_dim (`int`, defaults to `128`):
354
+ The number of channels in each head.
355
+ in_channels (`int`, defaults to `16`):
356
+ The number of channels in the input.
357
+ out_channels (`int`, defaults to `16`):
358
+ The number of channels in the output.
359
+ text_dim (`int`, defaults to `4096`):
360
+ Input dimension for text embeddings.
361
+ freq_dim (`int`, defaults to `256`):
362
+ Dimension for sinusoidal time embeddings.
363
+ ffn_dim (`int`, defaults to `8192`):
364
+ Intermediate dimension in feed-forward network.
365
+ num_layers (`int`, defaults to `32`):
366
+ The number of layers of transformer blocks to use.
367
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
368
+ Window size for local attention (-1 indicates global attention).
369
+ cross_attn_norm (`bool`, defaults to `True`):
370
+ Enable cross-attention normalization.
371
+ qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`):
372
+ Enable query/key normalization.
373
+ eps (`float`, defaults to `1e-6`):
374
+ Epsilon value for normalization layers.
375
+ inject_sample_info (`bool`, defaults to `False`):
376
+ Whether to inject sample information into the model.
377
+ image_dim (`int`, *optional*):
378
+ The dimension of the image embeddings.
379
+ added_kv_proj_dim (`int`, *optional*):
380
+ The dimension of the added key/value projection.
381
+ rope_max_seq_len (`int`, defaults to `1024`):
382
+ The maximum sequence length for the rotary embeddings.
383
+ pos_embed_seq_len (`int`, *optional*):
384
+ The sequence length for the positional embeddings.
385
+ """
386
+
387
+ _supports_gradient_checkpointing = True
388
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
389
+ _no_split_modules = ["SkyReelsV2TransformerBlock"]
390
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
391
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
392
+
393
+ @register_to_config
394
+ def __init__(
395
+ self,
396
+ patch_size: Tuple[int] = (1, 2, 2),
397
+ num_attention_heads: int = 16,
398
+ attention_head_dim: int = 128,
399
+ in_channels: int = 16,
400
+ out_channels: int = 16,
401
+ text_dim: int = 4096,
402
+ freq_dim: int = 256,
403
+ ffn_dim: int = 8192,
404
+ num_layers: int = 32,
405
+ cross_attn_norm: bool = True,
406
+ qk_norm: Optional[str] = "rms_norm_across_heads",
407
+ eps: float = 1e-6,
408
+ image_dim: Optional[int] = None,
409
+ added_kv_proj_dim: Optional[int] = None,
410
+ rope_max_seq_len: int = 1024,
411
+ pos_embed_seq_len: Optional[int] = None,
412
+ inject_sample_info: bool = False,
413
+ num_frame_per_block: int = 1,
414
+ ) -> None:
415
+ super().__init__()
416
+
417
+ inner_dim = num_attention_heads * attention_head_dim
418
+ out_channels = out_channels or in_channels
419
+
420
+ # 1. Patch & position embedding
421
+ self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
422
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
423
+
424
+ # 2. Condition embeddings
425
+ # image_embedding_dim=1280 for I2V model
426
+ self.condition_embedder = SkyReelsV2TimeTextImageEmbedding(
427
+ dim=inner_dim,
428
+ time_freq_dim=freq_dim,
429
+ time_proj_dim=inner_dim * 6,
430
+ text_embed_dim=text_dim,
431
+ image_embed_dim=image_dim,
432
+ pos_embed_seq_len=pos_embed_seq_len,
433
+ )
434
+
435
+ # 3. Transformer blocks
436
+ self.blocks = nn.ModuleList(
437
+ [
438
+ SkyReelsV2TransformerBlock(
439
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
440
+ )
441
+ for _ in range(num_layers)
442
+ ]
443
+ )
444
+
445
+ # 4. Output norm & projection
446
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
447
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
448
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
449
+
450
+ if inject_sample_info:
451
+ self.fps_embedding = nn.Embedding(2, inner_dim)
452
+ self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="linear-silu")
453
+
454
+ self.gradient_checkpointing = False
455
+
456
+ def forward(
457
+ self,
458
+ hidden_states: torch.Tensor,
459
+ timestep: torch.LongTensor,
460
+ encoder_hidden_states: torch.Tensor,
461
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
462
+ enable_diffusion_forcing: bool = False,
463
+ fps: Optional[torch.Tensor] = None,
464
+ return_dict: bool = True,
465
+ attention_kwargs: Optional[Dict[str, Any]] = None,
466
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
467
+ if attention_kwargs is not None:
468
+ attention_kwargs = attention_kwargs.copy()
469
+ lora_scale = attention_kwargs.pop("scale", 1.0)
470
+ else:
471
+ lora_scale = 1.0
472
+
473
+ if USE_PEFT_BACKEND:
474
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
475
+ scale_lora_layers(self, lora_scale)
476
+ else:
477
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
478
+ logger.warning(
479
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
480
+ )
481
+
482
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
483
+ p_t, p_h, p_w = self.config.patch_size
484
+ post_patch_num_frames = num_frames // p_t
485
+ post_patch_height = height // p_h
486
+ post_patch_width = width // p_w
487
+
488
+ rotary_emb = self.rope(hidden_states)
489
+
490
+ hidden_states = self.patch_embedding(hidden_states)
491
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
492
+
493
+ causal_mask = None
494
+ if self.config.num_frame_per_block > 1:
495
+ block_num = post_patch_num_frames // self.config.num_frame_per_block
496
+ range_tensor = torch.arange(block_num, device=hidden_states.device).repeat_interleave(
497
+ self.config.num_frame_per_block
498
+ )
499
+ causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
500
+ causal_mask = causal_mask.view(post_patch_num_frames, 1, 1, post_patch_num_frames, 1, 1)
501
+ causal_mask = causal_mask.repeat(
502
+ 1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width
503
+ )
504
+ causal_mask = causal_mask.reshape(
505
+ post_patch_num_frames * post_patch_height * post_patch_width,
506
+ post_patch_num_frames * post_patch_height * post_patch_width,
507
+ )
508
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
509
+
510
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
511
+ timestep, encoder_hidden_states, encoder_hidden_states_image
512
+ )
513
+
514
+ timestep_proj = timestep_proj.unflatten(-1, (6, -1))
515
+
516
+ if encoder_hidden_states_image is not None:
517
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
518
+
519
+ if self.config.inject_sample_info:
520
+ fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device)
521
+
522
+ fps_emb = self.fps_embedding(fps)
523
+ if enable_diffusion_forcing:
524
+ timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat(
525
+ timestep.shape[1], 1, 1
526
+ )
527
+ else:
528
+ timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1))
529
+
530
+ if enable_diffusion_forcing:
531
+ b, f = timestep.shape
532
+ temb = temb.view(b, f, 1, 1, -1)
533
+ timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1) # (b, f, 1, 1, 6, inner_dim)
534
+ temb = temb.repeat(1, 1, post_patch_height, post_patch_width, 1).flatten(1, 3)
535
+ timestep_proj = timestep_proj.repeat(1, 1, post_patch_height, post_patch_width, 1, 1).flatten(
536
+ 1, 3
537
+ ) # (b, f, pp_h, pp_w, 6, inner_dim) -> (b, f * pp_h * pp_w, 6, inner_dim)
538
+ timestep_proj = timestep_proj.transpose(1, 2).contiguous() # (b, 6, f * pp_h * pp_w, inner_dim)
539
+
540
+ # 4. Transformer blocks
541
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
542
+ for block in self.blocks:
543
+ hidden_states = self._gradient_checkpointing_func(
544
+ block,
545
+ hidden_states,
546
+ encoder_hidden_states,
547
+ timestep_proj,
548
+ rotary_emb,
549
+ causal_mask,
550
+ )
551
+ else:
552
+ for block in self.blocks:
553
+ hidden_states = block(
554
+ hidden_states,
555
+ encoder_hidden_states,
556
+ timestep_proj,
557
+ rotary_emb,
558
+ causal_mask,
559
+ )
560
+
561
+ if temb.dim() == 2:
562
+ # If temb is 2D, we assume it has time 1-D time embedding values for each batch.
563
+ # For models:
564
+ # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
565
+ # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
566
+ # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
567
+ # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
568
+ # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
569
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
570
+ elif temb.dim() == 3:
571
+ # If temb is 3D, we assume it has 2-D time embedding values for each batch.
572
+ # Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing.
573
+ # For models:
574
+ # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
575
+ # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
576
+ # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
577
+ shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1)
578
+ shift, scale = shift.squeeze(1), scale.squeeze(1)
579
+
580
+ # Move the shift and scale tensors to the same device as hidden_states.
581
+ # When using multi-GPU inference via accelerate these will be on the
582
+ # first device rather than the last device, which hidden_states ends up
583
+ # on.
584
+ shift = shift.to(hidden_states.device)
585
+ scale = scale.to(hidden_states.device)
586
+
587
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
588
+
589
+ hidden_states = self.proj_out(hidden_states)
590
+
591
+ hidden_states = hidden_states.reshape(
592
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
593
+ )
594
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
595
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
596
+
597
+ if USE_PEFT_BACKEND:
598
+ # remove `lora_scale` from each PEFT layer
599
+ unscale_lora_layers(self, lora_scale)
600
+
601
+ if not return_dict:
602
+ return (output,)
603
+
604
+ return Transformer2DModelOutput(sample=output)
605
+
606
+ def _set_ar_attention(self, causal_block_size: int):
607
+ self.register_to_config(num_frame_per_block=causal_block_size)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.