diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1063 @@
1
+ # Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ import inspect
17
+ import math
18
+ import re
19
+ from copy import deepcopy
20
+ from typing import Any, Callable, Dict, List, Optional, Union
21
+
22
+ import ftfy
23
+ import torch
24
+ from PIL import Image
25
+ from transformers import AutoTokenizer, UMT5EncoderModel
26
+
27
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
28
+ from ...loaders import SkyReelsV2LoraLoaderMixin
29
+ from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
30
+ from ...schedulers import UniPCMultistepScheduler
31
+ from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
32
+ from ...utils.torch_utils import randn_tensor
33
+ from ...video_processor import VideoProcessor
34
+ from ..pipeline_utils import DiffusionPipeline
35
+ from .pipeline_output import SkyReelsV2PipelineOutput
36
+
37
+
38
+ if is_torch_xla_available():
39
+ import torch_xla.core.xla_model as xm
40
+
41
+ XLA_AVAILABLE = True
42
+ else:
43
+ XLA_AVAILABLE = False
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+ if is_ftfy_available():
48
+ import ftfy
49
+
50
+
51
+ EXAMPLE_DOC_STRING = """\
52
+ Examples:
53
+ ```py
54
+ >>> import torch
55
+ >>> from diffusers import (
56
+ ... SkyReelsV2DiffusionForcingVideoToVideoPipeline,
57
+ ... UniPCMultistepScheduler,
58
+ ... AutoencoderKLWan,
59
+ ... )
60
+ >>> from diffusers.utils import export_to_video
61
+
62
+ >>> # Load the pipeline
63
+ >>> # Available models:
64
+ >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
65
+ >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
66
+ >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
67
+ >>> vae = AutoencoderKLWan.from_pretrained(
68
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
69
+ ... subfolder="vae",
70
+ ... torch_dtype=torch.float32,
71
+ ... )
72
+ >>> pipe = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(
73
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
74
+ ... vae=vae,
75
+ ... torch_dtype=torch.bfloat16,
76
+ ... )
77
+ >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
78
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
79
+ >>> pipe = pipe.to("cuda")
80
+
81
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
82
+
83
+ >>> output = pipe(
84
+ ... prompt=prompt,
85
+ ... num_inference_steps=50,
86
+ ... height=544,
87
+ ... width=960,
88
+ ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V
89
+ ... num_frames=97,
90
+ ... ar_step=0, # Controls asynchronous inference (0 for synchronous mode)
91
+ ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos
92
+ ... addnoise_condition=20, # Improves consistency in long video generation
93
+ ... ).frames[0]
94
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
95
+ ```
96
+ """
97
+
98
+
99
+ def basic_clean(text):
100
+ text = ftfy.fix_text(text)
101
+ text = html.unescape(html.unescape(text))
102
+ return text.strip()
103
+
104
+
105
+ def whitespace_clean(text):
106
+ text = re.sub(r"\s+", " ", text)
107
+ text = text.strip()
108
+ return text
109
+
110
+
111
+ def prompt_clean(text):
112
+ text = whitespace_clean(basic_clean(text))
113
+ return text
114
+
115
+
116
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
117
+ def retrieve_timesteps(
118
+ scheduler,
119
+ num_inference_steps: Optional[int] = None,
120
+ device: Optional[Union[str, torch.device]] = None,
121
+ timesteps: Optional[List[int]] = None,
122
+ sigmas: Optional[List[float]] = None,
123
+ **kwargs,
124
+ ):
125
+ r"""
126
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
127
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
128
+
129
+ Args:
130
+ scheduler (`SchedulerMixin`):
131
+ The scheduler to get timesteps from.
132
+ num_inference_steps (`int`):
133
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
134
+ must be `None`.
135
+ device (`str` or `torch.device`, *optional*):
136
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
137
+ timesteps (`List[int]`, *optional*):
138
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
139
+ `num_inference_steps` and `sigmas` must be `None`.
140
+ sigmas (`List[float]`, *optional*):
141
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
142
+ `num_inference_steps` and `timesteps` must be `None`.
143
+
144
+ Returns:
145
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
146
+ second element is the number of inference steps.
147
+ """
148
+ if timesteps is not None and sigmas is not None:
149
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
150
+ if timesteps is not None:
151
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
152
+ if not accepts_timesteps:
153
+ raise ValueError(
154
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
155
+ f" timestep schedules. Please check whether you are using the correct scheduler."
156
+ )
157
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
158
+ timesteps = scheduler.timesteps
159
+ num_inference_steps = len(timesteps)
160
+ elif sigmas is not None:
161
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
162
+ if not accept_sigmas:
163
+ raise ValueError(
164
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
165
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
166
+ )
167
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
168
+ timesteps = scheduler.timesteps
169
+ num_inference_steps = len(timesteps)
170
+ else:
171
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
172
+ timesteps = scheduler.timesteps
173
+ return timesteps, num_inference_steps
174
+
175
+
176
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
177
+ def retrieve_latents(
178
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
179
+ ):
180
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
181
+ return encoder_output.latent_dist.sample(generator)
182
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
183
+ return encoder_output.latent_dist.mode()
184
+ elif hasattr(encoder_output, "latents"):
185
+ return encoder_output.latents
186
+ else:
187
+ raise AttributeError("Could not access latents of provided encoder_output")
188
+
189
+
190
+ class SkyReelsV2DiffusionForcingVideoToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
191
+ """
192
+ Pipeline for Video-to-Video (v2v) generation using SkyReels-V2 with diffusion forcing.
193
+
194
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
195
+ implemented for all pipelines (downloading, saving, running on a specific device, etc.).
196
+
197
+ Args:
198
+ tokenizer ([`AutoTokenizer`]):
199
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
200
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
201
+ text_encoder ([`UMT5EncoderModel`]):
202
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
203
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
204
+ transformer ([`SkyReelsV2Transformer3DModel`]):
205
+ Conditional Transformer to denoise the encoded image latents.
206
+ scheduler ([`UniPCMultistepScheduler`]):
207
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
208
+ vae ([`AutoencoderKLWan`]):
209
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
210
+ """
211
+
212
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
213
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
214
+
215
+ def __init__(
216
+ self,
217
+ tokenizer: AutoTokenizer,
218
+ text_encoder: UMT5EncoderModel,
219
+ transformer: SkyReelsV2Transformer3DModel,
220
+ vae: AutoencoderKLWan,
221
+ scheduler: UniPCMultistepScheduler,
222
+ ):
223
+ super().__init__()
224
+
225
+ self.register_modules(
226
+ vae=vae,
227
+ text_encoder=text_encoder,
228
+ tokenizer=tokenizer,
229
+ transformer=transformer,
230
+ scheduler=scheduler,
231
+ )
232
+
233
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
234
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
235
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
236
+
237
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
238
+ def _get_t5_prompt_embeds(
239
+ self,
240
+ prompt: Union[str, List[str]] = None,
241
+ num_videos_per_prompt: int = 1,
242
+ max_sequence_length: int = 226,
243
+ device: Optional[torch.device] = None,
244
+ dtype: Optional[torch.dtype] = None,
245
+ ):
246
+ device = device or self._execution_device
247
+ dtype = dtype or self.text_encoder.dtype
248
+
249
+ prompt = [prompt] if isinstance(prompt, str) else prompt
250
+ prompt = [prompt_clean(u) for u in prompt]
251
+ batch_size = len(prompt)
252
+
253
+ text_inputs = self.tokenizer(
254
+ prompt,
255
+ padding="max_length",
256
+ max_length=max_sequence_length,
257
+ truncation=True,
258
+ add_special_tokens=True,
259
+ return_attention_mask=True,
260
+ return_tensors="pt",
261
+ )
262
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
263
+ seq_lens = mask.gt(0).sum(dim=1).long()
264
+
265
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
266
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
267
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
268
+ prompt_embeds = torch.stack(
269
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
270
+ )
271
+
272
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
273
+ _, seq_len, _ = prompt_embeds.shape
274
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
275
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
276
+
277
+ return prompt_embeds
278
+
279
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
280
+ def encode_prompt(
281
+ self,
282
+ prompt: Union[str, List[str]],
283
+ negative_prompt: Optional[Union[str, List[str]]] = None,
284
+ do_classifier_free_guidance: bool = True,
285
+ num_videos_per_prompt: int = 1,
286
+ prompt_embeds: Optional[torch.Tensor] = None,
287
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
288
+ max_sequence_length: int = 226,
289
+ device: Optional[torch.device] = None,
290
+ dtype: Optional[torch.dtype] = None,
291
+ ):
292
+ r"""
293
+ Encodes the prompt into text encoder hidden states.
294
+
295
+ Args:
296
+ prompt (`str` or `List[str]`, *optional*):
297
+ prompt to be encoded
298
+ negative_prompt (`str` or `List[str]`, *optional*):
299
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
300
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
301
+ less than `1`).
302
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
303
+ Whether to use classifier free guidance or not.
304
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
305
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
306
+ prompt_embeds (`torch.Tensor`, *optional*):
307
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
308
+ provided, text embeddings will be generated from `prompt` input argument.
309
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
310
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
311
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
312
+ argument.
313
+ device: (`torch.device`, *optional*):
314
+ torch device
315
+ dtype: (`torch.dtype`, *optional*):
316
+ torch dtype
317
+ """
318
+ device = device or self._execution_device
319
+
320
+ prompt = [prompt] if isinstance(prompt, str) else prompt
321
+ if prompt is not None:
322
+ batch_size = len(prompt)
323
+ else:
324
+ batch_size = prompt_embeds.shape[0]
325
+
326
+ if prompt_embeds is None:
327
+ prompt_embeds = self._get_t5_prompt_embeds(
328
+ prompt=prompt,
329
+ num_videos_per_prompt=num_videos_per_prompt,
330
+ max_sequence_length=max_sequence_length,
331
+ device=device,
332
+ dtype=dtype,
333
+ )
334
+
335
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
336
+ negative_prompt = negative_prompt or ""
337
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
338
+
339
+ if prompt is not None and type(prompt) is not type(negative_prompt):
340
+ raise TypeError(
341
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
342
+ f" {type(prompt)}."
343
+ )
344
+ elif batch_size != len(negative_prompt):
345
+ raise ValueError(
346
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
347
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
348
+ " the batch size of `prompt`."
349
+ )
350
+
351
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
352
+ prompt=negative_prompt,
353
+ num_videos_per_prompt=num_videos_per_prompt,
354
+ max_sequence_length=max_sequence_length,
355
+ device=device,
356
+ dtype=dtype,
357
+ )
358
+
359
+ return prompt_embeds, negative_prompt_embeds
360
+
361
+ def check_inputs(
362
+ self,
363
+ prompt,
364
+ negative_prompt,
365
+ height,
366
+ width,
367
+ video=None,
368
+ latents=None,
369
+ prompt_embeds=None,
370
+ negative_prompt_embeds=None,
371
+ callback_on_step_end_tensor_inputs=None,
372
+ overlap_history=None,
373
+ num_frames=None,
374
+ base_num_frames=None,
375
+ ):
376
+ if height % 16 != 0 or width % 16 != 0:
377
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
378
+
379
+ if callback_on_step_end_tensor_inputs is not None and not all(
380
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
381
+ ):
382
+ raise ValueError(
383
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
384
+ )
385
+
386
+ if prompt is not None and prompt_embeds is not None:
387
+ raise ValueError(
388
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
389
+ " only forward one of the two."
390
+ )
391
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
392
+ raise ValueError(
393
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
394
+ " only forward one of the two."
395
+ )
396
+ elif prompt is None and prompt_embeds is None:
397
+ raise ValueError(
398
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
399
+ )
400
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
401
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
402
+ elif negative_prompt is not None and (
403
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
404
+ ):
405
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
406
+
407
+ if video is not None and latents is not None:
408
+ raise ValueError("Only one of `video` or `latents` should be provided")
409
+
410
+ if num_frames > base_num_frames and overlap_history is None:
411
+ raise ValueError(
412
+ "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. "
413
+ "Please specify a value for `overlap_history`. Recommended values are 17 or 37."
414
+ )
415
+
416
+ def prepare_latents(
417
+ self,
418
+ video: torch.Tensor,
419
+ batch_size: int = 1,
420
+ num_channels_latents: int = 16,
421
+ height: int = 480,
422
+ width: int = 832,
423
+ num_frames: int = 97,
424
+ dtype: Optional[torch.dtype] = None,
425
+ device: Optional[torch.device] = None,
426
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
427
+ latents: Optional[torch.Tensor] = None,
428
+ video_latents: Optional[torch.Tensor] = None,
429
+ base_latent_num_frames: Optional[int] = None,
430
+ overlap_history: Optional[int] = None,
431
+ causal_block_size: Optional[int] = None,
432
+ overlap_history_latent_frames: Optional[int] = None,
433
+ long_video_iter: Optional[int] = None,
434
+ ) -> torch.Tensor:
435
+ if latents is not None:
436
+ return latents.to(device=device, dtype=dtype)
437
+
438
+ num_latent_frames = (
439
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.shape[2]
440
+ )
441
+ latent_height = height // self.vae_scale_factor_spatial
442
+ latent_width = width // self.vae_scale_factor_spatial
443
+
444
+ if long_video_iter == 0:
445
+ prefix_video_latents = [
446
+ retrieve_latents(
447
+ self.vae.encode(
448
+ vid.unsqueeze(0)[:, :, -overlap_history:] if vid.dim() == 4 else vid[:, :, -overlap_history:]
449
+ ),
450
+ sample_mode="argmax",
451
+ )
452
+ for vid in video
453
+ ]
454
+ prefix_video_latents = torch.cat(prefix_video_latents, dim=0).to(dtype)
455
+
456
+ latents_mean = (
457
+ torch.tensor(self.vae.config.latents_mean)
458
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
459
+ .to(device, self.vae.dtype)
460
+ )
461
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
462
+ device, self.vae.dtype
463
+ )
464
+ prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std
465
+ else:
466
+ prefix_video_latents = video_latents[:, :, -overlap_history_latent_frames:]
467
+
468
+ if prefix_video_latents.shape[2] % causal_block_size != 0:
469
+ truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size
470
+ logger.warning(
471
+ f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. "
472
+ f"This truncation ensures compatibility with the causal block size, which is required for proper processing. "
473
+ f"However, it may slightly affect the continuity of the generated video at the truncation boundary."
474
+ )
475
+ prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents]
476
+ prefix_video_latents_frames = prefix_video_latents.shape[2]
477
+
478
+ finished_frame_num = (
479
+ long_video_iter * (base_latent_num_frames - overlap_history_latent_frames) + overlap_history_latent_frames
480
+ )
481
+ left_frame_num = num_latent_frames - finished_frame_num
482
+ num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames)
483
+
484
+ shape = (
485
+ batch_size,
486
+ num_channels_latents,
487
+ num_latent_frames,
488
+ latent_height,
489
+ latent_width,
490
+ )
491
+ if isinstance(generator, list) and len(generator) != batch_size:
492
+ raise ValueError(
493
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
494
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
495
+ )
496
+
497
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
498
+
499
+ return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_frames
500
+
501
+ # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix
502
+ def generate_timestep_matrix(
503
+ self,
504
+ num_latent_frames: int,
505
+ step_template: torch.Tensor,
506
+ base_num_latent_frames: int,
507
+ ar_step: int = 5,
508
+ num_pre_ready: int = 0,
509
+ causal_block_size: int = 1,
510
+ shrink_interval_with_mask: bool = False,
511
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
512
+ """
513
+ This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule
514
+ across temporal frames. It supports both synchronous and asynchronous generation modes:
515
+
516
+ **Synchronous Mode** (ar_step=0, causal_block_size=1):
517
+ - All frames are denoised simultaneously at each timestep
518
+ - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0]
519
+ - Simpler but may have less temporal consistency for long videos
520
+
521
+ **Asynchronous Mode** (ar_step>0, causal_block_size>1):
522
+ - Frames are grouped into causal blocks and processed block/chunk-wise
523
+ - Each block is denoised in a staggered pattern creating a "denoising wave"
524
+ - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps
525
+ - Creates stronger temporal dependencies and better consistency
526
+
527
+ Args:
528
+ num_latent_frames (int): Total number of latent frames to generate
529
+ step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0])
530
+ base_num_latent_frames (int): Maximum frames the model can process in one forward pass
531
+ ar_step (int, optional): Autoregressive step size for temporal lag.
532
+ 0 = synchronous, >0 = asynchronous. Defaults to 5.
533
+ num_pre_ready (int, optional):
534
+ Number of frames already denoised (e.g., from prefix in a video2video task).
535
+ Defaults to 0.
536
+ causal_block_size (int, optional): Number of frames processed as a causal block.
537
+ Defaults to 1.
538
+ shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals.
539
+ Defaults to False.
540
+
541
+ Returns:
542
+ tuple containing:
543
+ - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape:
544
+ [num_iterations, num_latent_frames]
545
+ - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations,
546
+ num_latent_frames]
547
+ - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape:
548
+ [num_iterations, num_latent_frames]
549
+ - valid_interval (list[tuple]): List of (start, end) intervals for each iteration
550
+
551
+ Raises:
552
+ ValueError: If ar_step is too small for the given configuration
553
+ """
554
+ # Initialize lists to store the scheduling matrices and metadata
555
+ step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration
556
+ update_mask, valid_interval = [], [] # Will store update masks and processing intervals
557
+
558
+ # Calculate total number of denoising iterations (add 1 for initial noise state)
559
+ num_iterations = len(step_template) + 1
560
+
561
+ # Convert frame counts to block counts for causal processing
562
+ # Each block contains causal_block_size frames that are processed together
563
+ # E.g.: 25 frames ÷ 5 = 5 blocks total
564
+ num_blocks = num_latent_frames // causal_block_size
565
+ base_num_blocks = base_num_latent_frames // causal_block_size
566
+
567
+ # Validate ar_step is sufficient for the given configuration
568
+ # In asynchronous mode, we need enough timesteps to create the staggered pattern
569
+ if base_num_blocks < num_blocks:
570
+ min_ar_step = len(step_template) / base_num_blocks
571
+ if ar_step < min_ar_step:
572
+ raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting")
573
+
574
+ # Extend step_template with boundary values for easier indexing
575
+ # 999: dummy value for counter starting from 1
576
+ # 0: final timestep (completely denoised)
577
+ step_template = torch.cat(
578
+ [
579
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
580
+ step_template.long(),
581
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
582
+ ]
583
+ )
584
+
585
+ # Initialize the previous row state (tracks denoising progress for each block)
586
+ # 0 means not started, num_iterations means fully denoised
587
+ pre_row = torch.zeros(num_blocks, dtype=torch.long)
588
+
589
+ # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state
590
+ if num_pre_ready > 0:
591
+ pre_row[: num_pre_ready // causal_block_size] = num_iterations
592
+
593
+ # Main loop: Generate denoising schedule until all frames are fully denoised
594
+ while not torch.all(pre_row >= (num_iterations - 1)):
595
+ # Create new row representing the next denoising step
596
+ new_row = torch.zeros(num_blocks, dtype=torch.long)
597
+
598
+ # Apply diffusion forcing logic for each block
599
+ for i in range(num_blocks):
600
+ if i == 0 or pre_row[i - 1] >= (
601
+ num_iterations - 1
602
+ ): # the first frame or the last frame is completely denoised
603
+ new_row[i] = pre_row[i] + 1
604
+ else:
605
+ # Asynchronous mode: lag behind previous block by ar_step timesteps
606
+ # This creates the "diffusion forcing" staggered pattern
607
+ new_row[i] = new_row[i - 1] - ar_step
608
+
609
+ # Clamp values to valid range [0, num_iterations]
610
+ new_row = new_row.clamp(0, num_iterations)
611
+
612
+ # Create update mask: True for blocks that need denoising update at this iteration
613
+ # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations)
614
+ # Final state example: [False, ..., False, True, True, True, True, True]
615
+ # where first 20 frames are done (False) and last 5 frames still need updates (True)
616
+ update_mask.append((new_row != pre_row) & (new_row != num_iterations))
617
+
618
+ # Store the iteration state
619
+ step_index.append(new_row) # Index into step_template
620
+ step_matrix.append(step_template[new_row]) # Actual timestep values
621
+ pre_row = new_row # Update for next iteration
622
+
623
+ # For videos longer than model capacity, we process in sliding windows
624
+ terminal_flag = base_num_blocks
625
+
626
+ # Optional optimization: shrink interval based on first update mask
627
+ if shrink_interval_with_mask:
628
+ idx_sequence = torch.arange(num_blocks, dtype=torch.int64)
629
+ update_mask = update_mask[0]
630
+ update_mask_idx = idx_sequence[update_mask]
631
+ last_update_idx = update_mask_idx[-1].item()
632
+ terminal_flag = last_update_idx + 1
633
+
634
+ # Each interval defines which frames to process in the current forward pass
635
+ for curr_mask in update_mask:
636
+ # Extend terminal flag if current mask has updates beyond current terminal
637
+ if terminal_flag < num_blocks and curr_mask[terminal_flag]:
638
+ terminal_flag += 1
639
+ # Create interval: [start, end) where start ensures we don't exceed model capacity
640
+ valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag))
641
+
642
+ # Convert lists to tensors for efficient processing
643
+ step_update_mask = torch.stack(update_mask, dim=0)
644
+ step_index = torch.stack(step_index, dim=0)
645
+ step_matrix = torch.stack(step_matrix, dim=0)
646
+
647
+ # Each block's schedule is replicated to all frames within that block
648
+ if causal_block_size > 1:
649
+ # Expand each block to causal_block_size frames
650
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
651
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
652
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
653
+ # Scale intervals from block-level to frame-level
654
+ valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval]
655
+
656
+ return step_matrix, step_index, step_update_mask, valid_interval
657
+
658
+ @property
659
+ def guidance_scale(self):
660
+ return self._guidance_scale
661
+
662
+ @property
663
+ def do_classifier_free_guidance(self):
664
+ return self._guidance_scale > 1.0
665
+
666
+ @property
667
+ def num_timesteps(self):
668
+ return self._num_timesteps
669
+
670
+ @property
671
+ def current_timestep(self):
672
+ return self._current_timestep
673
+
674
+ @property
675
+ def interrupt(self):
676
+ return self._interrupt
677
+
678
+ @property
679
+ def attention_kwargs(self):
680
+ return self._attention_kwargs
681
+
682
+ @torch.no_grad()
683
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
684
+ def __call__(
685
+ self,
686
+ video: List[Image.Image],
687
+ prompt: Union[str, List[str]] = None,
688
+ negative_prompt: Union[str, List[str]] = None,
689
+ height: int = 544,
690
+ width: int = 960,
691
+ num_frames: int = 120,
692
+ num_inference_steps: int = 50,
693
+ guidance_scale: float = 6.0,
694
+ num_videos_per_prompt: Optional[int] = 1,
695
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
696
+ latents: Optional[torch.Tensor] = None,
697
+ prompt_embeds: Optional[torch.Tensor] = None,
698
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
699
+ output_type: Optional[str] = "np",
700
+ return_dict: bool = True,
701
+ attention_kwargs: Optional[Dict[str, Any]] = None,
702
+ callback_on_step_end: Optional[
703
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
704
+ ] = None,
705
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
706
+ max_sequence_length: int = 512,
707
+ overlap_history: Optional[int] = None,
708
+ addnoise_condition: float = 0,
709
+ base_num_frames: int = 97,
710
+ ar_step: int = 0,
711
+ causal_block_size: Optional[int] = None,
712
+ fps: int = 24,
713
+ ):
714
+ r"""
715
+ The call function to the pipeline for generation.
716
+
717
+ Args:
718
+ video (`List[Image.Image]`):
719
+ The video to guide the video generation.
720
+ prompt (`str` or `List[str]`, *optional*):
721
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
722
+ instead.
723
+ negative_prompt (`str` or `List[str]`, *optional*):
724
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
725
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
726
+ less than `1`).
727
+ height (`int`, defaults to `544`):
728
+ The height of the generated video.
729
+ width (`int`, defaults to `960`):
730
+ The width of the generated video.
731
+ num_frames (`int`, defaults to `120`):
732
+ The number of frames in the generated video.
733
+ num_inference_steps (`int`, defaults to `50`):
734
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
735
+ expense of slower inference.
736
+ guidance_scale (`float`, defaults to `6.0`):
737
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
738
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
739
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
740
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
741
+ usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
742
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
743
+ The number of images to generate per prompt.
744
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
745
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
746
+ generation deterministic.
747
+ latents (`torch.Tensor`, *optional*):
748
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
749
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
750
+ tensor is generated by sampling using the supplied random `generator`.
751
+ prompt_embeds (`torch.Tensor`, *optional*):
752
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
753
+ provided, text embeddings are generated from the `prompt` input argument.
754
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
755
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
756
+ provided, text embeddings are generated from the `negative_prompt` input argument.
757
+ output_type (`str`, *optional*, defaults to `"np"`):
758
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
759
+ return_dict (`bool`, *optional*, defaults to `True`):
760
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
761
+ attention_kwargs (`dict`, *optional*):
762
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
763
+ `self.processor` in
764
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
765
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
766
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
767
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
768
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
769
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
770
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
771
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
772
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
773
+ `._callback_tensor_inputs` attribute of your pipeline class.
774
+ max_sequence_length (`int`, *optional*, defaults to `512`):
775
+ The maximum sequence length of the prompt.
776
+ overlap_history (`int`, *optional*, defaults to `None`):
777
+ Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes
778
+ short video generation mode, and no overlap is applied. 17 and 37 are recommended to set.
779
+ addnoise_condition (`float`, *optional*, defaults to `0`):
780
+ This is used to help smooth the long video generation by adding some noise to the clean condition. Too
781
+ large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger
782
+ ones, but it is recommended to not exceed 50.
783
+ base_num_frames (`int`, *optional*, defaults to `97`):
784
+ 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**)
785
+ ar_step (`int`, *optional*, defaults to `0`):
786
+ Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous
787
+ inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed
788
+ to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole
789
+ sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous
790
+ inference may improve the instruction following and visual consistent performance.
791
+ causal_block_size (`int`, *optional*, defaults to `None`):
792
+ The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step >
793
+ 0)
794
+ fps (`int`, *optional*, defaults to `24`):
795
+ Frame rate of the generated video
796
+
797
+ Examples:
798
+
799
+ Returns:
800
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
801
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
802
+ where the first element is a list with the generated images and the second element is a list of `bool`s
803
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
804
+ """
805
+
806
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
807
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
808
+
809
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
810
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
811
+ num_videos_per_prompt = 1
812
+
813
+ # 1. Check inputs. Raise error if not correct
814
+ self.check_inputs(
815
+ prompt,
816
+ negative_prompt,
817
+ height,
818
+ width,
819
+ video,
820
+ latents,
821
+ prompt_embeds,
822
+ negative_prompt_embeds,
823
+ callback_on_step_end_tensor_inputs,
824
+ overlap_history,
825
+ num_frames,
826
+ base_num_frames,
827
+ )
828
+
829
+ if addnoise_condition > 60:
830
+ logger.warning(
831
+ f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended."
832
+ )
833
+
834
+ if num_frames % self.vae_scale_factor_temporal != 1:
835
+ logger.warning(
836
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
837
+ )
838
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
839
+ num_frames = max(num_frames, 1)
840
+
841
+ self._guidance_scale = guidance_scale
842
+ self._attention_kwargs = attention_kwargs
843
+ self._current_timestep = None
844
+ self._interrupt = False
845
+
846
+ device = self._execution_device
847
+
848
+ # 2. Define call parameters
849
+ if prompt is not None and isinstance(prompt, str):
850
+ batch_size = 1
851
+ elif prompt is not None and isinstance(prompt, list):
852
+ batch_size = len(prompt)
853
+ else:
854
+ batch_size = prompt_embeds.shape[0]
855
+
856
+ # 3. Encode input prompt
857
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
858
+ prompt=prompt,
859
+ negative_prompt=negative_prompt,
860
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
861
+ num_videos_per_prompt=num_videos_per_prompt,
862
+ prompt_embeds=prompt_embeds,
863
+ negative_prompt_embeds=negative_prompt_embeds,
864
+ max_sequence_length=max_sequence_length,
865
+ device=device,
866
+ )
867
+
868
+ transformer_dtype = self.transformer.dtype
869
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
870
+ if negative_prompt_embeds is not None:
871
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
872
+
873
+ # 4. Prepare timesteps
874
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
875
+ timesteps = self.scheduler.timesteps
876
+
877
+ if latents is None:
878
+ video_original = self.video_processor.preprocess_video(video, height=height, width=width).to(
879
+ device, dtype=torch.float32
880
+ )
881
+
882
+ if causal_block_size is None:
883
+ causal_block_size = self.transformer.config.num_frame_per_block
884
+ else:
885
+ self.transformer._set_ar_attention(causal_block_size)
886
+
887
+ fps_embeds = [fps] * prompt_embeds.shape[0]
888
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
889
+
890
+ # Long video generation
891
+ accumulated_latents = None
892
+ overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1
893
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
894
+ base_latent_num_frames = (
895
+ (base_num_frames - 1) // self.vae_scale_factor_temporal + 1
896
+ if base_num_frames is not None
897
+ else num_latent_frames
898
+ )
899
+ n_iter = (
900
+ 1
901
+ + (num_latent_frames - base_latent_num_frames - 1)
902
+ // (base_latent_num_frames - overlap_history_latent_frames)
903
+ + 1
904
+ )
905
+ for long_video_iter in range(n_iter):
906
+ logger.debug(f"Processing iteration {long_video_iter + 1}/{n_iter} for long video generation...")
907
+
908
+ # 5. Prepare latent variables
909
+ num_channels_latents = self.transformer.config.in_channels
910
+ latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_frames = (
911
+ self.prepare_latents(
912
+ video_original,
913
+ batch_size * num_videos_per_prompt,
914
+ num_channels_latents,
915
+ height,
916
+ width,
917
+ num_frames,
918
+ torch.float32,
919
+ device,
920
+ generator,
921
+ latents if long_video_iter == 0 else None,
922
+ video_latents=accumulated_latents, # Pass latents directly instead of decoded video
923
+ overlap_history=overlap_history,
924
+ base_latent_num_frames=base_latent_num_frames,
925
+ causal_block_size=causal_block_size,
926
+ overlap_history_latent_frames=overlap_history_latent_frames,
927
+ long_video_iter=long_video_iter,
928
+ )
929
+ )
930
+
931
+ if prefix_video_latents_frames > 0:
932
+ latents[:, :, :prefix_video_latents_frames, :, :] = prefix_video_latents.to(transformer_dtype)
933
+
934
+ # 4. Prepare sample schedulers and timestep matrix
935
+ sample_schedulers = []
936
+ for _ in range(current_num_latent_frames):
937
+ sample_scheduler = deepcopy(self.scheduler)
938
+ sample_scheduler.set_timesteps(num_inference_steps, device=device)
939
+ sample_schedulers.append(sample_scheduler)
940
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
941
+ current_num_latent_frames,
942
+ timesteps,
943
+ current_num_latent_frames,
944
+ ar_step,
945
+ prefix_video_latents_frames,
946
+ causal_block_size,
947
+ )
948
+
949
+ # 6. Denoising loop
950
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
951
+ self._num_timesteps = len(step_matrix)
952
+
953
+ with self.progress_bar(total=len(step_matrix)) as progress_bar:
954
+ for i, t in enumerate(step_matrix):
955
+ if self.interrupt:
956
+ continue
957
+
958
+ self._current_timestep = t
959
+ valid_interval_start, valid_interval_end = valid_interval[i]
960
+ latent_model_input = (
961
+ latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone()
962
+ )
963
+ timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone()
964
+
965
+ if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames:
966
+ noise_factor = 0.001 * addnoise_condition
967
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = (
968
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
969
+ * (1.0 - noise_factor)
970
+ + torch.randn_like(
971
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
972
+ )
973
+ * noise_factor
974
+ )
975
+ timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
976
+
977
+ noise_pred = self.transformer(
978
+ hidden_states=latent_model_input,
979
+ timestep=timestep,
980
+ encoder_hidden_states=prompt_embeds,
981
+ enable_diffusion_forcing=True,
982
+ fps=fps_embeds,
983
+ attention_kwargs=attention_kwargs,
984
+ return_dict=False,
985
+ )[0]
986
+ if self.do_classifier_free_guidance:
987
+ noise_uncond = self.transformer(
988
+ hidden_states=latent_model_input,
989
+ timestep=timestep,
990
+ encoder_hidden_states=negative_prompt_embeds,
991
+ enable_diffusion_forcing=True,
992
+ fps=fps_embeds,
993
+ attention_kwargs=attention_kwargs,
994
+ return_dict=False,
995
+ )[0]
996
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
997
+
998
+ update_mask_i = step_update_mask[i]
999
+ for idx in range(valid_interval_start, valid_interval_end):
1000
+ if update_mask_i[idx].item():
1001
+ latents[:, :, idx, :, :] = sample_schedulers[idx].step(
1002
+ noise_pred[:, :, idx - valid_interval_start, :, :],
1003
+ t[idx],
1004
+ latents[:, :, idx, :, :],
1005
+ return_dict=False,
1006
+ )[0]
1007
+
1008
+ if callback_on_step_end is not None:
1009
+ callback_kwargs = {}
1010
+ for k in callback_on_step_end_tensor_inputs:
1011
+ callback_kwargs[k] = locals()[k]
1012
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1013
+
1014
+ latents = callback_outputs.pop("latents", latents)
1015
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1016
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1017
+
1018
+ # call the callback, if provided
1019
+ if i == len(step_matrix) - 1 or (
1020
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1021
+ ):
1022
+ progress_bar.update()
1023
+
1024
+ if XLA_AVAILABLE:
1025
+ xm.mark_step()
1026
+
1027
+ if accumulated_latents is None:
1028
+ accumulated_latents = latents
1029
+ else:
1030
+ # Keep overlap frames for conditioning but don't include them in final output
1031
+ accumulated_latents = torch.cat(
1032
+ [accumulated_latents, latents[:, :, overlap_history_latent_frames:]], dim=2
1033
+ )
1034
+
1035
+ latents = accumulated_latents
1036
+
1037
+ self._current_timestep = None
1038
+
1039
+ # Final decoding step - convert latents to pixels
1040
+ if not output_type == "latent":
1041
+ latents = latents.to(self.vae.dtype)
1042
+ latents_mean = (
1043
+ torch.tensor(self.vae.config.latents_mean)
1044
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
1045
+ .to(latents.device, latents.dtype)
1046
+ )
1047
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
1048
+ latents.device, latents.dtype
1049
+ )
1050
+ latents = latents / latents_std + latents_mean
1051
+ video_generated = self.vae.decode(latents, return_dict=False)[0]
1052
+ video = torch.cat([video_original, video_generated], dim=2)
1053
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
1054
+ else:
1055
+ video = latents
1056
+
1057
+ # Offload all models
1058
+ self.maybe_free_model_hooks()
1059
+
1060
+ if not return_dict:
1061
+ return (video,)
1062
+
1063
+ return SkyReelsV2PipelineOutput(frames=video)