diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1059 @@
1
+ # Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import math
17
+ import re
18
+ from copy import deepcopy
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import ftfy
22
+ import PIL
23
+ import torch
24
+ from transformers import AutoTokenizer, UMT5EncoderModel
25
+
26
+ from diffusers.image_processor import PipelineImageInput
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ from diffusers.video_processor import VideoProcessor
29
+
30
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
31
+ from ...loaders import SkyReelsV2LoraLoaderMixin
32
+ from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
33
+ from ...schedulers import UniPCMultistepScheduler
34
+ from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
35
+ from ..pipeline_utils import DiffusionPipeline
36
+ from .pipeline_output import SkyReelsV2PipelineOutput
37
+
38
+
39
+ if is_torch_xla_available():
40
+ import torch_xla.core.xla_model as xm
41
+
42
+ XLA_AVAILABLE = True
43
+ else:
44
+ XLA_AVAILABLE = False
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+ if is_ftfy_available():
49
+ import ftfy
50
+
51
+
52
+ EXAMPLE_DOC_STRING = """\
53
+ Examples:
54
+ ```py
55
+ >>> import torch
56
+ >>> from diffusers import (
57
+ ... SkyReelsV2DiffusionForcingImageToVideoPipeline,
58
+ ... UniPCMultistepScheduler,
59
+ ... AutoencoderKLWan,
60
+ ... )
61
+ >>> from diffusers.utils import export_to_video
62
+ >>> from PIL import Image
63
+
64
+ >>> # Load the pipeline
65
+ >>> # Available models:
66
+ >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
67
+ >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
68
+ >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
69
+ >>> vae = AutoencoderKLWan.from_pretrained(
70
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
71
+ ... subfolder="vae",
72
+ ... torch_dtype=torch.float32,
73
+ ... )
74
+ >>> pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
75
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
76
+ ... vae=vae,
77
+ ... torch_dtype=torch.bfloat16,
78
+ ... )
79
+ >>> flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
80
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
81
+ >>> pipe = pipe.to("cuda")
82
+
83
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
84
+ >>> image = Image.open("path/to/image.png")
85
+
86
+ >>> output = pipe(
87
+ ... image=image,
88
+ ... prompt=prompt,
89
+ ... num_inference_steps=50,
90
+ ... height=544,
91
+ ... width=960,
92
+ ... guidance_scale=5.0, # 6.0 for T2V, 5.0 for I2V
93
+ ... num_frames=97,
94
+ ... ar_step=0, # Controls asynchronous inference (0 for synchronous mode)
95
+ ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos
96
+ ... addnoise_condition=20, # Improves consistency in long video generation
97
+ ... ).frames[0]
98
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
99
+ ```
100
+ """
101
+
102
+
103
+ def basic_clean(text):
104
+ text = ftfy.fix_text(text)
105
+ text = html.unescape(html.unescape(text))
106
+ return text.strip()
107
+
108
+
109
+ def whitespace_clean(text):
110
+ text = re.sub(r"\s+", " ", text)
111
+ text = text.strip()
112
+ return text
113
+
114
+
115
+ def prompt_clean(text):
116
+ text = whitespace_clean(basic_clean(text))
117
+ return text
118
+
119
+
120
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
121
+ def retrieve_latents(
122
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
123
+ ):
124
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
125
+ return encoder_output.latent_dist.sample(generator)
126
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
127
+ return encoder_output.latent_dist.mode()
128
+ elif hasattr(encoder_output, "latents"):
129
+ return encoder_output.latents
130
+ else:
131
+ raise AttributeError("Could not access latents of provided encoder_output")
132
+
133
+
134
+ class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
135
+ """
136
+ Pipeline for Image-to-Video (i2v) generation using SkyReels-V2 with diffusion forcing.
137
+
138
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
139
+ implemented for all pipelines (downloading, saving, running on a specific device, etc.).
140
+
141
+ Args:
142
+ tokenizer ([`AutoTokenizer`]):
143
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
144
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
145
+ text_encoder ([`UMT5EncoderModel`]):
146
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
147
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
148
+ transformer ([`SkyReelsV2Transformer3DModel`]):
149
+ Conditional Transformer to denoise the encoded image latents.
150
+ scheduler ([`UniPCMultistepScheduler`]):
151
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
152
+ vae ([`AutoencoderKLWan`]):
153
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
154
+ """
155
+
156
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
157
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
158
+
159
+ def __init__(
160
+ self,
161
+ tokenizer: AutoTokenizer,
162
+ text_encoder: UMT5EncoderModel,
163
+ transformer: SkyReelsV2Transformer3DModel,
164
+ vae: AutoencoderKLWan,
165
+ scheduler: UniPCMultistepScheduler,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.register_modules(
170
+ vae=vae,
171
+ text_encoder=text_encoder,
172
+ tokenizer=tokenizer,
173
+ transformer=transformer,
174
+ scheduler=scheduler,
175
+ )
176
+
177
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
178
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
179
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
180
+
181
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
182
+ def _get_t5_prompt_embeds(
183
+ self,
184
+ prompt: Union[str, List[str]] = None,
185
+ num_videos_per_prompt: int = 1,
186
+ max_sequence_length: int = 226,
187
+ device: Optional[torch.device] = None,
188
+ dtype: Optional[torch.dtype] = None,
189
+ ):
190
+ device = device or self._execution_device
191
+ dtype = dtype or self.text_encoder.dtype
192
+
193
+ prompt = [prompt] if isinstance(prompt, str) else prompt
194
+ prompt = [prompt_clean(u) for u in prompt]
195
+ batch_size = len(prompt)
196
+
197
+ text_inputs = self.tokenizer(
198
+ prompt,
199
+ padding="max_length",
200
+ max_length=max_sequence_length,
201
+ truncation=True,
202
+ add_special_tokens=True,
203
+ return_attention_mask=True,
204
+ return_tensors="pt",
205
+ )
206
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
207
+ seq_lens = mask.gt(0).sum(dim=1).long()
208
+
209
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
210
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
211
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
212
+ prompt_embeds = torch.stack(
213
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
214
+ )
215
+
216
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
217
+ _, seq_len, _ = prompt_embeds.shape
218
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
219
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
220
+
221
+ return prompt_embeds
222
+
223
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
224
+ def encode_prompt(
225
+ self,
226
+ prompt: Union[str, List[str]],
227
+ negative_prompt: Optional[Union[str, List[str]]] = None,
228
+ do_classifier_free_guidance: bool = True,
229
+ num_videos_per_prompt: int = 1,
230
+ prompt_embeds: Optional[torch.Tensor] = None,
231
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
232
+ max_sequence_length: int = 226,
233
+ device: Optional[torch.device] = None,
234
+ dtype: Optional[torch.dtype] = None,
235
+ ):
236
+ r"""
237
+ Encodes the prompt into text encoder hidden states.
238
+
239
+ Args:
240
+ prompt (`str` or `List[str]`, *optional*):
241
+ prompt to be encoded
242
+ negative_prompt (`str` or `List[str]`, *optional*):
243
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
244
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
245
+ less than `1`).
246
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
247
+ Whether to use classifier free guidance or not.
248
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
249
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
250
+ prompt_embeds (`torch.Tensor`, *optional*):
251
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
252
+ provided, text embeddings will be generated from `prompt` input argument.
253
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
254
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
255
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
256
+ argument.
257
+ device: (`torch.device`, *optional*):
258
+ torch device
259
+ dtype: (`torch.dtype`, *optional*):
260
+ torch dtype
261
+ """
262
+ device = device or self._execution_device
263
+
264
+ prompt = [prompt] if isinstance(prompt, str) else prompt
265
+ if prompt is not None:
266
+ batch_size = len(prompt)
267
+ else:
268
+ batch_size = prompt_embeds.shape[0]
269
+
270
+ if prompt_embeds is None:
271
+ prompt_embeds = self._get_t5_prompt_embeds(
272
+ prompt=prompt,
273
+ num_videos_per_prompt=num_videos_per_prompt,
274
+ max_sequence_length=max_sequence_length,
275
+ device=device,
276
+ dtype=dtype,
277
+ )
278
+
279
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
280
+ negative_prompt = negative_prompt or ""
281
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
282
+
283
+ if prompt is not None and type(prompt) is not type(negative_prompt):
284
+ raise TypeError(
285
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
286
+ f" {type(prompt)}."
287
+ )
288
+ elif batch_size != len(negative_prompt):
289
+ raise ValueError(
290
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
291
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
292
+ " the batch size of `prompt`."
293
+ )
294
+
295
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
296
+ prompt=negative_prompt,
297
+ num_videos_per_prompt=num_videos_per_prompt,
298
+ max_sequence_length=max_sequence_length,
299
+ device=device,
300
+ dtype=dtype,
301
+ )
302
+
303
+ return prompt_embeds, negative_prompt_embeds
304
+
305
+ def check_inputs(
306
+ self,
307
+ prompt,
308
+ negative_prompt,
309
+ image,
310
+ height,
311
+ width,
312
+ prompt_embeds=None,
313
+ negative_prompt_embeds=None,
314
+ image_embeds=None,
315
+ callback_on_step_end_tensor_inputs=None,
316
+ overlap_history=None,
317
+ num_frames=None,
318
+ base_num_frames=None,
319
+ ):
320
+ if image is not None and image_embeds is not None:
321
+ raise ValueError(
322
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
323
+ " only forward one of the two."
324
+ )
325
+ if image is None and image_embeds is None:
326
+ raise ValueError(
327
+ "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined."
328
+ )
329
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
330
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
331
+ if height % 16 != 0 or width % 16 != 0:
332
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
333
+
334
+ if callback_on_step_end_tensor_inputs is not None and not all(
335
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
336
+ ):
337
+ raise ValueError(
338
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
339
+ )
340
+
341
+ if prompt is not None and prompt_embeds is not None:
342
+ raise ValueError(
343
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
344
+ " only forward one of the two."
345
+ )
346
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
347
+ raise ValueError(
348
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
349
+ " only forward one of the two."
350
+ )
351
+ elif prompt is None and prompt_embeds is None:
352
+ raise ValueError(
353
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
354
+ )
355
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
356
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
357
+ elif negative_prompt is not None and (
358
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
359
+ ):
360
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
361
+
362
+ if num_frames > base_num_frames and overlap_history is None:
363
+ raise ValueError(
364
+ "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. "
365
+ "Please specify a value for `overlap_history`. Recommended values are 17 or 37."
366
+ )
367
+
368
+ def prepare_latents(
369
+ self,
370
+ image: Optional[PipelineImageInput],
371
+ batch_size: int,
372
+ num_channels_latents: int = 16,
373
+ height: int = 480,
374
+ width: int = 832,
375
+ num_frames: int = 97,
376
+ dtype: Optional[torch.dtype] = None,
377
+ device: Optional[torch.device] = None,
378
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
379
+ latents: Optional[torch.Tensor] = None,
380
+ last_image: Optional[torch.Tensor] = None,
381
+ video_latents: Optional[torch.Tensor] = None,
382
+ base_latent_num_frames: Optional[int] = None,
383
+ causal_block_size: Optional[int] = None,
384
+ overlap_history_latent_frames: Optional[int] = None,
385
+ long_video_iter: Optional[int] = None,
386
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
387
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
388
+ latent_height = height // self.vae_scale_factor_spatial
389
+ latent_width = width // self.vae_scale_factor_spatial
390
+
391
+ prefix_video_latents_frames = 0
392
+
393
+ if video_latents is not None: # long video generation at the iterations other than the first one
394
+ condition = video_latents[:, :, -overlap_history_latent_frames:]
395
+
396
+ if condition.shape[2] % causal_block_size != 0:
397
+ truncate_len_latents = condition.shape[2] % causal_block_size
398
+ logger.warning(
399
+ f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. "
400
+ f"This truncation ensures compatibility with the causal block size, which is required for proper processing. "
401
+ f"However, it may slightly affect the continuity of the generated video at the truncation boundary."
402
+ )
403
+ condition = condition[:, :, :-truncate_len_latents]
404
+ prefix_video_latents_frames = condition.shape[2]
405
+
406
+ finished_frame_num = (
407
+ long_video_iter * (base_latent_num_frames - overlap_history_latent_frames)
408
+ + overlap_history_latent_frames
409
+ )
410
+ left_frame_num = num_latent_frames - finished_frame_num
411
+ num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames)
412
+ elif base_latent_num_frames is not None: # long video generation at the first iteration
413
+ num_latent_frames = base_latent_num_frames
414
+ else: # short video generation
415
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
416
+
417
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
418
+ if isinstance(generator, list) and len(generator) != batch_size:
419
+ raise ValueError(
420
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
421
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
422
+ )
423
+
424
+ if latents is None:
425
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
426
+ else:
427
+ latents = latents.to(device=device, dtype=dtype)
428
+
429
+ if image is not None:
430
+ image = image.unsqueeze(2)
431
+ if last_image is not None:
432
+ last_image = last_image.unsqueeze(2)
433
+ video_condition = torch.cat([image, last_image], dim=0)
434
+ else:
435
+ video_condition = image
436
+
437
+ video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
438
+
439
+ latents_mean = (
440
+ torch.tensor(self.vae.config.latents_mean)
441
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
442
+ .to(latents.device, latents.dtype)
443
+ )
444
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
445
+ latents.device, latents.dtype
446
+ )
447
+
448
+ if isinstance(generator, list):
449
+ latent_condition = [
450
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
451
+ ]
452
+ latent_condition = torch.cat(latent_condition)
453
+ else:
454
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
455
+ latent_condition = latent_condition.repeat_interleave(batch_size, dim=0)
456
+
457
+ latent_condition = latent_condition.to(dtype)
458
+ condition = (latent_condition - latents_mean) * latents_std
459
+ prefix_video_latents_frames = condition.shape[2]
460
+
461
+ return latents, num_latent_frames, condition, prefix_video_latents_frames
462
+
463
+ # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix
464
+ def generate_timestep_matrix(
465
+ self,
466
+ num_latent_frames: int,
467
+ step_template: torch.Tensor,
468
+ base_num_latent_frames: int,
469
+ ar_step: int = 5,
470
+ num_pre_ready: int = 0,
471
+ causal_block_size: int = 1,
472
+ shrink_interval_with_mask: bool = False,
473
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
474
+ """
475
+ This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule
476
+ across temporal frames. It supports both synchronous and asynchronous generation modes:
477
+
478
+ **Synchronous Mode** (ar_step=0, causal_block_size=1):
479
+ - All frames are denoised simultaneously at each timestep
480
+ - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0]
481
+ - Simpler but may have less temporal consistency for long videos
482
+
483
+ **Asynchronous Mode** (ar_step>0, causal_block_size>1):
484
+ - Frames are grouped into causal blocks and processed block/chunk-wise
485
+ - Each block is denoised in a staggered pattern creating a "denoising wave"
486
+ - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps
487
+ - Creates stronger temporal dependencies and better consistency
488
+
489
+ Args:
490
+ num_latent_frames (int): Total number of latent frames to generate
491
+ step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0])
492
+ base_num_latent_frames (int): Maximum frames the model can process in one forward pass
493
+ ar_step (int, optional): Autoregressive step size for temporal lag.
494
+ 0 = synchronous, >0 = asynchronous. Defaults to 5.
495
+ num_pre_ready (int, optional):
496
+ Number of frames already denoised (e.g., from prefix in a video2video task).
497
+ Defaults to 0.
498
+ causal_block_size (int, optional): Number of frames processed as a causal block.
499
+ Defaults to 1.
500
+ shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals.
501
+ Defaults to False.
502
+
503
+ Returns:
504
+ tuple containing:
505
+ - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape:
506
+ [num_iterations, num_latent_frames]
507
+ - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations,
508
+ num_latent_frames]
509
+ - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape:
510
+ [num_iterations, num_latent_frames]
511
+ - valid_interval (list[tuple]): List of (start, end) intervals for each iteration
512
+
513
+ Raises:
514
+ ValueError: If ar_step is too small for the given configuration
515
+ """
516
+ # Initialize lists to store the scheduling matrices and metadata
517
+ step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration
518
+ update_mask, valid_interval = [], [] # Will store update masks and processing intervals
519
+
520
+ # Calculate total number of denoising iterations (add 1 for initial noise state)
521
+ num_iterations = len(step_template) + 1
522
+
523
+ # Convert frame counts to block counts for causal processing
524
+ # Each block contains causal_block_size frames that are processed together
525
+ # E.g.: 25 frames ÷ 5 = 5 blocks total
526
+ num_blocks = num_latent_frames // causal_block_size
527
+ base_num_blocks = base_num_latent_frames // causal_block_size
528
+
529
+ # Validate ar_step is sufficient for the given configuration
530
+ # In asynchronous mode, we need enough timesteps to create the staggered pattern
531
+ if base_num_blocks < num_blocks:
532
+ min_ar_step = len(step_template) / base_num_blocks
533
+ if ar_step < min_ar_step:
534
+ raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting")
535
+
536
+ # Extend step_template with boundary values for easier indexing
537
+ # 999: dummy value for counter starting from 1
538
+ # 0: final timestep (completely denoised)
539
+ step_template = torch.cat(
540
+ [
541
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
542
+ step_template.long(),
543
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
544
+ ]
545
+ )
546
+
547
+ # Initialize the previous row state (tracks denoising progress for each block)
548
+ # 0 means not started, num_iterations means fully denoised
549
+ pre_row = torch.zeros(num_blocks, dtype=torch.long)
550
+
551
+ # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state
552
+ if num_pre_ready > 0:
553
+ pre_row[: num_pre_ready // causal_block_size] = num_iterations
554
+
555
+ # Main loop: Generate denoising schedule until all frames are fully denoised
556
+ while not torch.all(pre_row >= (num_iterations - 1)):
557
+ # Create new row representing the next denoising step
558
+ new_row = torch.zeros(num_blocks, dtype=torch.long)
559
+
560
+ # Apply diffusion forcing logic for each block
561
+ for i in range(num_blocks):
562
+ if i == 0 or pre_row[i - 1] >= (
563
+ num_iterations - 1
564
+ ): # the first frame or the last frame is completely denoised
565
+ new_row[i] = pre_row[i] + 1
566
+ else:
567
+ # Asynchronous mode: lag behind previous block by ar_step timesteps
568
+ # This creates the "diffusion forcing" staggered pattern
569
+ new_row[i] = new_row[i - 1] - ar_step
570
+
571
+ # Clamp values to valid range [0, num_iterations]
572
+ new_row = new_row.clamp(0, num_iterations)
573
+
574
+ # Create update mask: True for blocks that need denoising update at this iteration
575
+ # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations)
576
+ # Final state example: [False, ..., False, True, True, True, True, True]
577
+ # where first 20 frames are done (False) and last 5 frames still need updates (True)
578
+ update_mask.append((new_row != pre_row) & (new_row != num_iterations))
579
+
580
+ # Store the iteration state
581
+ step_index.append(new_row) # Index into step_template
582
+ step_matrix.append(step_template[new_row]) # Actual timestep values
583
+ pre_row = new_row # Update for next iteration
584
+
585
+ # For videos longer than model capacity, we process in sliding windows
586
+ terminal_flag = base_num_blocks
587
+
588
+ # Optional optimization: shrink interval based on first update mask
589
+ if shrink_interval_with_mask:
590
+ idx_sequence = torch.arange(num_blocks, dtype=torch.int64)
591
+ update_mask = update_mask[0]
592
+ update_mask_idx = idx_sequence[update_mask]
593
+ last_update_idx = update_mask_idx[-1].item()
594
+ terminal_flag = last_update_idx + 1
595
+
596
+ # Each interval defines which frames to process in the current forward pass
597
+ for curr_mask in update_mask:
598
+ # Extend terminal flag if current mask has updates beyond current terminal
599
+ if terminal_flag < num_blocks and curr_mask[terminal_flag]:
600
+ terminal_flag += 1
601
+ # Create interval: [start, end) where start ensures we don't exceed model capacity
602
+ valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag))
603
+
604
+ # Convert lists to tensors for efficient processing
605
+ step_update_mask = torch.stack(update_mask, dim=0)
606
+ step_index = torch.stack(step_index, dim=0)
607
+ step_matrix = torch.stack(step_matrix, dim=0)
608
+
609
+ # Each block's schedule is replicated to all frames within that block
610
+ if causal_block_size > 1:
611
+ # Expand each block to causal_block_size frames
612
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
613
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
614
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
615
+ # Scale intervals from block-level to frame-level
616
+ valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval]
617
+
618
+ return step_matrix, step_index, step_update_mask, valid_interval
619
+
620
+ @property
621
+ def guidance_scale(self):
622
+ return self._guidance_scale
623
+
624
+ @property
625
+ def do_classifier_free_guidance(self):
626
+ return self._guidance_scale > 1.0
627
+
628
+ @property
629
+ def num_timesteps(self):
630
+ return self._num_timesteps
631
+
632
+ @property
633
+ def current_timestep(self):
634
+ return self._current_timestep
635
+
636
+ @property
637
+ def interrupt(self):
638
+ return self._interrupt
639
+
640
+ @property
641
+ def attention_kwargs(self):
642
+ return self._attention_kwargs
643
+
644
+ @torch.no_grad()
645
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
646
+ def __call__(
647
+ self,
648
+ image: PipelineImageInput,
649
+ prompt: Union[str, List[str]] = None,
650
+ negative_prompt: Union[str, List[str]] = None,
651
+ height: int = 544,
652
+ width: int = 960,
653
+ num_frames: int = 97,
654
+ num_inference_steps: int = 50,
655
+ guidance_scale: float = 5.0,
656
+ num_videos_per_prompt: Optional[int] = 1,
657
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
658
+ latents: Optional[torch.Tensor] = None,
659
+ prompt_embeds: Optional[torch.Tensor] = None,
660
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
661
+ image_embeds: Optional[torch.Tensor] = None,
662
+ last_image: Optional[torch.Tensor] = None,
663
+ output_type: Optional[str] = "np",
664
+ return_dict: bool = True,
665
+ attention_kwargs: Optional[Dict[str, Any]] = None,
666
+ callback_on_step_end: Optional[
667
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
668
+ ] = None,
669
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
670
+ max_sequence_length: int = 512,
671
+ overlap_history: Optional[int] = None,
672
+ addnoise_condition: float = 0,
673
+ base_num_frames: int = 97,
674
+ ar_step: int = 0,
675
+ causal_block_size: Optional[int] = None,
676
+ fps: int = 24,
677
+ ):
678
+ r"""
679
+ The call function to the pipeline for generation.
680
+
681
+ Args:
682
+ image (`PipelineImageInput`):
683
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
684
+ prompt (`str` or `List[str]`, *optional*):
685
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
686
+ instead.
687
+ negative_prompt (`str` or `List[str]`, *optional*):
688
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
689
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
690
+ less than `1`).
691
+ height (`int`, defaults to `544`):
692
+ The height of the generated video.
693
+ width (`int`, defaults to `960`):
694
+ The width of the generated video.
695
+ num_frames (`int`, defaults to `97`):
696
+ The number of frames in the generated video.
697
+ num_inference_steps (`int`, defaults to `50`):
698
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
699
+ expense of slower inference.
700
+ guidance_scale (`float`, defaults to `5.0`):
701
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
702
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
703
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
704
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
705
+ usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
706
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
707
+ The number of images to generate per prompt.
708
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
709
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
710
+ generation deterministic.
711
+ latents (`torch.Tensor`, *optional*):
712
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
713
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
714
+ tensor is generated by sampling using the supplied random `generator`.
715
+ prompt_embeds (`torch.Tensor`, *optional*):
716
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
717
+ provided, text embeddings are generated from the `prompt` input argument.
718
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
719
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
720
+ provided, text embeddings are generated from the `negative_prompt` input argument.
721
+ image_embeds (`torch.Tensor`, *optional*):
722
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
723
+ image embeddings are generated from the `image` input argument.
724
+ last_image (`torch.Tensor`, *optional*):
725
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
726
+ image embeddings are generated from the `image` input argument.
727
+ output_type (`str`, *optional*, defaults to `"np"`):
728
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
729
+ return_dict (`bool`, *optional*, defaults to `True`):
730
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
731
+ attention_kwargs (`dict`, *optional*):
732
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
733
+ `self.processor` in
734
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
735
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
736
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
737
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
738
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
739
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
740
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
741
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
742
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
743
+ `._callback_tensor_inputs` attribute of your pipeline class.
744
+ max_sequence_length (`int`, *optional*, defaults to `512`):
745
+ The maximum sequence length of the prompt.
746
+ overlap_history (`int`, *optional*, defaults to `None`):
747
+ Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes
748
+ short video generation mode, and no overlap is applied. 17 and 37 are recommended to set.
749
+ addnoise_condition (`float`, *optional*, defaults to `0`):
750
+ This is used to help smooth the long video generation by adding some noise to the clean condition. Too
751
+ large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger
752
+ ones, but it is recommended to not exceed 50.
753
+ base_num_frames (`int`, *optional*, defaults to `97`):
754
+ 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**)
755
+ ar_step (`int`, *optional*, defaults to `0`):
756
+ Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous
757
+ inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed
758
+ to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole
759
+ sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous
760
+ inference may improve the instruction following and visual consistent performance.
761
+ causal_block_size (`int`, *optional*, defaults to `None`):
762
+ The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step >
763
+ 0)
764
+ fps (`int`, *optional*, defaults to `24`):
765
+ Frame rate of the generated video
766
+
767
+ Examples:
768
+
769
+ Returns:
770
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
771
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
772
+ where the first element is a list with the generated images and the second element is a list of `bool`s
773
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
774
+ """
775
+
776
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
777
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
778
+
779
+ # 1. Check inputs. Raise error if not correct
780
+ self.check_inputs(
781
+ prompt,
782
+ negative_prompt,
783
+ image,
784
+ height,
785
+ width,
786
+ prompt_embeds,
787
+ negative_prompt_embeds,
788
+ image_embeds,
789
+ callback_on_step_end_tensor_inputs,
790
+ overlap_history,
791
+ num_frames,
792
+ base_num_frames,
793
+ )
794
+
795
+ if addnoise_condition > 60:
796
+ logger.warning(
797
+ f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended."
798
+ )
799
+
800
+ if num_frames % self.vae_scale_factor_temporal != 1:
801
+ logger.warning(
802
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
803
+ )
804
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
805
+ num_frames = max(num_frames, 1)
806
+
807
+ self._guidance_scale = guidance_scale
808
+ self._attention_kwargs = attention_kwargs
809
+ self._current_timestep = None
810
+ self._interrupt = False
811
+
812
+ device = self._execution_device
813
+
814
+ # 2. Define call parameters
815
+ if prompt is not None and isinstance(prompt, str):
816
+ batch_size = 1
817
+ elif prompt is not None and isinstance(prompt, list):
818
+ batch_size = len(prompt)
819
+ else:
820
+ batch_size = prompt_embeds.shape[0]
821
+
822
+ # 3. Encode input prompt
823
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
824
+ prompt=prompt,
825
+ negative_prompt=negative_prompt,
826
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
827
+ num_videos_per_prompt=num_videos_per_prompt,
828
+ prompt_embeds=prompt_embeds,
829
+ negative_prompt_embeds=negative_prompt_embeds,
830
+ max_sequence_length=max_sequence_length,
831
+ device=device,
832
+ )
833
+
834
+ transformer_dtype = self.transformer.dtype
835
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
836
+ if negative_prompt_embeds is not None:
837
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
838
+
839
+ # 4. Prepare timesteps
840
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
841
+ timesteps = self.scheduler.timesteps
842
+
843
+ if causal_block_size is None:
844
+ causal_block_size = self.transformer.config.num_frame_per_block
845
+ else:
846
+ self.transformer._set_ar_attention(causal_block_size)
847
+
848
+ fps_embeds = [fps] * prompt_embeds.shape[0]
849
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
850
+
851
+ # Determine if we're doing long video generation
852
+ is_long_video = overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames
853
+ # Initialize accumulated_latents to store all latents in one tensor
854
+ accumulated_latents = None
855
+ if is_long_video:
856
+ # Long video generation setup
857
+ overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1
858
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
859
+ base_latent_num_frames = (
860
+ (base_num_frames - 1) // self.vae_scale_factor_temporal + 1
861
+ if base_num_frames is not None
862
+ else num_latent_frames
863
+ )
864
+ n_iter = (
865
+ 1
866
+ + (num_latent_frames - base_latent_num_frames - 1)
867
+ // (base_latent_num_frames - overlap_history_latent_frames)
868
+ + 1
869
+ )
870
+ else:
871
+ # Short video generation setup
872
+ n_iter = 1
873
+ base_latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
874
+
875
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
876
+
877
+ if last_image is not None:
878
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
879
+ device, dtype=torch.float32
880
+ )
881
+
882
+ # Loop through iterations (multiple iterations only for long videos)
883
+ for iter_idx in range(n_iter):
884
+ if is_long_video:
885
+ logger.debug(f"Processing iteration {iter_idx + 1}/{n_iter} for long video generation...")
886
+
887
+ num_channels_latents = self.vae.config.z_dim
888
+ latents, current_num_latent_frames, condition, prefix_video_latents_frames = self.prepare_latents(
889
+ image if iter_idx == 0 else None,
890
+ batch_size * num_videos_per_prompt,
891
+ num_channels_latents,
892
+ height,
893
+ width,
894
+ num_frames,
895
+ torch.float32,
896
+ device,
897
+ generator,
898
+ latents if iter_idx == 0 else None,
899
+ last_image,
900
+ video_latents=accumulated_latents, # Pass latents directly instead of decoded video
901
+ base_latent_num_frames=base_latent_num_frames if is_long_video else None,
902
+ causal_block_size=causal_block_size,
903
+ overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None,
904
+ long_video_iter=iter_idx if is_long_video else None,
905
+ )
906
+
907
+ if iter_idx == 0:
908
+ latents[:, :, :prefix_video_latents_frames, :, :] = condition[: (condition.shape[0] + 1) // 2].to(
909
+ transformer_dtype
910
+ )
911
+ else:
912
+ latents[:, :, :prefix_video_latents_frames, :, :] = condition.to(transformer_dtype)
913
+
914
+ if iter_idx == 0 and last_image is not None:
915
+ end_video_latents = condition[condition.shape[0] // 2 :].to(transformer_dtype)
916
+
917
+ if last_image is not None and iter_idx + 1 == n_iter:
918
+ latents = torch.cat([latents, end_video_latents], dim=2)
919
+ base_latent_num_frames += prefix_video_latents_frames
920
+ current_num_latent_frames += prefix_video_latents_frames
921
+
922
+ # 4. Prepare sample schedulers and timestep matrix
923
+ sample_schedulers = []
924
+ for _ in range(current_num_latent_frames):
925
+ sample_scheduler = deepcopy(self.scheduler)
926
+ sample_scheduler.set_timesteps(num_inference_steps, device=device)
927
+ sample_schedulers.append(sample_scheduler)
928
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
929
+ current_num_latent_frames,
930
+ timesteps,
931
+ base_latent_num_frames,
932
+ ar_step,
933
+ prefix_video_latents_frames,
934
+ causal_block_size,
935
+ )
936
+
937
+ if last_image is not None and iter_idx + 1 == n_iter:
938
+ step_matrix[:, -prefix_video_latents_frames:] = 0
939
+ step_update_mask[:, -prefix_video_latents_frames:] = False
940
+
941
+ # 6. Denoising loop
942
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
943
+ self._num_timesteps = len(step_matrix)
944
+
945
+ with self.progress_bar(total=len(step_matrix)) as progress_bar:
946
+ for i, t in enumerate(step_matrix):
947
+ if self.interrupt:
948
+ continue
949
+
950
+ self._current_timestep = t
951
+ valid_interval_start, valid_interval_end = valid_interval[i]
952
+ latent_model_input = (
953
+ latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone()
954
+ )
955
+ timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone()
956
+
957
+ if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames:
958
+ noise_factor = 0.001 * addnoise_condition
959
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = (
960
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
961
+ * (1.0 - noise_factor)
962
+ + torch.randn_like(
963
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
964
+ )
965
+ * noise_factor
966
+ )
967
+ timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
968
+
969
+ noise_pred = self.transformer(
970
+ hidden_states=latent_model_input,
971
+ timestep=timestep,
972
+ encoder_hidden_states=prompt_embeds,
973
+ enable_diffusion_forcing=True,
974
+ fps=fps_embeds,
975
+ attention_kwargs=attention_kwargs,
976
+ return_dict=False,
977
+ )[0]
978
+ if self.do_classifier_free_guidance:
979
+ noise_uncond = self.transformer(
980
+ hidden_states=latent_model_input,
981
+ timestep=timestep,
982
+ encoder_hidden_states=negative_prompt_embeds,
983
+ enable_diffusion_forcing=True,
984
+ fps=fps_embeds,
985
+ attention_kwargs=attention_kwargs,
986
+ return_dict=False,
987
+ )[0]
988
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
989
+
990
+ update_mask_i = step_update_mask[i]
991
+ for idx in range(valid_interval_start, valid_interval_end):
992
+ if update_mask_i[idx].item():
993
+ latents[:, :, idx, :, :] = sample_schedulers[idx].step(
994
+ noise_pred[:, :, idx - valid_interval_start, :, :],
995
+ t[idx],
996
+ latents[:, :, idx, :, :],
997
+ return_dict=False,
998
+ )[0]
999
+
1000
+ if callback_on_step_end is not None:
1001
+ callback_kwargs = {}
1002
+ for k in callback_on_step_end_tensor_inputs:
1003
+ callback_kwargs[k] = locals()[k]
1004
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1005
+
1006
+ latents = callback_outputs.pop("latents", latents)
1007
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1008
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1009
+
1010
+ # call the callback, if provided
1011
+ if i == len(step_matrix) - 1 or (
1012
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1013
+ ):
1014
+ progress_bar.update()
1015
+
1016
+ if XLA_AVAILABLE:
1017
+ xm.mark_step()
1018
+
1019
+ # Handle latent accumulation for long videos or use the current latents for short videos
1020
+ if is_long_video:
1021
+ if accumulated_latents is None:
1022
+ accumulated_latents = latents
1023
+ else:
1024
+ # Keep overlap frames for conditioning but don't include them in final output
1025
+ accumulated_latents = torch.cat(
1026
+ [accumulated_latents, latents[:, :, overlap_history_latent_frames:]],
1027
+ dim=2,
1028
+ )
1029
+
1030
+ if is_long_video:
1031
+ latents = accumulated_latents
1032
+
1033
+ self._current_timestep = None
1034
+
1035
+ # Final decoding step - convert latents to pixels
1036
+ if not output_type == "latent":
1037
+ if last_image is not None:
1038
+ latents = latents[:, :, :-prefix_video_latents_frames, :, :].to(self.vae.dtype)
1039
+ latents_mean = (
1040
+ torch.tensor(self.vae.config.latents_mean)
1041
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
1042
+ .to(latents.device, latents.dtype)
1043
+ )
1044
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
1045
+ latents.device, latents.dtype
1046
+ )
1047
+ latents = latents / latents_std + latents_mean
1048
+ video = self.vae.decode(latents, return_dict=False)[0]
1049
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
1050
+ else:
1051
+ video = latents
1052
+
1053
+ # Offload all models
1054
+ self.maybe_free_model_hooks()
1055
+
1056
+ if not return_dict:
1057
+ return (video,)
1058
+
1059
+ return SkyReelsV2PipelineOutput(frames=video)