diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,978 @@
1
+ # Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import math
17
+ import re
18
+ from copy import deepcopy
19
+ from typing import Any, Callable, Dict, List, Optional, Union
20
+
21
+ import ftfy
22
+ import torch
23
+ from transformers import AutoTokenizer, UMT5EncoderModel
24
+
25
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from ...loaders import SkyReelsV2LoraLoaderMixin
27
+ from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
28
+ from ...schedulers import UniPCMultistepScheduler
29
+ from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
30
+ from ...utils.torch_utils import randn_tensor
31
+ from ...video_processor import VideoProcessor
32
+ from ..pipeline_utils import DiffusionPipeline
33
+ from .pipeline_output import SkyReelsV2PipelineOutput
34
+
35
+
36
+ if is_torch_xla_available():
37
+ import torch_xla.core.xla_model as xm
38
+
39
+ XLA_AVAILABLE = True
40
+ else:
41
+ XLA_AVAILABLE = False
42
+
43
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
+
45
+ if is_ftfy_available():
46
+ import ftfy
47
+
48
+
49
+ EXAMPLE_DOC_STRING = """\
50
+ Examples:
51
+ ```py
52
+ >>> import torch
53
+ >>> from diffusers import (
54
+ ... SkyReelsV2DiffusionForcingPipeline,
55
+ ... UniPCMultistepScheduler,
56
+ ... AutoencoderKLWan,
57
+ ... )
58
+ >>> from diffusers.utils import export_to_video
59
+
60
+ >>> # Load the pipeline
61
+ >>> # Available models:
62
+ >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
63
+ >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
64
+ >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
65
+ >>> vae = AutoencoderKLWan.from_pretrained(
66
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
67
+ ... subfolder="vae",
68
+ ... torch_dtype=torch.float32,
69
+ ... )
70
+ >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
71
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
72
+ ... vae=vae,
73
+ ... torch_dtype=torch.bfloat16,
74
+ ... )
75
+ >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
76
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
77
+ >>> pipe = pipe.to("cuda")
78
+
79
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
80
+
81
+ >>> output = pipe(
82
+ ... prompt=prompt,
83
+ ... num_inference_steps=30,
84
+ ... height=544,
85
+ ... width=960,
86
+ ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V
87
+ ... num_frames=97,
88
+ ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode)
89
+ ... causal_block_size=5, # Number of frames processed together in a causal block
90
+ ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos
91
+ ... addnoise_condition=20, # Improves consistency in long video generation
92
+ ... ).frames[0]
93
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
94
+ ```
95
+ """
96
+
97
+
98
+ def basic_clean(text):
99
+ text = ftfy.fix_text(text)
100
+ text = html.unescape(html.unescape(text))
101
+ return text.strip()
102
+
103
+
104
+ def whitespace_clean(text):
105
+ text = re.sub(r"\s+", " ", text)
106
+ text = text.strip()
107
+ return text
108
+
109
+
110
+ def prompt_clean(text):
111
+ text = whitespace_clean(basic_clean(text))
112
+ return text
113
+
114
+
115
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
116
+ def retrieve_latents(
117
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
118
+ ):
119
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
120
+ return encoder_output.latent_dist.sample(generator)
121
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
122
+ return encoder_output.latent_dist.mode()
123
+ elif hasattr(encoder_output, "latents"):
124
+ return encoder_output.latents
125
+ else:
126
+ raise AttributeError("Could not access latents of provided encoder_output")
127
+
128
+
129
+ class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
130
+ """
131
+ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing.
132
+
133
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
134
+ implemented for all pipelines (downloading, saving, running on a specific device, etc.).
135
+
136
+ Args:
137
+ tokenizer ([`AutoTokenizer`]):
138
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
139
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
140
+ text_encoder ([`UMT5EncoderModel`]):
141
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
142
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
143
+ transformer ([`SkyReelsV2Transformer3DModel`]):
144
+ Conditional Transformer to denoise the encoded image latents.
145
+ scheduler ([`UniPCMultistepScheduler`]):
146
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
147
+ vae ([`AutoencoderKLWan`]):
148
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
149
+ """
150
+
151
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
152
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
153
+
154
+ def __init__(
155
+ self,
156
+ tokenizer: AutoTokenizer,
157
+ text_encoder: UMT5EncoderModel,
158
+ transformer: SkyReelsV2Transformer3DModel,
159
+ vae: AutoencoderKLWan,
160
+ scheduler: UniPCMultistepScheduler,
161
+ ):
162
+ super().__init__()
163
+
164
+ self.register_modules(
165
+ vae=vae,
166
+ text_encoder=text_encoder,
167
+ tokenizer=tokenizer,
168
+ transformer=transformer,
169
+ scheduler=scheduler,
170
+ )
171
+
172
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
173
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
174
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
175
+
176
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
177
+ def _get_t5_prompt_embeds(
178
+ self,
179
+ prompt: Union[str, List[str]] = None,
180
+ num_videos_per_prompt: int = 1,
181
+ max_sequence_length: int = 226,
182
+ device: Optional[torch.device] = None,
183
+ dtype: Optional[torch.dtype] = None,
184
+ ):
185
+ device = device or self._execution_device
186
+ dtype = dtype or self.text_encoder.dtype
187
+
188
+ prompt = [prompt] if isinstance(prompt, str) else prompt
189
+ prompt = [prompt_clean(u) for u in prompt]
190
+ batch_size = len(prompt)
191
+
192
+ text_inputs = self.tokenizer(
193
+ prompt,
194
+ padding="max_length",
195
+ max_length=max_sequence_length,
196
+ truncation=True,
197
+ add_special_tokens=True,
198
+ return_attention_mask=True,
199
+ return_tensors="pt",
200
+ )
201
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
202
+ seq_lens = mask.gt(0).sum(dim=1).long()
203
+
204
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
205
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
206
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
207
+ prompt_embeds = torch.stack(
208
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
209
+ )
210
+
211
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
212
+ _, seq_len, _ = prompt_embeds.shape
213
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
214
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
215
+
216
+ return prompt_embeds
217
+
218
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
219
+ def encode_prompt(
220
+ self,
221
+ prompt: Union[str, List[str]],
222
+ negative_prompt: Optional[Union[str, List[str]]] = None,
223
+ do_classifier_free_guidance: bool = True,
224
+ num_videos_per_prompt: int = 1,
225
+ prompt_embeds: Optional[torch.Tensor] = None,
226
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
227
+ max_sequence_length: int = 226,
228
+ device: Optional[torch.device] = None,
229
+ dtype: Optional[torch.dtype] = None,
230
+ ):
231
+ r"""
232
+ Encodes the prompt into text encoder hidden states.
233
+
234
+ Args:
235
+ prompt (`str` or `List[str]`, *optional*):
236
+ prompt to be encoded
237
+ negative_prompt (`str` or `List[str]`, *optional*):
238
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
239
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
240
+ less than `1`).
241
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
242
+ Whether to use classifier free guidance or not.
243
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
244
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
245
+ prompt_embeds (`torch.Tensor`, *optional*):
246
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
247
+ provided, text embeddings will be generated from `prompt` input argument.
248
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
249
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
250
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
251
+ argument.
252
+ device: (`torch.device`, *optional*):
253
+ torch device
254
+ dtype: (`torch.dtype`, *optional*):
255
+ torch dtype
256
+ """
257
+ device = device or self._execution_device
258
+
259
+ prompt = [prompt] if isinstance(prompt, str) else prompt
260
+ if prompt is not None:
261
+ batch_size = len(prompt)
262
+ else:
263
+ batch_size = prompt_embeds.shape[0]
264
+
265
+ if prompt_embeds is None:
266
+ prompt_embeds = self._get_t5_prompt_embeds(
267
+ prompt=prompt,
268
+ num_videos_per_prompt=num_videos_per_prompt,
269
+ max_sequence_length=max_sequence_length,
270
+ device=device,
271
+ dtype=dtype,
272
+ )
273
+
274
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
275
+ negative_prompt = negative_prompt or ""
276
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
277
+
278
+ if prompt is not None and type(prompt) is not type(negative_prompt):
279
+ raise TypeError(
280
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
281
+ f" {type(prompt)}."
282
+ )
283
+ elif batch_size != len(negative_prompt):
284
+ raise ValueError(
285
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
286
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
287
+ " the batch size of `prompt`."
288
+ )
289
+
290
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
291
+ prompt=negative_prompt,
292
+ num_videos_per_prompt=num_videos_per_prompt,
293
+ max_sequence_length=max_sequence_length,
294
+ device=device,
295
+ dtype=dtype,
296
+ )
297
+
298
+ return prompt_embeds, negative_prompt_embeds
299
+
300
+ def check_inputs(
301
+ self,
302
+ prompt,
303
+ negative_prompt,
304
+ height,
305
+ width,
306
+ prompt_embeds=None,
307
+ negative_prompt_embeds=None,
308
+ callback_on_step_end_tensor_inputs=None,
309
+ overlap_history=None,
310
+ num_frames=None,
311
+ base_num_frames=None,
312
+ ):
313
+ if height % 16 != 0 or width % 16 != 0:
314
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
315
+
316
+ if callback_on_step_end_tensor_inputs is not None and not all(
317
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
318
+ ):
319
+ raise ValueError(
320
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
321
+ )
322
+
323
+ if prompt is not None and prompt_embeds is not None:
324
+ raise ValueError(
325
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
326
+ " only forward one of the two."
327
+ )
328
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
329
+ raise ValueError(
330
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
331
+ " only forward one of the two."
332
+ )
333
+ elif prompt is None and prompt_embeds is None:
334
+ raise ValueError(
335
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
336
+ )
337
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
338
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
339
+ elif negative_prompt is not None and (
340
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
341
+ ):
342
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
343
+
344
+ if num_frames > base_num_frames and overlap_history is None:
345
+ raise ValueError(
346
+ "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. "
347
+ "Please specify a value for `overlap_history`. Recommended values are 17 or 37."
348
+ )
349
+
350
+ def prepare_latents(
351
+ self,
352
+ batch_size: int,
353
+ num_channels_latents: int = 16,
354
+ height: int = 480,
355
+ width: int = 832,
356
+ num_frames: int = 97,
357
+ dtype: Optional[torch.dtype] = None,
358
+ device: Optional[torch.device] = None,
359
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
360
+ latents: Optional[torch.Tensor] = None,
361
+ base_latent_num_frames: Optional[int] = None,
362
+ video_latents: Optional[torch.Tensor] = None,
363
+ causal_block_size: Optional[int] = None,
364
+ overlap_history_latent_frames: Optional[int] = None,
365
+ long_video_iter: Optional[int] = None,
366
+ ) -> torch.Tensor:
367
+ if latents is not None:
368
+ return latents.to(device=device, dtype=dtype)
369
+
370
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
371
+ latent_height = height // self.vae_scale_factor_spatial
372
+ latent_width = width // self.vae_scale_factor_spatial
373
+
374
+ prefix_video_latents = None
375
+ prefix_video_latents_frames = 0
376
+
377
+ if video_latents is not None: # long video generation at the iterations other than the first one
378
+ prefix_video_latents = video_latents[:, :, -overlap_history_latent_frames:]
379
+
380
+ if prefix_video_latents.shape[2] % causal_block_size != 0:
381
+ truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size
382
+ logger.warning(
383
+ f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. "
384
+ f"This truncation ensures compatibility with the causal block size, which is required for proper processing. "
385
+ f"However, it may slightly affect the continuity of the generated video at the truncation boundary."
386
+ )
387
+ prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents]
388
+ prefix_video_latents_frames = prefix_video_latents.shape[2]
389
+
390
+ finished_frame_num = (
391
+ long_video_iter * (base_latent_num_frames - overlap_history_latent_frames)
392
+ + overlap_history_latent_frames
393
+ )
394
+ left_frame_num = num_latent_frames - finished_frame_num
395
+ num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames)
396
+ elif base_latent_num_frames is not None: # long video generation at the first iteration
397
+ num_latent_frames = base_latent_num_frames
398
+ else: # short video generation
399
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
400
+
401
+ shape = (
402
+ batch_size,
403
+ num_channels_latents,
404
+ num_latent_frames,
405
+ latent_height,
406
+ latent_width,
407
+ )
408
+ if isinstance(generator, list) and len(generator) != batch_size:
409
+ raise ValueError(
410
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
411
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
412
+ )
413
+
414
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
415
+
416
+ return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_frames
417
+
418
+ def generate_timestep_matrix(
419
+ self,
420
+ num_latent_frames: int,
421
+ step_template: torch.Tensor,
422
+ base_num_latent_frames: int,
423
+ ar_step: int = 5,
424
+ num_pre_ready: int = 0,
425
+ causal_block_size: int = 1,
426
+ shrink_interval_with_mask: bool = False,
427
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
428
+ """
429
+ This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule
430
+ across temporal frames. It supports both synchronous and asynchronous generation modes:
431
+
432
+ **Synchronous Mode** (ar_step=0, causal_block_size=1):
433
+ - All frames are denoised simultaneously at each timestep
434
+ - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0]
435
+ - Simpler but may have less temporal consistency for long videos
436
+
437
+ **Asynchronous Mode** (ar_step>0, causal_block_size>1):
438
+ - Frames are grouped into causal blocks and processed block/chunk-wise
439
+ - Each block is denoised in a staggered pattern creating a "denoising wave"
440
+ - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps
441
+ - Creates stronger temporal dependencies and better consistency
442
+
443
+ Args:
444
+ num_latent_frames (int): Total number of latent frames to generate
445
+ step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0])
446
+ base_num_latent_frames (int): Maximum frames the model can process in one forward pass
447
+ ar_step (int, optional): Autoregressive step size for temporal lag.
448
+ 0 = synchronous, >0 = asynchronous. Defaults to 5.
449
+ num_pre_ready (int, optional):
450
+ Number of frames already denoised (e.g., from prefix in a video2video task).
451
+ Defaults to 0.
452
+ causal_block_size (int, optional): Number of frames processed as a causal block.
453
+ Defaults to 1.
454
+ shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals.
455
+ Defaults to False.
456
+
457
+ Returns:
458
+ tuple containing:
459
+ - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape:
460
+ [num_iterations, num_latent_frames]
461
+ - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations,
462
+ num_latent_frames]
463
+ - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape:
464
+ [num_iterations, num_latent_frames]
465
+ - valid_interval (list[tuple]): List of (start, end) intervals for each iteration
466
+
467
+ Raises:
468
+ ValueError: If ar_step is too small for the given configuration
469
+ """
470
+ # Initialize lists to store the scheduling matrices and metadata
471
+ step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration
472
+ update_mask, valid_interval = [], [] # Will store update masks and processing intervals
473
+
474
+ # Calculate total number of denoising iterations (add 1 for initial noise state)
475
+ num_iterations = len(step_template) + 1
476
+
477
+ # Convert frame counts to block counts for causal processing
478
+ # Each block contains causal_block_size frames that are processed together
479
+ # E.g.: 25 frames ÷ 5 = 5 blocks total
480
+ num_blocks = num_latent_frames // causal_block_size
481
+ base_num_blocks = base_num_latent_frames // causal_block_size
482
+
483
+ # Validate ar_step is sufficient for the given configuration
484
+ # In asynchronous mode, we need enough timesteps to create the staggered pattern
485
+ if base_num_blocks < num_blocks:
486
+ min_ar_step = len(step_template) / base_num_blocks
487
+ if ar_step < min_ar_step:
488
+ raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting")
489
+
490
+ # Extend step_template with boundary values for easier indexing
491
+ # 999: dummy value for counter starting from 1
492
+ # 0: final timestep (completely denoised)
493
+ step_template = torch.cat(
494
+ [
495
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
496
+ step_template.long(),
497
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
498
+ ]
499
+ )
500
+
501
+ # Initialize the previous row state (tracks denoising progress for each block)
502
+ # 0 means not started, num_iterations means fully denoised
503
+ pre_row = torch.zeros(num_blocks, dtype=torch.long)
504
+
505
+ # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state
506
+ if num_pre_ready > 0:
507
+ pre_row[: num_pre_ready // causal_block_size] = num_iterations
508
+
509
+ # Main loop: Generate denoising schedule until all frames are fully denoised
510
+ while not torch.all(pre_row >= (num_iterations - 1)):
511
+ # Create new row representing the next denoising step
512
+ new_row = torch.zeros(num_blocks, dtype=torch.long)
513
+
514
+ # Apply diffusion forcing logic for each block
515
+ for i in range(num_blocks):
516
+ if i == 0 or pre_row[i - 1] >= (
517
+ num_iterations - 1
518
+ ): # the first frame or the last frame is completely denoised
519
+ new_row[i] = pre_row[i] + 1
520
+ else:
521
+ # Asynchronous mode: lag behind previous block by ar_step timesteps
522
+ # This creates the "diffusion forcing" staggered pattern
523
+ new_row[i] = new_row[i - 1] - ar_step
524
+
525
+ # Clamp values to valid range [0, num_iterations]
526
+ new_row = new_row.clamp(0, num_iterations)
527
+
528
+ # Create update mask: True for blocks that need denoising update at this iteration
529
+ # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations)
530
+ # Final state example: [False, ..., False, True, True, True, True, True]
531
+ # where first 20 frames are done (False) and last 5 frames still need updates (True)
532
+ update_mask.append((new_row != pre_row) & (new_row != num_iterations))
533
+
534
+ # Store the iteration state
535
+ step_index.append(new_row) # Index into step_template
536
+ step_matrix.append(step_template[new_row]) # Actual timestep values
537
+ pre_row = new_row # Update for next iteration
538
+
539
+ # For videos longer than model capacity, we process in sliding windows
540
+ terminal_flag = base_num_blocks
541
+
542
+ # Optional optimization: shrink interval based on first update mask
543
+ if shrink_interval_with_mask:
544
+ idx_sequence = torch.arange(num_blocks, dtype=torch.int64)
545
+ update_mask = update_mask[0]
546
+ update_mask_idx = idx_sequence[update_mask]
547
+ last_update_idx = update_mask_idx[-1].item()
548
+ terminal_flag = last_update_idx + 1
549
+
550
+ # Each interval defines which frames to process in the current forward pass
551
+ for curr_mask in update_mask:
552
+ # Extend terminal flag if current mask has updates beyond current terminal
553
+ if terminal_flag < num_blocks and curr_mask[terminal_flag]:
554
+ terminal_flag += 1
555
+ # Create interval: [start, end) where start ensures we don't exceed model capacity
556
+ valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag))
557
+
558
+ # Convert lists to tensors for efficient processing
559
+ step_update_mask = torch.stack(update_mask, dim=0)
560
+ step_index = torch.stack(step_index, dim=0)
561
+ step_matrix = torch.stack(step_matrix, dim=0)
562
+
563
+ # Each block's schedule is replicated to all frames within that block
564
+ if causal_block_size > 1:
565
+ # Expand each block to causal_block_size frames
566
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
567
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
568
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
569
+ # Scale intervals from block-level to frame-level
570
+ valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval]
571
+
572
+ return step_matrix, step_index, step_update_mask, valid_interval
573
+
574
+ @property
575
+ def guidance_scale(self):
576
+ return self._guidance_scale
577
+
578
+ @property
579
+ def do_classifier_free_guidance(self):
580
+ return self._guidance_scale > 1.0
581
+
582
+ @property
583
+ def num_timesteps(self):
584
+ return self._num_timesteps
585
+
586
+ @property
587
+ def current_timestep(self):
588
+ return self._current_timestep
589
+
590
+ @property
591
+ def interrupt(self):
592
+ return self._interrupt
593
+
594
+ @property
595
+ def attention_kwargs(self):
596
+ return self._attention_kwargs
597
+
598
+ @torch.no_grad()
599
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
600
+ def __call__(
601
+ self,
602
+ prompt: Union[str, List[str]],
603
+ negative_prompt: Union[str, List[str]] = None,
604
+ height: int = 544,
605
+ width: int = 960,
606
+ num_frames: int = 97,
607
+ num_inference_steps: int = 50,
608
+ guidance_scale: float = 6.0,
609
+ num_videos_per_prompt: Optional[int] = 1,
610
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
611
+ latents: Optional[torch.Tensor] = None,
612
+ prompt_embeds: Optional[torch.Tensor] = None,
613
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
614
+ output_type: Optional[str] = "np",
615
+ return_dict: bool = True,
616
+ attention_kwargs: Optional[Dict[str, Any]] = None,
617
+ callback_on_step_end: Optional[
618
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
619
+ ] = None,
620
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
621
+ max_sequence_length: int = 512,
622
+ overlap_history: Optional[int] = None,
623
+ addnoise_condition: float = 0,
624
+ base_num_frames: int = 97,
625
+ ar_step: int = 0,
626
+ causal_block_size: Optional[int] = None,
627
+ fps: int = 24,
628
+ ):
629
+ r"""
630
+ The call function to the pipeline for generation.
631
+
632
+ Args:
633
+ prompt (`str` or `List[str]`, *optional*):
634
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
635
+ instead.
636
+ negative_prompt (`str` or `List[str]`, *optional*):
637
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
638
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
639
+ less than `1`).
640
+ height (`int`, defaults to `544`):
641
+ The height of the generated video.
642
+ width (`int`, defaults to `960`):
643
+ The width of the generated video.
644
+ num_frames (`int`, defaults to `97`):
645
+ The number of frames in the generated video.
646
+ num_inference_steps (`int`, defaults to `50`):
647
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
648
+ expense of slower inference.
649
+ guidance_scale (`float`, defaults to `6.0`):
650
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
651
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
652
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
653
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
654
+ usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
655
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
656
+ The number of images to generate per prompt.
657
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
658
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
659
+ generation deterministic.
660
+ latents (`torch.Tensor`, *optional*):
661
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
662
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
663
+ tensor is generated by sampling using the supplied random `generator`.
664
+ prompt_embeds (`torch.Tensor`, *optional*):
665
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
666
+ provided, text embeddings are generated from the `prompt` input argument.
667
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
668
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
669
+ provided, text embeddings are generated from the `negative_prompt` input argument.
670
+ output_type (`str`, *optional*, defaults to `"np"`):
671
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
672
+ return_dict (`bool`, *optional*, defaults to `True`):
673
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
674
+ attention_kwargs (`dict`, *optional*):
675
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
676
+ `self.processor` in
677
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
678
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
679
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
680
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
681
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
682
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
683
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
684
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
685
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
686
+ `._callback_tensor_inputs` attribute of your pipeline class.
687
+ max_sequence_length (`int`, *optional*, defaults to `512`):
688
+ The maximum sequence length of the prompt.
689
+ overlap_history (`int`, *optional*, defaults to `None`):
690
+ Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes
691
+ short video generation mode, and no overlap is applied. 17 and 37 are recommended to set.
692
+ addnoise_condition (`float`, *optional*, defaults to `0`):
693
+ This is used to help smooth the long video generation by adding some noise to the clean condition. Too
694
+ large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger
695
+ ones, but it is recommended to not exceed 50.
696
+ base_num_frames (`int`, *optional*, defaults to `97`):
697
+ 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**)
698
+ ar_step (`int`, *optional*, defaults to `0`):
699
+ Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous
700
+ inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed
701
+ to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole
702
+ sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous
703
+ inference may improve the instruction following and visual consistent performance.
704
+ causal_block_size (`int`, *optional*, defaults to `None`):
705
+ The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step >
706
+ 0)
707
+ fps (`int`, *optional*, defaults to `24`):
708
+ Frame rate of the generated video
709
+
710
+ Examples:
711
+
712
+ Returns:
713
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
714
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
715
+ where the first element is a list with the generated images and the second element is a list of `bool`s
716
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
717
+ """
718
+
719
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
720
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
721
+
722
+ # 1. Check inputs. Raise error if not correct
723
+ self.check_inputs(
724
+ prompt,
725
+ negative_prompt,
726
+ height,
727
+ width,
728
+ prompt_embeds,
729
+ negative_prompt_embeds,
730
+ callback_on_step_end_tensor_inputs,
731
+ overlap_history,
732
+ num_frames,
733
+ base_num_frames,
734
+ )
735
+
736
+ if addnoise_condition > 60:
737
+ logger.warning(
738
+ f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended."
739
+ )
740
+
741
+ if num_frames % self.vae_scale_factor_temporal != 1:
742
+ logger.warning(
743
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
744
+ )
745
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
746
+ num_frames = max(num_frames, 1)
747
+
748
+ self._guidance_scale = guidance_scale
749
+ self._attention_kwargs = attention_kwargs
750
+ self._current_timestep = None
751
+ self._interrupt = False
752
+
753
+ device = self._execution_device
754
+
755
+ # 2. Define call parameters
756
+ if prompt is not None and isinstance(prompt, str):
757
+ batch_size = 1
758
+ elif prompt is not None and isinstance(prompt, list):
759
+ batch_size = len(prompt)
760
+ else:
761
+ batch_size = prompt_embeds.shape[0]
762
+
763
+ # 3. Encode input prompt
764
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
765
+ prompt=prompt,
766
+ negative_prompt=negative_prompt,
767
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
768
+ num_videos_per_prompt=num_videos_per_prompt,
769
+ prompt_embeds=prompt_embeds,
770
+ negative_prompt_embeds=negative_prompt_embeds,
771
+ max_sequence_length=max_sequence_length,
772
+ device=device,
773
+ )
774
+
775
+ transformer_dtype = self.transformer.dtype
776
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
777
+ if negative_prompt_embeds is not None:
778
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
779
+
780
+ # 4. Prepare timesteps
781
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
782
+ timesteps = self.scheduler.timesteps
783
+
784
+ if causal_block_size is None:
785
+ causal_block_size = self.transformer.config.num_frame_per_block
786
+ else:
787
+ self.transformer._set_ar_attention(causal_block_size)
788
+
789
+ fps_embeds = [fps] * prompt_embeds.shape[0]
790
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
791
+
792
+ # Determine if we're doing long video generation
793
+ is_long_video = overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames
794
+ # Initialize accumulated_latents to store all latents in one tensor
795
+ accumulated_latents = None
796
+ if is_long_video:
797
+ # Long video generation setup
798
+ overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1
799
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
800
+ base_latent_num_frames = (
801
+ (base_num_frames - 1) // self.vae_scale_factor_temporal + 1
802
+ if base_num_frames is not None
803
+ else num_latent_frames
804
+ )
805
+ n_iter = (
806
+ 1
807
+ + (num_latent_frames - base_latent_num_frames - 1)
808
+ // (base_latent_num_frames - overlap_history_latent_frames)
809
+ + 1
810
+ )
811
+ else:
812
+ # Short video generation setup
813
+ n_iter = 1
814
+ base_latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
815
+
816
+ # Loop through iterations (multiple iterations only for long videos)
817
+ for iter_idx in range(n_iter):
818
+ if is_long_video:
819
+ logger.debug(f"Processing iteration {iter_idx + 1}/{n_iter} for long video generation...")
820
+
821
+ # 5. Prepare latent variables
822
+ num_channels_latents = self.transformer.config.in_channels
823
+ latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_frames = (
824
+ self.prepare_latents(
825
+ batch_size * num_videos_per_prompt,
826
+ num_channels_latents,
827
+ height,
828
+ width,
829
+ num_frames,
830
+ torch.float32,
831
+ device,
832
+ generator,
833
+ latents if iter_idx == 0 else None,
834
+ video_latents=accumulated_latents, # Pass latents directly instead of decoded video
835
+ base_latent_num_frames=base_latent_num_frames if is_long_video else None,
836
+ causal_block_size=causal_block_size,
837
+ overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None,
838
+ long_video_iter=iter_idx if is_long_video else None,
839
+ )
840
+ )
841
+
842
+ if prefix_video_latents_frames > 0:
843
+ latents[:, :, :prefix_video_latents_frames, :, :] = prefix_video_latents.to(transformer_dtype)
844
+
845
+ # 6. Prepare sample schedulers and timestep matrix
846
+ sample_schedulers = []
847
+ for _ in range(current_num_latent_frames):
848
+ sample_scheduler = deepcopy(self.scheduler)
849
+ sample_scheduler.set_timesteps(num_inference_steps, device=device)
850
+ sample_schedulers.append(sample_scheduler)
851
+
852
+ # Different matrix generation for short vs long video
853
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
854
+ current_num_latent_frames,
855
+ timesteps,
856
+ current_num_latent_frames if is_long_video else base_latent_num_frames,
857
+ ar_step,
858
+ prefix_video_latents_frames,
859
+ causal_block_size,
860
+ )
861
+
862
+ # 7. Denoising loop
863
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
864
+ self._num_timesteps = len(step_matrix)
865
+
866
+ with self.progress_bar(total=len(step_matrix)) as progress_bar:
867
+ for i, t in enumerate(step_matrix):
868
+ if self.interrupt:
869
+ continue
870
+
871
+ self._current_timestep = t
872
+ valid_interval_start, valid_interval_end = valid_interval[i]
873
+ latent_model_input = (
874
+ latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone()
875
+ )
876
+ timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone()
877
+
878
+ if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames:
879
+ noise_factor = 0.001 * addnoise_condition
880
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = (
881
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
882
+ * (1.0 - noise_factor)
883
+ + torch.randn_like(
884
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
885
+ )
886
+ * noise_factor
887
+ )
888
+ timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
889
+
890
+ noise_pred = self.transformer(
891
+ hidden_states=latent_model_input,
892
+ timestep=timestep,
893
+ encoder_hidden_states=prompt_embeds,
894
+ enable_diffusion_forcing=True,
895
+ fps=fps_embeds,
896
+ attention_kwargs=attention_kwargs,
897
+ return_dict=False,
898
+ )[0]
899
+ if self.do_classifier_free_guidance:
900
+ noise_uncond = self.transformer(
901
+ hidden_states=latent_model_input,
902
+ timestep=timestep,
903
+ encoder_hidden_states=negative_prompt_embeds,
904
+ enable_diffusion_forcing=True,
905
+ fps=fps_embeds,
906
+ attention_kwargs=attention_kwargs,
907
+ return_dict=False,
908
+ )[0]
909
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
910
+
911
+ update_mask_i = step_update_mask[i]
912
+ for idx in range(valid_interval_start, valid_interval_end):
913
+ if update_mask_i[idx].item():
914
+ latents[:, :, idx, :, :] = sample_schedulers[idx].step(
915
+ noise_pred[:, :, idx - valid_interval_start, :, :],
916
+ t[idx],
917
+ latents[:, :, idx, :, :],
918
+ return_dict=False,
919
+ )[0]
920
+
921
+ if callback_on_step_end is not None:
922
+ callback_kwargs = {}
923
+ for k in callback_on_step_end_tensor_inputs:
924
+ callback_kwargs[k] = locals()[k]
925
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
926
+
927
+ latents = callback_outputs.pop("latents", latents)
928
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
929
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
930
+
931
+ # call the callback, if provided
932
+ if i == len(step_matrix) - 1 or (
933
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
934
+ ):
935
+ progress_bar.update()
936
+
937
+ if XLA_AVAILABLE:
938
+ xm.mark_step()
939
+
940
+ # Handle latent accumulation for long videos or use the current latents for short videos
941
+ if is_long_video:
942
+ if accumulated_latents is None:
943
+ accumulated_latents = latents
944
+ else:
945
+ # Keep overlap frames for conditioning but don't include them in final output
946
+ accumulated_latents = torch.cat(
947
+ [accumulated_latents, latents[:, :, overlap_history_latent_frames:]], dim=2
948
+ )
949
+
950
+ if is_long_video:
951
+ latents = accumulated_latents
952
+
953
+ self._current_timestep = None
954
+
955
+ # Final decoding step - convert latents to pixels
956
+ if not output_type == "latent":
957
+ latents = latents.to(self.vae.dtype)
958
+ latents_mean = (
959
+ torch.tensor(self.vae.config.latents_mean)
960
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
961
+ .to(latents.device, latents.dtype)
962
+ )
963
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
964
+ latents.device, latents.dtype
965
+ )
966
+ latents = latents / latents_std + latents_mean
967
+ video = self.vae.decode(latents, return_dict=False)[0]
968
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
969
+ else:
970
+ video = latents
971
+
972
+ # Offload all models
973
+ self.maybe_free_model_hooks()
974
+
975
+ if not return_dict:
976
+ return (video,)
977
+
978
+ return SkyReelsV2PipelineOutput(frames=video)