diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,610 @@
1
+ # Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import regex as re
19
+ import torch
20
+ from transformers import AutoTokenizer, UMT5EncoderModel
21
+
22
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
23
+ from ...loaders import SkyReelsV2LoraLoaderMixin
24
+ from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
25
+ from ...schedulers import UniPCMultistepScheduler
26
+ from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
27
+ from ...utils.torch_utils import randn_tensor
28
+ from ...video_processor import VideoProcessor
29
+ from ..pipeline_utils import DiffusionPipeline
30
+ from .pipeline_output import SkyReelsV2PipelineOutput
31
+
32
+
33
+ if is_torch_xla_available():
34
+ import torch_xla.core.xla_model as xm
35
+
36
+ XLA_AVAILABLE = True
37
+ else:
38
+ XLA_AVAILABLE = False
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+ if is_ftfy_available():
43
+ import ftfy
44
+
45
+
46
+ EXAMPLE_DOC_STRING = """\
47
+ Examples:
48
+ ```py
49
+ >>> import torch
50
+ >>> from diffusers import (
51
+ ... SkyReelsV2Pipeline,
52
+ ... UniPCMultistepScheduler,
53
+ ... AutoencoderKLWan,
54
+ ... )
55
+ >>> from diffusers.utils import export_to_video
56
+
57
+ >>> # Load the pipeline
58
+ >>> # Available models:
59
+ >>> # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
60
+ >>> # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
61
+ >>> vae = AutoencoderKLWan.from_pretrained(
62
+ ... "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers",
63
+ ... subfolder="vae",
64
+ ... torch_dtype=torch.float32,
65
+ ... )
66
+ >>> pipe = SkyReelsV2Pipeline.from_pretrained(
67
+ ... "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers",
68
+ ... vae=vae,
69
+ ... torch_dtype=torch.bfloat16,
70
+ ... )
71
+ >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
72
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
73
+ >>> pipe = pipe.to("cuda")
74
+
75
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
76
+
77
+ >>> output = pipe(
78
+ ... prompt=prompt,
79
+ ... num_inference_steps=50,
80
+ ... height=544,
81
+ ... width=960,
82
+ ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V
83
+ ... num_frames=97,
84
+ ... ).frames[0]
85
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
86
+ ```
87
+ """
88
+
89
+
90
+ def basic_clean(text):
91
+ text = ftfy.fix_text(text)
92
+ text = html.unescape(html.unescape(text))
93
+ return text.strip()
94
+
95
+
96
+ def whitespace_clean(text):
97
+ text = re.sub(r"\s+", " ", text)
98
+ text = text.strip()
99
+ return text
100
+
101
+
102
+ def prompt_clean(text):
103
+ text = whitespace_clean(basic_clean(text))
104
+ return text
105
+
106
+
107
+ class SkyReelsV2Pipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
108
+ r"""
109
+ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2.
110
+
111
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
112
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
113
+
114
+ Args:
115
+ tokenizer ([`T5Tokenizer`]):
116
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
117
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
118
+ text_encoder ([`T5EncoderModel`]):
119
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
120
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
121
+ transformer ([`SkyReelsV2Transformer3DModel`]):
122
+ Conditional Transformer to denoise the input latents.
123
+ scheduler ([`UniPCMultistepScheduler`]):
124
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
125
+ vae ([`AutoencoderKLWan`]):
126
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
127
+ """
128
+
129
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
130
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
131
+
132
+ def __init__(
133
+ self,
134
+ tokenizer: AutoTokenizer,
135
+ text_encoder: UMT5EncoderModel,
136
+ transformer: SkyReelsV2Transformer3DModel,
137
+ vae: AutoencoderKLWan,
138
+ scheduler: UniPCMultistepScheduler,
139
+ ):
140
+ super().__init__()
141
+
142
+ self.register_modules(
143
+ vae=vae,
144
+ text_encoder=text_encoder,
145
+ tokenizer=tokenizer,
146
+ transformer=transformer,
147
+ scheduler=scheduler,
148
+ )
149
+
150
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
151
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
152
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
153
+
154
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
155
+ def _get_t5_prompt_embeds(
156
+ self,
157
+ prompt: Union[str, List[str]] = None,
158
+ num_videos_per_prompt: int = 1,
159
+ max_sequence_length: int = 226,
160
+ device: Optional[torch.device] = None,
161
+ dtype: Optional[torch.dtype] = None,
162
+ ):
163
+ device = device or self._execution_device
164
+ dtype = dtype or self.text_encoder.dtype
165
+
166
+ prompt = [prompt] if isinstance(prompt, str) else prompt
167
+ prompt = [prompt_clean(u) for u in prompt]
168
+ batch_size = len(prompt)
169
+
170
+ text_inputs = self.tokenizer(
171
+ prompt,
172
+ padding="max_length",
173
+ max_length=max_sequence_length,
174
+ truncation=True,
175
+ add_special_tokens=True,
176
+ return_attention_mask=True,
177
+ return_tensors="pt",
178
+ )
179
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
180
+ seq_lens = mask.gt(0).sum(dim=1).long()
181
+
182
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
183
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
184
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
185
+ prompt_embeds = torch.stack(
186
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
187
+ )
188
+
189
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
190
+ _, seq_len, _ = prompt_embeds.shape
191
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
192
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
193
+
194
+ return prompt_embeds
195
+
196
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
197
+ def encode_prompt(
198
+ self,
199
+ prompt: Union[str, List[str]],
200
+ negative_prompt: Optional[Union[str, List[str]]] = None,
201
+ do_classifier_free_guidance: bool = True,
202
+ num_videos_per_prompt: int = 1,
203
+ prompt_embeds: Optional[torch.Tensor] = None,
204
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
205
+ max_sequence_length: int = 226,
206
+ device: Optional[torch.device] = None,
207
+ dtype: Optional[torch.dtype] = None,
208
+ ):
209
+ r"""
210
+ Encodes the prompt into text encoder hidden states.
211
+
212
+ Args:
213
+ prompt (`str` or `List[str]`, *optional*):
214
+ prompt to be encoded
215
+ negative_prompt (`str` or `List[str]`, *optional*):
216
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
217
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
218
+ less than `1`).
219
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
220
+ Whether to use classifier free guidance or not.
221
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
222
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
223
+ prompt_embeds (`torch.Tensor`, *optional*):
224
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
225
+ provided, text embeddings will be generated from `prompt` input argument.
226
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
227
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
228
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
229
+ argument.
230
+ device: (`torch.device`, *optional*):
231
+ torch device
232
+ dtype: (`torch.dtype`, *optional*):
233
+ torch dtype
234
+ """
235
+ device = device or self._execution_device
236
+
237
+ prompt = [prompt] if isinstance(prompt, str) else prompt
238
+ if prompt is not None:
239
+ batch_size = len(prompt)
240
+ else:
241
+ batch_size = prompt_embeds.shape[0]
242
+
243
+ if prompt_embeds is None:
244
+ prompt_embeds = self._get_t5_prompt_embeds(
245
+ prompt=prompt,
246
+ num_videos_per_prompt=num_videos_per_prompt,
247
+ max_sequence_length=max_sequence_length,
248
+ device=device,
249
+ dtype=dtype,
250
+ )
251
+
252
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
253
+ negative_prompt = negative_prompt or ""
254
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
255
+
256
+ if prompt is not None and type(prompt) is not type(negative_prompt):
257
+ raise TypeError(
258
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
259
+ f" {type(prompt)}."
260
+ )
261
+ elif batch_size != len(negative_prompt):
262
+ raise ValueError(
263
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
264
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
265
+ " the batch size of `prompt`."
266
+ )
267
+
268
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
269
+ prompt=negative_prompt,
270
+ num_videos_per_prompt=num_videos_per_prompt,
271
+ max_sequence_length=max_sequence_length,
272
+ device=device,
273
+ dtype=dtype,
274
+ )
275
+
276
+ return prompt_embeds, negative_prompt_embeds
277
+
278
+ def check_inputs(
279
+ self,
280
+ prompt,
281
+ negative_prompt,
282
+ height,
283
+ width,
284
+ prompt_embeds=None,
285
+ negative_prompt_embeds=None,
286
+ callback_on_step_end_tensor_inputs=None,
287
+ ):
288
+ if height % 16 != 0 or width % 16 != 0:
289
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
290
+
291
+ if callback_on_step_end_tensor_inputs is not None and not all(
292
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
293
+ ):
294
+ raise ValueError(
295
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
296
+ )
297
+
298
+ if prompt is not None and prompt_embeds is not None:
299
+ raise ValueError(
300
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
301
+ " only forward one of the two."
302
+ )
303
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
304
+ raise ValueError(
305
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
306
+ " only forward one of the two."
307
+ )
308
+ elif prompt is None and prompt_embeds is None:
309
+ raise ValueError(
310
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
311
+ )
312
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
313
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
314
+ elif negative_prompt is not None and (
315
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
316
+ ):
317
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
318
+
319
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents
320
+ def prepare_latents(
321
+ self,
322
+ batch_size: int,
323
+ num_channels_latents: int = 16,
324
+ height: int = 480,
325
+ width: int = 832,
326
+ num_frames: int = 81,
327
+ dtype: Optional[torch.dtype] = None,
328
+ device: Optional[torch.device] = None,
329
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
330
+ latents: Optional[torch.Tensor] = None,
331
+ ) -> torch.Tensor:
332
+ if latents is not None:
333
+ return latents.to(device=device, dtype=dtype)
334
+
335
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
336
+ shape = (
337
+ batch_size,
338
+ num_channels_latents,
339
+ num_latent_frames,
340
+ int(height) // self.vae_scale_factor_spatial,
341
+ int(width) // self.vae_scale_factor_spatial,
342
+ )
343
+ if isinstance(generator, list) and len(generator) != batch_size:
344
+ raise ValueError(
345
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
346
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
347
+ )
348
+
349
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
350
+ return latents
351
+
352
+ @property
353
+ def guidance_scale(self):
354
+ return self._guidance_scale
355
+
356
+ @property
357
+ def do_classifier_free_guidance(self):
358
+ return self._guidance_scale > 1.0
359
+
360
+ @property
361
+ def num_timesteps(self):
362
+ return self._num_timesteps
363
+
364
+ @property
365
+ def current_timestep(self):
366
+ return self._current_timestep
367
+
368
+ @property
369
+ def interrupt(self):
370
+ return self._interrupt
371
+
372
+ @property
373
+ def attention_kwargs(self):
374
+ return self._attention_kwargs
375
+
376
+ @torch.no_grad()
377
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
378
+ def __call__(
379
+ self,
380
+ prompt: Union[str, List[str]] = None,
381
+ negative_prompt: Union[str, List[str]] = None,
382
+ height: int = 544,
383
+ width: int = 960,
384
+ num_frames: int = 97,
385
+ num_inference_steps: int = 50,
386
+ guidance_scale: float = 6.0,
387
+ num_videos_per_prompt: Optional[int] = 1,
388
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
389
+ latents: Optional[torch.Tensor] = None,
390
+ prompt_embeds: Optional[torch.Tensor] = None,
391
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
392
+ output_type: Optional[str] = "np",
393
+ return_dict: bool = True,
394
+ attention_kwargs: Optional[Dict[str, Any]] = None,
395
+ callback_on_step_end: Optional[
396
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
397
+ ] = None,
398
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
399
+ max_sequence_length: int = 512,
400
+ ):
401
+ r"""
402
+ The call function to the pipeline for generation.
403
+
404
+ Args:
405
+ prompt (`str` or `List[str]`, *optional*):
406
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
407
+ instead.
408
+ height (`int`, defaults to `544`):
409
+ The height in pixels of the generated image.
410
+ width (`int`, defaults to `960`):
411
+ The width in pixels of the generated image.
412
+ num_frames (`int`, defaults to `97`):
413
+ The number of frames in the generated video.
414
+ num_inference_steps (`int`, defaults to `50`):
415
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
416
+ expense of slower inference.
417
+ guidance_scale (`float`, defaults to `6.0`):
418
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
419
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
420
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
421
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
422
+ usually at the expense of lower image quality.
423
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
424
+ The number of images to generate per prompt.
425
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
426
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
427
+ generation deterministic.
428
+ latents (`torch.Tensor`, *optional*):
429
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
430
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
431
+ tensor is generated by sampling using the supplied random `generator`.
432
+ prompt_embeds (`torch.Tensor`, *optional*):
433
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
434
+ provided, text embeddings are generated from the `prompt` input argument.
435
+ output_type (`str`, *optional*, defaults to `"np"`):
436
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
437
+ return_dict (`bool`, *optional*, defaults to `True`):
438
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
439
+ attention_kwargs (`dict`, *optional*):
440
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
441
+ `self.processor` in
442
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
443
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
444
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
445
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
446
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
447
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
448
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
449
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
450
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
451
+ `._callback_tensor_inputs` attribute of your pipeline class.
452
+ max_sequence_length (`int`, *optional*, defaults to `512`):
453
+ The maximum sequence length for the text encoder.
454
+
455
+ Examples:
456
+
457
+ Returns:
458
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
459
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
460
+ where the first element is a list with the generated images and the second element is a list of `bool`s
461
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
462
+ """
463
+
464
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
465
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
466
+
467
+ # 1. Check inputs. Raise error if not correct
468
+ self.check_inputs(
469
+ prompt,
470
+ negative_prompt,
471
+ height,
472
+ width,
473
+ prompt_embeds,
474
+ negative_prompt_embeds,
475
+ callback_on_step_end_tensor_inputs,
476
+ )
477
+
478
+ if num_frames % self.vae_scale_factor_temporal != 1:
479
+ logger.warning(
480
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
481
+ )
482
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
483
+ num_frames = max(num_frames, 1)
484
+
485
+ self._guidance_scale = guidance_scale
486
+ self._attention_kwargs = attention_kwargs
487
+ self._current_timestep = None
488
+ self._interrupt = False
489
+
490
+ device = self._execution_device
491
+
492
+ # 2. Define call parameters
493
+ if prompt is not None and isinstance(prompt, str):
494
+ batch_size = 1
495
+ elif prompt is not None and isinstance(prompt, list):
496
+ batch_size = len(prompt)
497
+ else:
498
+ batch_size = prompt_embeds.shape[0]
499
+
500
+ # 3. Encode input prompt
501
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
502
+ prompt=prompt,
503
+ negative_prompt=negative_prompt,
504
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
505
+ num_videos_per_prompt=num_videos_per_prompt,
506
+ prompt_embeds=prompt_embeds,
507
+ negative_prompt_embeds=negative_prompt_embeds,
508
+ max_sequence_length=max_sequence_length,
509
+ device=device,
510
+ )
511
+
512
+ transformer_dtype = self.transformer.dtype
513
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
514
+ if negative_prompt_embeds is not None:
515
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
516
+
517
+ # 4. Prepare timesteps
518
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
519
+ timesteps = self.scheduler.timesteps
520
+
521
+ # 5. Prepare latent variables
522
+ num_channels_latents = self.transformer.config.in_channels
523
+ latents = self.prepare_latents(
524
+ batch_size * num_videos_per_prompt,
525
+ num_channels_latents,
526
+ height,
527
+ width,
528
+ num_frames,
529
+ torch.float32,
530
+ device,
531
+ generator,
532
+ latents,
533
+ )
534
+
535
+ # 6. Denoising loop
536
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
537
+ self._num_timesteps = len(timesteps)
538
+
539
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
540
+ for i, t in enumerate(timesteps):
541
+ if self.interrupt:
542
+ continue
543
+
544
+ self._current_timestep = t
545
+ latent_model_input = latents.to(transformer_dtype)
546
+ timestep = t.expand(latents.shape[0])
547
+
548
+ noise_pred = self.transformer(
549
+ hidden_states=latent_model_input,
550
+ timestep=timestep,
551
+ encoder_hidden_states=prompt_embeds,
552
+ attention_kwargs=attention_kwargs,
553
+ return_dict=False,
554
+ )[0]
555
+
556
+ if self.do_classifier_free_guidance:
557
+ noise_uncond = self.transformer(
558
+ hidden_states=latent_model_input,
559
+ timestep=timestep,
560
+ encoder_hidden_states=negative_prompt_embeds,
561
+ attention_kwargs=attention_kwargs,
562
+ return_dict=False,
563
+ )[0]
564
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
565
+
566
+ # compute the previous noisy sample x_t -> x_t-1
567
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
568
+
569
+ if callback_on_step_end is not None:
570
+ callback_kwargs = {}
571
+ for k in callback_on_step_end_tensor_inputs:
572
+ callback_kwargs[k] = locals()[k]
573
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
574
+
575
+ latents = callback_outputs.pop("latents", latents)
576
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
577
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
578
+
579
+ # call the callback, if provided
580
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
581
+ progress_bar.update()
582
+
583
+ if XLA_AVAILABLE:
584
+ xm.mark_step()
585
+
586
+ self._current_timestep = None
587
+
588
+ if not output_type == "latent":
589
+ latents = latents.to(self.vae.dtype)
590
+ latents_mean = (
591
+ torch.tensor(self.vae.config.latents_mean)
592
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
593
+ .to(latents.device, latents.dtype)
594
+ )
595
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
596
+ latents.device, latents.dtype
597
+ )
598
+ latents = latents / latents_std + latents_mean
599
+ video = self.vae.decode(latents, return_dict=False)[0]
600
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
601
+ else:
602
+ video = latents
603
+
604
+ # Offload all models
605
+ self.maybe_free_model_hooks()
606
+
607
+ if not return_dict:
608
+ return (video,)
609
+
610
+ return SkyReelsV2PipelineOutput(frames=video)