diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,365 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import List, Optional, Union
17
+
18
+ import torch
19
+
20
+ from ...schedulers import UniPCMultistepScheduler
21
+ from ...utils import logging
22
+ from ...utils.torch_utils import randn_tensor
23
+ from ..modular_pipeline import ModularPipelineBlocks, PipelineState
24
+ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
25
+ from .modular_pipeline import WanModularPipeline
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ # TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
32
+ # things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
33
+ # always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
34
+ # configuration of guider is.
35
+
36
+
37
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
38
+ def retrieve_timesteps(
39
+ scheduler,
40
+ num_inference_steps: Optional[int] = None,
41
+ device: Optional[Union[str, torch.device]] = None,
42
+ timesteps: Optional[List[int]] = None,
43
+ sigmas: Optional[List[float]] = None,
44
+ **kwargs,
45
+ ):
46
+ r"""
47
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
48
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
49
+
50
+ Args:
51
+ scheduler (`SchedulerMixin`):
52
+ The scheduler to get timesteps from.
53
+ num_inference_steps (`int`):
54
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
55
+ must be `None`.
56
+ device (`str` or `torch.device`, *optional*):
57
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
58
+ timesteps (`List[int]`, *optional*):
59
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
60
+ `num_inference_steps` and `sigmas` must be `None`.
61
+ sigmas (`List[float]`, *optional*):
62
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
63
+ `num_inference_steps` and `timesteps` must be `None`.
64
+
65
+ Returns:
66
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
67
+ second element is the number of inference steps.
68
+ """
69
+ if timesteps is not None and sigmas is not None:
70
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
71
+ if timesteps is not None:
72
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
73
+ if not accepts_timesteps:
74
+ raise ValueError(
75
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
76
+ f" timestep schedules. Please check whether you are using the correct scheduler."
77
+ )
78
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
79
+ timesteps = scheduler.timesteps
80
+ num_inference_steps = len(timesteps)
81
+ elif sigmas is not None:
82
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
83
+ if not accept_sigmas:
84
+ raise ValueError(
85
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
86
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
87
+ )
88
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
89
+ timesteps = scheduler.timesteps
90
+ num_inference_steps = len(timesteps)
91
+ else:
92
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
93
+ timesteps = scheduler.timesteps
94
+ return timesteps, num_inference_steps
95
+
96
+
97
+ class WanInputStep(ModularPipelineBlocks):
98
+ model_name = "wan"
99
+
100
+ @property
101
+ def description(self) -> str:
102
+ return (
103
+ "Input processing step that:\n"
104
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
105
+ " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n"
106
+ "All input tensors are expected to have either batch_size=1 or match the batch_size\n"
107
+ "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
108
+ "have a final batch_size of batch_size * num_videos_per_prompt."
109
+ )
110
+
111
+ @property
112
+ def inputs(self) -> List[InputParam]:
113
+ return [
114
+ InputParam("num_videos_per_prompt", default=1),
115
+ ]
116
+
117
+ @property
118
+ def intermediate_inputs(self) -> List[str]:
119
+ return [
120
+ InputParam(
121
+ "prompt_embeds",
122
+ required=True,
123
+ type_hint=torch.Tensor,
124
+ description="Pre-generated text embeddings. Can be generated from text_encoder step.",
125
+ ),
126
+ InputParam(
127
+ "negative_prompt_embeds",
128
+ type_hint=torch.Tensor,
129
+ description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
130
+ ),
131
+ ]
132
+
133
+ @property
134
+ def intermediate_outputs(self) -> List[str]:
135
+ return [
136
+ OutputParam(
137
+ "batch_size",
138
+ type_hint=int,
139
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt",
140
+ ),
141
+ OutputParam(
142
+ "dtype",
143
+ type_hint=torch.dtype,
144
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
145
+ ),
146
+ OutputParam(
147
+ "prompt_embeds",
148
+ type_hint=torch.Tensor,
149
+ kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
150
+ description="text embeddings used to guide the image generation",
151
+ ),
152
+ OutputParam(
153
+ "negative_prompt_embeds",
154
+ type_hint=torch.Tensor,
155
+ kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
156
+ description="negative text embeddings used to guide the image generation",
157
+ ),
158
+ ]
159
+
160
+ def check_inputs(self, components, block_state):
161
+ if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
162
+ if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape:
163
+ raise ValueError(
164
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
165
+ f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`"
166
+ f" {block_state.negative_prompt_embeds.shape}."
167
+ )
168
+
169
+ @torch.no_grad()
170
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
171
+ block_state = self.get_block_state(state)
172
+ self.check_inputs(components, block_state)
173
+
174
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
175
+ block_state.dtype = block_state.prompt_embeds.dtype
176
+
177
+ _, seq_len, _ = block_state.prompt_embeds.shape
178
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1)
179
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
180
+ block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
181
+ )
182
+
183
+ if block_state.negative_prompt_embeds is not None:
184
+ _, seq_len, _ = block_state.negative_prompt_embeds.shape
185
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
186
+ 1, block_state.num_videos_per_prompt, 1
187
+ )
188
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
189
+ block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
190
+ )
191
+
192
+ self.set_block_state(state, block_state)
193
+
194
+ return components, state
195
+
196
+
197
+ class WanSetTimestepsStep(ModularPipelineBlocks):
198
+ model_name = "wan"
199
+
200
+ @property
201
+ def expected_components(self) -> List[ComponentSpec]:
202
+ return [
203
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
204
+ ]
205
+
206
+ @property
207
+ def description(self) -> str:
208
+ return "Step that sets the scheduler's timesteps for inference"
209
+
210
+ @property
211
+ def inputs(self) -> List[InputParam]:
212
+ return [
213
+ InputParam("num_inference_steps", default=50),
214
+ InputParam("timesteps"),
215
+ InputParam("sigmas"),
216
+ ]
217
+
218
+ @property
219
+ def intermediate_outputs(self) -> List[OutputParam]:
220
+ return [
221
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
222
+ OutputParam(
223
+ "num_inference_steps",
224
+ type_hint=int,
225
+ description="The number of denoising steps to perform at inference time",
226
+ ),
227
+ ]
228
+
229
+ @torch.no_grad()
230
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
231
+ block_state = self.get_block_state(state)
232
+ block_state.device = components._execution_device
233
+
234
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
235
+ components.scheduler,
236
+ block_state.num_inference_steps,
237
+ block_state.device,
238
+ block_state.timesteps,
239
+ block_state.sigmas,
240
+ )
241
+
242
+ self.set_block_state(state, block_state)
243
+ return components, state
244
+
245
+
246
+ class WanPrepareLatentsStep(ModularPipelineBlocks):
247
+ model_name = "wan"
248
+
249
+ @property
250
+ def expected_components(self) -> List[ComponentSpec]:
251
+ return []
252
+
253
+ @property
254
+ def description(self) -> str:
255
+ return "Prepare latents step that prepares the latents for the text-to-video generation process"
256
+
257
+ @property
258
+ def inputs(self) -> List[InputParam]:
259
+ return [
260
+ InputParam("height", type_hint=int),
261
+ InputParam("width", type_hint=int),
262
+ InputParam("num_frames", type_hint=int),
263
+ InputParam("latents", type_hint=Optional[torch.Tensor]),
264
+ InputParam("num_videos_per_prompt", type_hint=int, default=1),
265
+ ]
266
+
267
+ @property
268
+ def intermediate_inputs(self) -> List[InputParam]:
269
+ return [
270
+ InputParam("generator"),
271
+ InputParam(
272
+ "batch_size",
273
+ required=True,
274
+ type_hint=int,
275
+ description="Number of prompts, the final batch size of model inputs should be `batch_size * num_videos_per_prompt`. Can be generated in input step.",
276
+ ),
277
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
278
+ ]
279
+
280
+ @property
281
+ def intermediate_outputs(self) -> List[OutputParam]:
282
+ return [
283
+ OutputParam(
284
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
285
+ )
286
+ ]
287
+
288
+ @staticmethod
289
+ def check_inputs(components, block_state):
290
+ if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
291
+ block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
292
+ ):
293
+ raise ValueError(
294
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
295
+ )
296
+ if block_state.num_frames is not None and (
297
+ block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
298
+ ):
299
+ raise ValueError(
300
+ f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
301
+ )
302
+
303
+ @staticmethod
304
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents with self->comp
305
+ def prepare_latents(
306
+ comp,
307
+ batch_size: int,
308
+ num_channels_latents: int = 16,
309
+ height: int = 480,
310
+ width: int = 832,
311
+ num_frames: int = 81,
312
+ dtype: Optional[torch.dtype] = None,
313
+ device: Optional[torch.device] = None,
314
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
315
+ latents: Optional[torch.Tensor] = None,
316
+ ) -> torch.Tensor:
317
+ if latents is not None:
318
+ return latents.to(device=device, dtype=dtype)
319
+
320
+ num_latent_frames = (num_frames - 1) // comp.vae_scale_factor_temporal + 1
321
+ shape = (
322
+ batch_size,
323
+ num_channels_latents,
324
+ num_latent_frames,
325
+ int(height) // comp.vae_scale_factor_spatial,
326
+ int(width) // comp.vae_scale_factor_spatial,
327
+ )
328
+ if isinstance(generator, list) and len(generator) != batch_size:
329
+ raise ValueError(
330
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
331
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
332
+ )
333
+
334
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
335
+ return latents
336
+
337
+ @torch.no_grad()
338
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
339
+ block_state = self.get_block_state(state)
340
+
341
+ block_state.height = block_state.height or components.default_height
342
+ block_state.width = block_state.width or components.default_width
343
+ block_state.num_frames = block_state.num_frames or components.default_num_frames
344
+ block_state.device = components._execution_device
345
+ block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
346
+ block_state.num_channels_latents = components.num_channels_latents
347
+
348
+ self.check_inputs(components, block_state)
349
+
350
+ block_state.latents = self.prepare_latents(
351
+ components,
352
+ block_state.batch_size * block_state.num_videos_per_prompt,
353
+ block_state.num_channels_latents,
354
+ block_state.height,
355
+ block_state.width,
356
+ block_state.num_frames,
357
+ block_state.dtype,
358
+ block_state.device,
359
+ block_state.generator,
360
+ block_state.latents,
361
+ )
362
+
363
+ self.set_block_state(state, block_state)
364
+
365
+ return components, state
@@ -0,0 +1,105 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List, Tuple, Union
16
+
17
+ import numpy as np
18
+ import PIL
19
+ import torch
20
+
21
+ from ...configuration_utils import FrozenDict
22
+ from ...models import AutoencoderKLWan
23
+ from ...utils import logging
24
+ from ...video_processor import VideoProcessor
25
+ from ..modular_pipeline import ModularPipelineBlocks, PipelineState
26
+ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ class WanDecodeStep(ModularPipelineBlocks):
33
+ model_name = "wan"
34
+
35
+ @property
36
+ def expected_components(self) -> List[ComponentSpec]:
37
+ return [
38
+ ComponentSpec("vae", AutoencoderKLWan),
39
+ ComponentSpec(
40
+ "video_processor",
41
+ VideoProcessor,
42
+ config=FrozenDict({"vae_scale_factor": 8}),
43
+ default_creation_method="from_config",
44
+ ),
45
+ ]
46
+
47
+ @property
48
+ def description(self) -> str:
49
+ return "Step that decodes the denoised latents into images"
50
+
51
+ @property
52
+ def inputs(self) -> List[Tuple[str, Any]]:
53
+ return [
54
+ InputParam("output_type", default="pil"),
55
+ ]
56
+
57
+ @property
58
+ def intermediate_inputs(self) -> List[str]:
59
+ return [
60
+ InputParam(
61
+ "latents",
62
+ required=True,
63
+ type_hint=torch.Tensor,
64
+ description="The denoised latents from the denoising step",
65
+ )
66
+ ]
67
+
68
+ @property
69
+ def intermediate_outputs(self) -> List[str]:
70
+ return [
71
+ OutputParam(
72
+ "videos",
73
+ type_hint=Union[List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]],
74
+ description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
75
+ )
76
+ ]
77
+
78
+ @torch.no_grad()
79
+ def __call__(self, components, state: PipelineState) -> PipelineState:
80
+ block_state = self.get_block_state(state)
81
+ vae_dtype = components.vae.dtype
82
+
83
+ if not block_state.output_type == "latent":
84
+ latents = block_state.latents
85
+ latents_mean = (
86
+ torch.tensor(components.vae.config.latents_mean)
87
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
88
+ .to(latents.device, latents.dtype)
89
+ )
90
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
91
+ 1, components.vae.config.z_dim, 1, 1, 1
92
+ ).to(latents.device, latents.dtype)
93
+ latents = latents / latents_std + latents_mean
94
+ latents = latents.to(vae_dtype)
95
+ block_state.videos = components.vae.decode(latents, return_dict=False)[0]
96
+ else:
97
+ block_state.videos = block_state.latents
98
+
99
+ block_state.videos = components.video_processor.postprocess_video(
100
+ block_state.videos, output_type=block_state.output_type
101
+ )
102
+
103
+ self.set_block_state(state, block_state)
104
+
105
+ return components, state