diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,261 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, List, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
|
19
|
+
from ...configuration_utils import FrozenDict
|
20
|
+
from ...guiders import ClassifierFreeGuidance
|
21
|
+
from ...models import WanTransformer3DModel
|
22
|
+
from ...schedulers import UniPCMultistepScheduler
|
23
|
+
from ...utils import logging
|
24
|
+
from ..modular_pipeline import (
|
25
|
+
BlockState,
|
26
|
+
LoopSequentialPipelineBlocks,
|
27
|
+
ModularPipelineBlocks,
|
28
|
+
PipelineState,
|
29
|
+
)
|
30
|
+
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
31
|
+
from .modular_pipeline import WanModularPipeline
|
32
|
+
|
33
|
+
|
34
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35
|
+
|
36
|
+
|
37
|
+
class WanLoopDenoiser(ModularPipelineBlocks):
|
38
|
+
model_name = "wan"
|
39
|
+
|
40
|
+
@property
|
41
|
+
def expected_components(self) -> List[ComponentSpec]:
|
42
|
+
return [
|
43
|
+
ComponentSpec(
|
44
|
+
"guider",
|
45
|
+
ClassifierFreeGuidance,
|
46
|
+
config=FrozenDict({"guidance_scale": 5.0}),
|
47
|
+
default_creation_method="from_config",
|
48
|
+
),
|
49
|
+
ComponentSpec("transformer", WanTransformer3DModel),
|
50
|
+
]
|
51
|
+
|
52
|
+
@property
|
53
|
+
def description(self) -> str:
|
54
|
+
return (
|
55
|
+
"Step within the denoising loop that denoise the latents with guidance. "
|
56
|
+
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
57
|
+
"object (e.g. `WanDenoiseLoopWrapper`)"
|
58
|
+
)
|
59
|
+
|
60
|
+
@property
|
61
|
+
def inputs(self) -> List[Tuple[str, Any]]:
|
62
|
+
return [
|
63
|
+
InputParam("attention_kwargs"),
|
64
|
+
]
|
65
|
+
|
66
|
+
@property
|
67
|
+
def intermediate_inputs(self) -> List[str]:
|
68
|
+
return [
|
69
|
+
InputParam(
|
70
|
+
"latents",
|
71
|
+
required=True,
|
72
|
+
type_hint=torch.Tensor,
|
73
|
+
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
74
|
+
),
|
75
|
+
InputParam(
|
76
|
+
"num_inference_steps",
|
77
|
+
required=True,
|
78
|
+
type_hint=int,
|
79
|
+
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
80
|
+
),
|
81
|
+
InputParam(
|
82
|
+
kwargs_type="guider_input_fields",
|
83
|
+
description=(
|
84
|
+
"All conditional model inputs that need to be prepared with guider. "
|
85
|
+
"It should contain prompt_embeds/negative_prompt_embeds. "
|
86
|
+
"Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
|
87
|
+
),
|
88
|
+
),
|
89
|
+
]
|
90
|
+
|
91
|
+
@torch.no_grad()
|
92
|
+
def __call__(
|
93
|
+
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
94
|
+
) -> PipelineState:
|
95
|
+
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
|
96
|
+
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
|
97
|
+
guider_input_fields = {
|
98
|
+
"prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
|
99
|
+
}
|
100
|
+
transformer_dtype = components.transformer.dtype
|
101
|
+
|
102
|
+
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
103
|
+
|
104
|
+
# Prepare mini‐batches according to guidance method and `guider_input_fields`
|
105
|
+
# Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
|
106
|
+
# e.g. for CFG, we prepare two batches: one for uncond, one for cond
|
107
|
+
# for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
|
108
|
+
# for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
|
109
|
+
guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
|
110
|
+
|
111
|
+
# run the denoiser for each guidance batch
|
112
|
+
for guider_state_batch in guider_state:
|
113
|
+
components.guider.prepare_models(components.transformer)
|
114
|
+
cond_kwargs = guider_state_batch.as_dict()
|
115
|
+
cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
|
116
|
+
prompt_embeds = cond_kwargs.pop("prompt_embeds")
|
117
|
+
|
118
|
+
# Predict the noise residual
|
119
|
+
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
120
|
+
guider_state_batch.noise_pred = components.transformer(
|
121
|
+
hidden_states=block_state.latents.to(transformer_dtype),
|
122
|
+
timestep=t.flatten(),
|
123
|
+
encoder_hidden_states=prompt_embeds,
|
124
|
+
attention_kwargs=block_state.attention_kwargs,
|
125
|
+
return_dict=False,
|
126
|
+
)[0]
|
127
|
+
components.guider.cleanup_models(components.transformer)
|
128
|
+
|
129
|
+
# Perform guidance
|
130
|
+
block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
|
131
|
+
|
132
|
+
return components, block_state
|
133
|
+
|
134
|
+
|
135
|
+
class WanLoopAfterDenoiser(ModularPipelineBlocks):
|
136
|
+
model_name = "wan"
|
137
|
+
|
138
|
+
@property
|
139
|
+
def expected_components(self) -> List[ComponentSpec]:
|
140
|
+
return [
|
141
|
+
ComponentSpec("scheduler", UniPCMultistepScheduler),
|
142
|
+
]
|
143
|
+
|
144
|
+
@property
|
145
|
+
def description(self) -> str:
|
146
|
+
return (
|
147
|
+
"step within the denoising loop that update the latents. "
|
148
|
+
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
149
|
+
"object (e.g. `WanDenoiseLoopWrapper`)"
|
150
|
+
)
|
151
|
+
|
152
|
+
@property
|
153
|
+
def inputs(self) -> List[Tuple[str, Any]]:
|
154
|
+
return []
|
155
|
+
|
156
|
+
@property
|
157
|
+
def intermediate_inputs(self) -> List[str]:
|
158
|
+
return [
|
159
|
+
InputParam("generator"),
|
160
|
+
]
|
161
|
+
|
162
|
+
@property
|
163
|
+
def intermediate_outputs(self) -> List[OutputParam]:
|
164
|
+
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
165
|
+
|
166
|
+
@torch.no_grad()
|
167
|
+
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
168
|
+
# Perform scheduler step using the predicted output
|
169
|
+
latents_dtype = block_state.latents.dtype
|
170
|
+
block_state.latents = components.scheduler.step(
|
171
|
+
block_state.noise_pred.float(),
|
172
|
+
t,
|
173
|
+
block_state.latents.float(),
|
174
|
+
**block_state.scheduler_step_kwargs,
|
175
|
+
return_dict=False,
|
176
|
+
)[0]
|
177
|
+
|
178
|
+
if block_state.latents.dtype != latents_dtype:
|
179
|
+
block_state.latents = block_state.latents.to(latents_dtype)
|
180
|
+
|
181
|
+
return components, block_state
|
182
|
+
|
183
|
+
|
184
|
+
class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
185
|
+
model_name = "wan"
|
186
|
+
|
187
|
+
@property
|
188
|
+
def description(self) -> str:
|
189
|
+
return (
|
190
|
+
"Pipeline block that iteratively denoise the latents over `timesteps`. "
|
191
|
+
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
|
192
|
+
)
|
193
|
+
|
194
|
+
@property
|
195
|
+
def loop_expected_components(self) -> List[ComponentSpec]:
|
196
|
+
return [
|
197
|
+
ComponentSpec(
|
198
|
+
"guider",
|
199
|
+
ClassifierFreeGuidance,
|
200
|
+
config=FrozenDict({"guidance_scale": 5.0}),
|
201
|
+
default_creation_method="from_config",
|
202
|
+
),
|
203
|
+
ComponentSpec("scheduler", UniPCMultistepScheduler),
|
204
|
+
ComponentSpec("transformer", WanTransformer3DModel),
|
205
|
+
]
|
206
|
+
|
207
|
+
@property
|
208
|
+
def loop_intermediate_inputs(self) -> List[InputParam]:
|
209
|
+
return [
|
210
|
+
InputParam(
|
211
|
+
"timesteps",
|
212
|
+
required=True,
|
213
|
+
type_hint=torch.Tensor,
|
214
|
+
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
215
|
+
),
|
216
|
+
InputParam(
|
217
|
+
"num_inference_steps",
|
218
|
+
required=True,
|
219
|
+
type_hint=int,
|
220
|
+
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
221
|
+
),
|
222
|
+
]
|
223
|
+
|
224
|
+
@torch.no_grad()
|
225
|
+
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
226
|
+
block_state = self.get_block_state(state)
|
227
|
+
|
228
|
+
block_state.num_warmup_steps = max(
|
229
|
+
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
230
|
+
)
|
231
|
+
|
232
|
+
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
233
|
+
for i, t in enumerate(block_state.timesteps):
|
234
|
+
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
235
|
+
if i == len(block_state.timesteps) - 1 or (
|
236
|
+
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
237
|
+
):
|
238
|
+
progress_bar.update()
|
239
|
+
|
240
|
+
self.set_block_state(state, block_state)
|
241
|
+
|
242
|
+
return components, state
|
243
|
+
|
244
|
+
|
245
|
+
class WanDenoiseStep(WanDenoiseLoopWrapper):
|
246
|
+
block_classes = [
|
247
|
+
WanLoopDenoiser,
|
248
|
+
WanLoopAfterDenoiser,
|
249
|
+
]
|
250
|
+
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
251
|
+
|
252
|
+
@property
|
253
|
+
def description(self) -> str:
|
254
|
+
return (
|
255
|
+
"Denoise step that iteratively denoise the latents. \n"
|
256
|
+
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
257
|
+
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
258
|
+
" - `WanLoopDenoiser`\n"
|
259
|
+
" - `WanLoopAfterDenoiser`\n"
|
260
|
+
"This block supports both text2vid tasks."
|
261
|
+
)
|
@@ -0,0 +1,242 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import html
|
16
|
+
from typing import List, Optional, Union
|
17
|
+
|
18
|
+
import regex as re
|
19
|
+
import torch
|
20
|
+
from transformers import AutoTokenizer, UMT5EncoderModel
|
21
|
+
|
22
|
+
from ...configuration_utils import FrozenDict
|
23
|
+
from ...guiders import ClassifierFreeGuidance
|
24
|
+
from ...utils import is_ftfy_available, logging
|
25
|
+
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
26
|
+
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
27
|
+
from .modular_pipeline import WanModularPipeline
|
28
|
+
|
29
|
+
|
30
|
+
if is_ftfy_available():
|
31
|
+
import ftfy
|
32
|
+
|
33
|
+
|
34
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35
|
+
|
36
|
+
|
37
|
+
def basic_clean(text):
|
38
|
+
text = ftfy.fix_text(text)
|
39
|
+
text = html.unescape(html.unescape(text))
|
40
|
+
return text.strip()
|
41
|
+
|
42
|
+
|
43
|
+
def whitespace_clean(text):
|
44
|
+
text = re.sub(r"\s+", " ", text)
|
45
|
+
text = text.strip()
|
46
|
+
return text
|
47
|
+
|
48
|
+
|
49
|
+
def prompt_clean(text):
|
50
|
+
text = whitespace_clean(basic_clean(text))
|
51
|
+
return text
|
52
|
+
|
53
|
+
|
54
|
+
class WanTextEncoderStep(ModularPipelineBlocks):
|
55
|
+
model_name = "wan"
|
56
|
+
|
57
|
+
@property
|
58
|
+
def description(self) -> str:
|
59
|
+
return "Text Encoder step that generate text_embeddings to guide the video generation"
|
60
|
+
|
61
|
+
@property
|
62
|
+
def expected_components(self) -> List[ComponentSpec]:
|
63
|
+
return [
|
64
|
+
ComponentSpec("text_encoder", UMT5EncoderModel),
|
65
|
+
ComponentSpec("tokenizer", AutoTokenizer),
|
66
|
+
ComponentSpec(
|
67
|
+
"guider",
|
68
|
+
ClassifierFreeGuidance,
|
69
|
+
config=FrozenDict({"guidance_scale": 5.0}),
|
70
|
+
default_creation_method="from_config",
|
71
|
+
),
|
72
|
+
]
|
73
|
+
|
74
|
+
@property
|
75
|
+
def expected_configs(self) -> List[ConfigSpec]:
|
76
|
+
return []
|
77
|
+
|
78
|
+
@property
|
79
|
+
def inputs(self) -> List[InputParam]:
|
80
|
+
return [
|
81
|
+
InputParam("prompt"),
|
82
|
+
InputParam("negative_prompt"),
|
83
|
+
InputParam("attention_kwargs"),
|
84
|
+
]
|
85
|
+
|
86
|
+
@property
|
87
|
+
def intermediate_outputs(self) -> List[OutputParam]:
|
88
|
+
return [
|
89
|
+
OutputParam(
|
90
|
+
"prompt_embeds",
|
91
|
+
type_hint=torch.Tensor,
|
92
|
+
kwargs_type="guider_input_fields",
|
93
|
+
description="text embeddings used to guide the image generation",
|
94
|
+
),
|
95
|
+
OutputParam(
|
96
|
+
"negative_prompt_embeds",
|
97
|
+
type_hint=torch.Tensor,
|
98
|
+
kwargs_type="guider_input_fields",
|
99
|
+
description="negative text embeddings used to guide the image generation",
|
100
|
+
),
|
101
|
+
]
|
102
|
+
|
103
|
+
@staticmethod
|
104
|
+
def check_inputs(block_state):
|
105
|
+
if block_state.prompt is not None and (
|
106
|
+
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
|
107
|
+
):
|
108
|
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
109
|
+
|
110
|
+
@staticmethod
|
111
|
+
def _get_t5_prompt_embeds(
|
112
|
+
components,
|
113
|
+
prompt: Union[str, List[str]],
|
114
|
+
max_sequence_length: int,
|
115
|
+
device: torch.device,
|
116
|
+
):
|
117
|
+
dtype = components.text_encoder.dtype
|
118
|
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
119
|
+
prompt = [prompt_clean(u) for u in prompt]
|
120
|
+
|
121
|
+
text_inputs = components.tokenizer(
|
122
|
+
prompt,
|
123
|
+
padding="max_length",
|
124
|
+
max_length=max_sequence_length,
|
125
|
+
truncation=True,
|
126
|
+
add_special_tokens=True,
|
127
|
+
return_attention_mask=True,
|
128
|
+
return_tensors="pt",
|
129
|
+
)
|
130
|
+
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
131
|
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
132
|
+
prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
133
|
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
134
|
+
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
135
|
+
prompt_embeds = torch.stack(
|
136
|
+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
137
|
+
)
|
138
|
+
|
139
|
+
return prompt_embeds
|
140
|
+
|
141
|
+
@staticmethod
|
142
|
+
def encode_prompt(
|
143
|
+
components,
|
144
|
+
prompt: str,
|
145
|
+
device: Optional[torch.device] = None,
|
146
|
+
num_videos_per_prompt: int = 1,
|
147
|
+
prepare_unconditional_embeds: bool = True,
|
148
|
+
negative_prompt: Optional[str] = None,
|
149
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
150
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
151
|
+
max_sequence_length: int = 512,
|
152
|
+
):
|
153
|
+
r"""
|
154
|
+
Encodes the prompt into text encoder hidden states.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
prompt (`str` or `List[str]`, *optional*):
|
158
|
+
prompt to be encoded
|
159
|
+
device: (`torch.device`):
|
160
|
+
torch device
|
161
|
+
num_videos_per_prompt (`int`):
|
162
|
+
number of videos that should be generated per prompt
|
163
|
+
prepare_unconditional_embeds (`bool`):
|
164
|
+
whether to use prepare unconditional embeddings or not
|
165
|
+
negative_prompt (`str` or `List[str]`, *optional*):
|
166
|
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
167
|
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
168
|
+
less than `1`).
|
169
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
170
|
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
171
|
+
provided, text embeddings will be generated from `prompt` input argument.
|
172
|
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
173
|
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
174
|
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
175
|
+
argument.
|
176
|
+
max_sequence_length (`int`, defaults to `512`):
|
177
|
+
The maximum number of text tokens to be used for the generation process.
|
178
|
+
"""
|
179
|
+
device = device or components._execution_device
|
180
|
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
181
|
+
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
|
182
|
+
|
183
|
+
if prompt_embeds is None:
|
184
|
+
prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device)
|
185
|
+
|
186
|
+
if prepare_unconditional_embeds and negative_prompt_embeds is None:
|
187
|
+
negative_prompt = negative_prompt or ""
|
188
|
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
189
|
+
|
190
|
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
191
|
+
raise TypeError(
|
192
|
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
193
|
+
f" {type(prompt)}."
|
194
|
+
)
|
195
|
+
elif batch_size != len(negative_prompt):
|
196
|
+
raise ValueError(
|
197
|
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
198
|
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
199
|
+
" the batch size of `prompt`."
|
200
|
+
)
|
201
|
+
|
202
|
+
negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(
|
203
|
+
components, negative_prompt, max_sequence_length, device
|
204
|
+
)
|
205
|
+
|
206
|
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
207
|
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
208
|
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
209
|
+
|
210
|
+
if prepare_unconditional_embeds:
|
211
|
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
212
|
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
213
|
+
|
214
|
+
return prompt_embeds, negative_prompt_embeds
|
215
|
+
|
216
|
+
@torch.no_grad()
|
217
|
+
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
218
|
+
# Get inputs and intermediates
|
219
|
+
block_state = self.get_block_state(state)
|
220
|
+
self.check_inputs(block_state)
|
221
|
+
|
222
|
+
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
223
|
+
block_state.device = components._execution_device
|
224
|
+
|
225
|
+
# Encode input prompt
|
226
|
+
(
|
227
|
+
block_state.prompt_embeds,
|
228
|
+
block_state.negative_prompt_embeds,
|
229
|
+
) = self.encode_prompt(
|
230
|
+
components,
|
231
|
+
block_state.prompt,
|
232
|
+
block_state.device,
|
233
|
+
1,
|
234
|
+
block_state.prepare_unconditional_embeds,
|
235
|
+
block_state.negative_prompt,
|
236
|
+
prompt_embeds=None,
|
237
|
+
negative_prompt_embeds=None,
|
238
|
+
)
|
239
|
+
|
240
|
+
# Add outputs
|
241
|
+
self.set_block_state(state, block_state)
|
242
|
+
return components, state
|
@@ -0,0 +1,144 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...utils import logging
|
16
|
+
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
17
|
+
from ..modular_pipeline_utils import InsertableDict
|
18
|
+
from .before_denoise import (
|
19
|
+
WanInputStep,
|
20
|
+
WanPrepareLatentsStep,
|
21
|
+
WanSetTimestepsStep,
|
22
|
+
)
|
23
|
+
from .decoders import WanDecodeStep
|
24
|
+
from .denoise import WanDenoiseStep
|
25
|
+
from .encoders import WanTextEncoderStep
|
26
|
+
|
27
|
+
|
28
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29
|
+
|
30
|
+
|
31
|
+
# before_denoise: text2vid
|
32
|
+
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
|
33
|
+
block_classes = [
|
34
|
+
WanInputStep,
|
35
|
+
WanSetTimestepsStep,
|
36
|
+
WanPrepareLatentsStep,
|
37
|
+
]
|
38
|
+
block_names = ["input", "set_timesteps", "prepare_latents"]
|
39
|
+
|
40
|
+
@property
|
41
|
+
def description(self):
|
42
|
+
return (
|
43
|
+
"Before denoise step that prepare the inputs for the denoise step.\n"
|
44
|
+
+ "This is a sequential pipeline blocks:\n"
|
45
|
+
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
|
46
|
+
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
47
|
+
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
48
|
+
)
|
49
|
+
|
50
|
+
|
51
|
+
# before_denoise: all task (text2vid,)
|
52
|
+
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
53
|
+
block_classes = [
|
54
|
+
WanBeforeDenoiseStep,
|
55
|
+
]
|
56
|
+
block_names = ["text2vid"]
|
57
|
+
block_trigger_inputs = [None]
|
58
|
+
|
59
|
+
@property
|
60
|
+
def description(self):
|
61
|
+
return (
|
62
|
+
"Before denoise step that prepare the inputs for the denoise step.\n"
|
63
|
+
+ "This is an auto pipeline block that works for text2vid.\n"
|
64
|
+
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
|
65
|
+
)
|
66
|
+
|
67
|
+
|
68
|
+
# denoise: text2vid
|
69
|
+
class WanAutoDenoiseStep(AutoPipelineBlocks):
|
70
|
+
block_classes = [
|
71
|
+
WanDenoiseStep,
|
72
|
+
]
|
73
|
+
block_names = ["denoise"]
|
74
|
+
block_trigger_inputs = [None]
|
75
|
+
|
76
|
+
@property
|
77
|
+
def description(self) -> str:
|
78
|
+
return (
|
79
|
+
"Denoise step that iteratively denoise the latents. "
|
80
|
+
"This is a auto pipeline block that works for text2vid tasks.."
|
81
|
+
" - `WanDenoiseStep` (denoise) for text2vid tasks."
|
82
|
+
)
|
83
|
+
|
84
|
+
|
85
|
+
# decode: all task (text2img, img2img, inpainting)
|
86
|
+
class WanAutoDecodeStep(AutoPipelineBlocks):
|
87
|
+
block_classes = [WanDecodeStep]
|
88
|
+
block_names = ["non-inpaint"]
|
89
|
+
block_trigger_inputs = [None]
|
90
|
+
|
91
|
+
@property
|
92
|
+
def description(self):
|
93
|
+
return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`"
|
94
|
+
|
95
|
+
|
96
|
+
# text2vid
|
97
|
+
class WanAutoBlocks(SequentialPipelineBlocks):
|
98
|
+
block_classes = [
|
99
|
+
WanTextEncoderStep,
|
100
|
+
WanAutoBeforeDenoiseStep,
|
101
|
+
WanAutoDenoiseStep,
|
102
|
+
WanAutoDecodeStep,
|
103
|
+
]
|
104
|
+
block_names = [
|
105
|
+
"text_encoder",
|
106
|
+
"before_denoise",
|
107
|
+
"denoise",
|
108
|
+
"decoder",
|
109
|
+
]
|
110
|
+
|
111
|
+
@property
|
112
|
+
def description(self):
|
113
|
+
return (
|
114
|
+
"Auto Modular pipeline for text-to-video using Wan.\n"
|
115
|
+
+ "- for text-to-video generation, all you need to provide is `prompt`"
|
116
|
+
)
|
117
|
+
|
118
|
+
|
119
|
+
TEXT2VIDEO_BLOCKS = InsertableDict(
|
120
|
+
[
|
121
|
+
("text_encoder", WanTextEncoderStep),
|
122
|
+
("input", WanInputStep),
|
123
|
+
("set_timesteps", WanSetTimestepsStep),
|
124
|
+
("prepare_latents", WanPrepareLatentsStep),
|
125
|
+
("denoise", WanDenoiseStep),
|
126
|
+
("decode", WanDecodeStep),
|
127
|
+
]
|
128
|
+
)
|
129
|
+
|
130
|
+
|
131
|
+
AUTO_BLOCKS = InsertableDict(
|
132
|
+
[
|
133
|
+
("text_encoder", WanTextEncoderStep),
|
134
|
+
("before_denoise", WanAutoBeforeDenoiseStep),
|
135
|
+
("denoise", WanAutoDenoiseStep),
|
136
|
+
("decode", WanAutoDecodeStep),
|
137
|
+
]
|
138
|
+
)
|
139
|
+
|
140
|
+
|
141
|
+
ALL_BLOCKS = {
|
142
|
+
"text2video": TEXT2VIDEO_BLOCKS,
|
143
|
+
"auto": AUTO_BLOCKS,
|
144
|
+
}
|