diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,665 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import List, Optional, Tuple, Union
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import PIL
|
9
|
+
import torch
|
10
|
+
|
11
|
+
from ..configuration_utils import ConfigMixin
|
12
|
+
from ..image_processor import PipelineImageInput
|
13
|
+
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
|
14
|
+
from .modular_pipeline_utils import InputParam
|
15
|
+
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
# YiYi Notes: this is actually for SDXL, put it here for now
|
20
|
+
SDXL_INPUTS_SCHEMA = {
|
21
|
+
"prompt": InputParam(
|
22
|
+
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
|
23
|
+
),
|
24
|
+
"prompt_2": InputParam(
|
25
|
+
"prompt_2",
|
26
|
+
type_hint=Union[str, List[str]],
|
27
|
+
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
|
28
|
+
),
|
29
|
+
"negative_prompt": InputParam(
|
30
|
+
"negative_prompt",
|
31
|
+
type_hint=Union[str, List[str]],
|
32
|
+
description="The prompt or prompts not to guide the image generation",
|
33
|
+
),
|
34
|
+
"negative_prompt_2": InputParam(
|
35
|
+
"negative_prompt_2",
|
36
|
+
type_hint=Union[str, List[str]],
|
37
|
+
description="The negative prompt or prompts for text_encoder_2",
|
38
|
+
),
|
39
|
+
"cross_attention_kwargs": InputParam(
|
40
|
+
"cross_attention_kwargs",
|
41
|
+
type_hint=Optional[dict],
|
42
|
+
description="Kwargs dictionary passed to the AttentionProcessor",
|
43
|
+
),
|
44
|
+
"clip_skip": InputParam(
|
45
|
+
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
|
46
|
+
),
|
47
|
+
"image": InputParam(
|
48
|
+
"image",
|
49
|
+
type_hint=PipelineImageInput,
|
50
|
+
required=True,
|
51
|
+
description="The image(s) to modify for img2img or inpainting",
|
52
|
+
),
|
53
|
+
"mask_image": InputParam(
|
54
|
+
"mask_image",
|
55
|
+
type_hint=PipelineImageInput,
|
56
|
+
required=True,
|
57
|
+
description="Mask image for inpainting, white pixels will be repainted",
|
58
|
+
),
|
59
|
+
"generator": InputParam(
|
60
|
+
"generator",
|
61
|
+
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
|
62
|
+
description="Generator(s) for deterministic generation",
|
63
|
+
),
|
64
|
+
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
|
65
|
+
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
|
66
|
+
"num_images_per_prompt": InputParam(
|
67
|
+
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
|
68
|
+
),
|
69
|
+
"num_inference_steps": InputParam(
|
70
|
+
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
|
71
|
+
),
|
72
|
+
"timesteps": InputParam(
|
73
|
+
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
|
74
|
+
),
|
75
|
+
"sigmas": InputParam(
|
76
|
+
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
|
77
|
+
),
|
78
|
+
"denoising_end": InputParam(
|
79
|
+
"denoising_end",
|
80
|
+
type_hint=Optional[float],
|
81
|
+
description="Fraction of denoising process to complete before termination",
|
82
|
+
),
|
83
|
+
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
|
84
|
+
"strength": InputParam(
|
85
|
+
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
|
86
|
+
),
|
87
|
+
"denoising_start": InputParam(
|
88
|
+
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
|
89
|
+
),
|
90
|
+
"latents": InputParam(
|
91
|
+
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
|
92
|
+
),
|
93
|
+
"padding_mask_crop": InputParam(
|
94
|
+
"padding_mask_crop",
|
95
|
+
type_hint=Optional[Tuple[int, int]],
|
96
|
+
description="Size of margin in crop for image and mask",
|
97
|
+
),
|
98
|
+
"original_size": InputParam(
|
99
|
+
"original_size",
|
100
|
+
type_hint=Optional[Tuple[int, int]],
|
101
|
+
description="Original size of the image for SDXL's micro-conditioning",
|
102
|
+
),
|
103
|
+
"target_size": InputParam(
|
104
|
+
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
|
105
|
+
),
|
106
|
+
"negative_original_size": InputParam(
|
107
|
+
"negative_original_size",
|
108
|
+
type_hint=Optional[Tuple[int, int]],
|
109
|
+
description="Negative conditioning based on image resolution",
|
110
|
+
),
|
111
|
+
"negative_target_size": InputParam(
|
112
|
+
"negative_target_size",
|
113
|
+
type_hint=Optional[Tuple[int, int]],
|
114
|
+
description="Negative conditioning based on target resolution",
|
115
|
+
),
|
116
|
+
"crops_coords_top_left": InputParam(
|
117
|
+
"crops_coords_top_left",
|
118
|
+
type_hint=Tuple[int, int],
|
119
|
+
default=(0, 0),
|
120
|
+
description="Top-left coordinates for SDXL's micro-conditioning",
|
121
|
+
),
|
122
|
+
"negative_crops_coords_top_left": InputParam(
|
123
|
+
"negative_crops_coords_top_left",
|
124
|
+
type_hint=Tuple[int, int],
|
125
|
+
default=(0, 0),
|
126
|
+
description="Negative conditioning crop coordinates",
|
127
|
+
),
|
128
|
+
"aesthetic_score": InputParam(
|
129
|
+
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
|
130
|
+
),
|
131
|
+
"negative_aesthetic_score": InputParam(
|
132
|
+
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
|
133
|
+
),
|
134
|
+
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
|
135
|
+
"output_type": InputParam(
|
136
|
+
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
|
137
|
+
),
|
138
|
+
"ip_adapter_image": InputParam(
|
139
|
+
"ip_adapter_image",
|
140
|
+
type_hint=PipelineImageInput,
|
141
|
+
required=True,
|
142
|
+
description="Image(s) to be used as IP adapter",
|
143
|
+
),
|
144
|
+
"control_image": InputParam(
|
145
|
+
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
|
146
|
+
),
|
147
|
+
"control_guidance_start": InputParam(
|
148
|
+
"control_guidance_start",
|
149
|
+
type_hint=Union[float, List[float]],
|
150
|
+
default=0.0,
|
151
|
+
description="When ControlNet starts applying",
|
152
|
+
),
|
153
|
+
"control_guidance_end": InputParam(
|
154
|
+
"control_guidance_end",
|
155
|
+
type_hint=Union[float, List[float]],
|
156
|
+
default=1.0,
|
157
|
+
description="When ControlNet stops applying",
|
158
|
+
),
|
159
|
+
"controlnet_conditioning_scale": InputParam(
|
160
|
+
"controlnet_conditioning_scale",
|
161
|
+
type_hint=Union[float, List[float]],
|
162
|
+
default=1.0,
|
163
|
+
description="Scale factor for ControlNet outputs",
|
164
|
+
),
|
165
|
+
"guess_mode": InputParam(
|
166
|
+
"guess_mode",
|
167
|
+
type_hint=bool,
|
168
|
+
default=False,
|
169
|
+
description="Enables ControlNet encoder to recognize input without prompts",
|
170
|
+
),
|
171
|
+
"control_mode": InputParam(
|
172
|
+
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
173
|
+
),
|
174
|
+
}
|
175
|
+
|
176
|
+
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
177
|
+
"prompt_embeds": InputParam(
|
178
|
+
"prompt_embeds",
|
179
|
+
type_hint=torch.Tensor,
|
180
|
+
required=True,
|
181
|
+
description="Text embeddings used to guide image generation",
|
182
|
+
),
|
183
|
+
"negative_prompt_embeds": InputParam(
|
184
|
+
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
|
185
|
+
),
|
186
|
+
"pooled_prompt_embeds": InputParam(
|
187
|
+
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
|
188
|
+
),
|
189
|
+
"negative_pooled_prompt_embeds": InputParam(
|
190
|
+
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
|
191
|
+
),
|
192
|
+
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
|
193
|
+
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
194
|
+
"preprocess_kwargs": InputParam(
|
195
|
+
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
196
|
+
),
|
197
|
+
"latents": InputParam(
|
198
|
+
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
199
|
+
),
|
200
|
+
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
201
|
+
"num_inference_steps": InputParam(
|
202
|
+
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
203
|
+
),
|
204
|
+
"latent_timestep": InputParam(
|
205
|
+
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
206
|
+
),
|
207
|
+
"image_latents": InputParam(
|
208
|
+
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
|
209
|
+
),
|
210
|
+
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
|
211
|
+
"masked_image_latents": InputParam(
|
212
|
+
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
|
213
|
+
),
|
214
|
+
"add_time_ids": InputParam(
|
215
|
+
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
|
216
|
+
),
|
217
|
+
"negative_add_time_ids": InputParam(
|
218
|
+
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
|
219
|
+
),
|
220
|
+
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
|
221
|
+
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
|
222
|
+
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
|
223
|
+
"ip_adapter_embeds": InputParam(
|
224
|
+
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
|
225
|
+
),
|
226
|
+
"negative_ip_adapter_embeds": InputParam(
|
227
|
+
"negative_ip_adapter_embeds",
|
228
|
+
type_hint=List[torch.Tensor],
|
229
|
+
description="Negative image embeddings for IP-Adapter",
|
230
|
+
),
|
231
|
+
"images": InputParam(
|
232
|
+
"images",
|
233
|
+
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
234
|
+
required=True,
|
235
|
+
description="Generated images",
|
236
|
+
),
|
237
|
+
}
|
238
|
+
|
239
|
+
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
|
240
|
+
|
241
|
+
|
242
|
+
DEFAULT_PARAM_MAPS = {
|
243
|
+
"prompt": {
|
244
|
+
"label": "Prompt",
|
245
|
+
"type": "string",
|
246
|
+
"default": "a bear sitting in a chair drinking a milkshake",
|
247
|
+
"display": "textarea",
|
248
|
+
},
|
249
|
+
"negative_prompt": {
|
250
|
+
"label": "Negative Prompt",
|
251
|
+
"type": "string",
|
252
|
+
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
|
253
|
+
"display": "textarea",
|
254
|
+
},
|
255
|
+
"num_inference_steps": {
|
256
|
+
"label": "Steps",
|
257
|
+
"type": "int",
|
258
|
+
"default": 25,
|
259
|
+
"min": 1,
|
260
|
+
"max": 1000,
|
261
|
+
},
|
262
|
+
"seed": {
|
263
|
+
"label": "Seed",
|
264
|
+
"type": "int",
|
265
|
+
"default": 0,
|
266
|
+
"min": 0,
|
267
|
+
"display": "random",
|
268
|
+
},
|
269
|
+
"width": {
|
270
|
+
"label": "Width",
|
271
|
+
"type": "int",
|
272
|
+
"display": "text",
|
273
|
+
"default": 1024,
|
274
|
+
"min": 8,
|
275
|
+
"max": 8192,
|
276
|
+
"step": 8,
|
277
|
+
"group": "dimensions",
|
278
|
+
},
|
279
|
+
"height": {
|
280
|
+
"label": "Height",
|
281
|
+
"type": "int",
|
282
|
+
"display": "text",
|
283
|
+
"default": 1024,
|
284
|
+
"min": 8,
|
285
|
+
"max": 8192,
|
286
|
+
"step": 8,
|
287
|
+
"group": "dimensions",
|
288
|
+
},
|
289
|
+
"images": {
|
290
|
+
"label": "Images",
|
291
|
+
"type": "image",
|
292
|
+
"display": "output",
|
293
|
+
},
|
294
|
+
"image": {
|
295
|
+
"label": "Image",
|
296
|
+
"type": "image",
|
297
|
+
"display": "input",
|
298
|
+
},
|
299
|
+
}
|
300
|
+
|
301
|
+
DEFAULT_TYPE_MAPS = {
|
302
|
+
"int": {
|
303
|
+
"type": "int",
|
304
|
+
"default": 0,
|
305
|
+
"min": 0,
|
306
|
+
},
|
307
|
+
"float": {
|
308
|
+
"type": "float",
|
309
|
+
"default": 0.0,
|
310
|
+
"min": 0.0,
|
311
|
+
},
|
312
|
+
"str": {
|
313
|
+
"type": "string",
|
314
|
+
"default": "",
|
315
|
+
},
|
316
|
+
"bool": {
|
317
|
+
"type": "boolean",
|
318
|
+
"default": False,
|
319
|
+
},
|
320
|
+
"image": {
|
321
|
+
"type": "image",
|
322
|
+
},
|
323
|
+
}
|
324
|
+
|
325
|
+
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
|
326
|
+
DEFAULT_CATEGORY = "Modular Diffusers"
|
327
|
+
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
|
328
|
+
DEFAULT_PARAMS_GROUPS_KEYS = {
|
329
|
+
"text_encoders": ["text_encoder", "tokenizer"],
|
330
|
+
"ip_adapter_embeds": ["ip_adapter_embeds"],
|
331
|
+
"prompt_embeddings": ["prompt_embeds"],
|
332
|
+
}
|
333
|
+
|
334
|
+
|
335
|
+
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
|
336
|
+
"""
|
337
|
+
Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
|
338
|
+
"text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
|
339
|
+
"""
|
340
|
+
if name is None:
|
341
|
+
return None
|
342
|
+
for group_name, group_keys in group_params_keys.items():
|
343
|
+
for group_key in group_keys:
|
344
|
+
if group_key in name:
|
345
|
+
return group_name
|
346
|
+
return None
|
347
|
+
|
348
|
+
|
349
|
+
class ModularNode(ConfigMixin):
|
350
|
+
"""
|
351
|
+
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
|
352
|
+
around a ModularPipelineBlocks object.
|
353
|
+
|
354
|
+
<Tip warning={true}>
|
355
|
+
|
356
|
+
This is an experimental feature and is likely to change in the future.
|
357
|
+
|
358
|
+
</Tip>
|
359
|
+
"""
|
360
|
+
|
361
|
+
config_name = "node_config.json"
|
362
|
+
|
363
|
+
@classmethod
|
364
|
+
def from_pretrained(
|
365
|
+
cls,
|
366
|
+
pretrained_model_name_or_path: str,
|
367
|
+
trust_remote_code: Optional[bool] = None,
|
368
|
+
**kwargs,
|
369
|
+
):
|
370
|
+
blocks = ModularPipelineBlocks.from_pretrained(
|
371
|
+
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
372
|
+
)
|
373
|
+
return cls(blocks, **kwargs)
|
374
|
+
|
375
|
+
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
|
376
|
+
self.blocks = blocks
|
377
|
+
|
378
|
+
if label is None:
|
379
|
+
label = self.blocks.__class__.__name__
|
380
|
+
# blocks param name -> mellon param name
|
381
|
+
self.name_mapping = {}
|
382
|
+
|
383
|
+
input_params = {}
|
384
|
+
# pass or create a default param dict for each input
|
385
|
+
# e.g. for prompt,
|
386
|
+
# prompt = {
|
387
|
+
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
|
388
|
+
# "label": "Prompt",
|
389
|
+
# "type": "string",
|
390
|
+
# "default": "a bear sitting in a chair drinking a milkshake",
|
391
|
+
# "display": "textarea"}
|
392
|
+
# if type is not specified, it'll be a "custom" param of its own type
|
393
|
+
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
|
394
|
+
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
|
395
|
+
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
|
396
|
+
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
|
397
|
+
for inp in inputs:
|
398
|
+
param = kwargs.pop(inp.name, None)
|
399
|
+
if param:
|
400
|
+
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
|
401
|
+
input_params[inp.name] = param
|
402
|
+
mellon_name = param.pop("name", inp.name)
|
403
|
+
if mellon_name != inp.name:
|
404
|
+
self.name_mapping[inp.name] = mellon_name
|
405
|
+
continue
|
406
|
+
|
407
|
+
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
|
408
|
+
continue
|
409
|
+
|
410
|
+
if inp.name in DEFAULT_PARAM_MAPS:
|
411
|
+
# first check if it's in the default param map, if so, directly use that
|
412
|
+
param = DEFAULT_PARAM_MAPS[inp.name].copy()
|
413
|
+
elif get_group_name(inp.name):
|
414
|
+
param = get_group_name(inp.name)
|
415
|
+
if inp.name not in self.name_mapping:
|
416
|
+
self.name_mapping[inp.name] = param
|
417
|
+
else:
|
418
|
+
# if not, check if it's in the SDXL input schema, if so,
|
419
|
+
# 1. use the type hint to determine the type
|
420
|
+
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
|
421
|
+
if inp.type_hint is not None:
|
422
|
+
type_str = str(inp.type_hint).lower()
|
423
|
+
else:
|
424
|
+
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
|
425
|
+
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
|
426
|
+
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
|
427
|
+
if type_key in type_str:
|
428
|
+
param = type_param.copy()
|
429
|
+
param["label"] = inp.name
|
430
|
+
param["display"] = "input"
|
431
|
+
break
|
432
|
+
else:
|
433
|
+
param = inp.name
|
434
|
+
# add the param dict to the inp_params dict
|
435
|
+
input_params[inp.name] = param
|
436
|
+
|
437
|
+
component_params = {}
|
438
|
+
for comp in self.blocks.expected_components:
|
439
|
+
param = kwargs.pop(comp.name, None)
|
440
|
+
if param:
|
441
|
+
component_params[comp.name] = param
|
442
|
+
mellon_name = param.pop("name", comp.name)
|
443
|
+
if mellon_name != comp.name:
|
444
|
+
self.name_mapping[comp.name] = mellon_name
|
445
|
+
continue
|
446
|
+
|
447
|
+
to_exclude = False
|
448
|
+
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
|
449
|
+
if exclude_key in comp.name:
|
450
|
+
to_exclude = True
|
451
|
+
break
|
452
|
+
if to_exclude:
|
453
|
+
continue
|
454
|
+
|
455
|
+
if get_group_name(comp.name):
|
456
|
+
param = get_group_name(comp.name)
|
457
|
+
if comp.name not in self.name_mapping:
|
458
|
+
self.name_mapping[comp.name] = param
|
459
|
+
elif comp.name in DEFAULT_MODEL_KEYS:
|
460
|
+
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
|
461
|
+
else:
|
462
|
+
param = comp.name
|
463
|
+
# add the param dict to the model_params dict
|
464
|
+
component_params[comp.name] = param
|
465
|
+
|
466
|
+
output_params = {}
|
467
|
+
if isinstance(self.blocks, SequentialPipelineBlocks):
|
468
|
+
last_block_name = list(self.blocks.sub_blocks.keys())[-1]
|
469
|
+
outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
|
470
|
+
else:
|
471
|
+
outputs = self.blocks.intermediate_outputs
|
472
|
+
|
473
|
+
for out in outputs:
|
474
|
+
param = kwargs.pop(out.name, None)
|
475
|
+
if param:
|
476
|
+
output_params[out.name] = param
|
477
|
+
mellon_name = param.pop("name", out.name)
|
478
|
+
if mellon_name != out.name:
|
479
|
+
self.name_mapping[out.name] = mellon_name
|
480
|
+
continue
|
481
|
+
|
482
|
+
if out.name in DEFAULT_PARAM_MAPS:
|
483
|
+
param = DEFAULT_PARAM_MAPS[out.name].copy()
|
484
|
+
param["display"] = "output"
|
485
|
+
else:
|
486
|
+
group_name = get_group_name(out.name)
|
487
|
+
if group_name:
|
488
|
+
param = group_name
|
489
|
+
if out.name not in self.name_mapping:
|
490
|
+
self.name_mapping[out.name] = param
|
491
|
+
else:
|
492
|
+
param = out.name
|
493
|
+
# add the param dict to the outputs dict
|
494
|
+
output_params[out.name] = param
|
495
|
+
|
496
|
+
if len(kwargs) > 0:
|
497
|
+
logger.warning(f"Unused kwargs: {kwargs}")
|
498
|
+
|
499
|
+
register_dict = {
|
500
|
+
"category": category,
|
501
|
+
"label": label,
|
502
|
+
"input_params": input_params,
|
503
|
+
"component_params": component_params,
|
504
|
+
"output_params": output_params,
|
505
|
+
"name_mapping": self.name_mapping,
|
506
|
+
}
|
507
|
+
self.register_to_config(**register_dict)
|
508
|
+
|
509
|
+
def setup(self, components_manager, collection=None):
|
510
|
+
self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
|
511
|
+
self._components_manager = components_manager
|
512
|
+
|
513
|
+
@property
|
514
|
+
def mellon_config(self):
|
515
|
+
return self._convert_to_mellon_config()
|
516
|
+
|
517
|
+
def _convert_to_mellon_config(self):
|
518
|
+
node = {}
|
519
|
+
node["label"] = self.config.label
|
520
|
+
node["category"] = self.config.category
|
521
|
+
|
522
|
+
node_param = {}
|
523
|
+
for inp_name, inp_param in self.config.input_params.items():
|
524
|
+
if inp_name in self.name_mapping:
|
525
|
+
mellon_name = self.name_mapping[inp_name]
|
526
|
+
else:
|
527
|
+
mellon_name = inp_name
|
528
|
+
if isinstance(inp_param, str):
|
529
|
+
param = {
|
530
|
+
"label": inp_param,
|
531
|
+
"type": inp_param,
|
532
|
+
"display": "input",
|
533
|
+
}
|
534
|
+
else:
|
535
|
+
param = inp_param
|
536
|
+
|
537
|
+
if mellon_name not in node_param:
|
538
|
+
node_param[mellon_name] = param
|
539
|
+
else:
|
540
|
+
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
|
541
|
+
|
542
|
+
for comp_name, comp_param in self.config.component_params.items():
|
543
|
+
if comp_name in self.name_mapping:
|
544
|
+
mellon_name = self.name_mapping[comp_name]
|
545
|
+
else:
|
546
|
+
mellon_name = comp_name
|
547
|
+
if isinstance(comp_param, str):
|
548
|
+
param = {
|
549
|
+
"label": comp_param,
|
550
|
+
"type": comp_param,
|
551
|
+
"display": "input",
|
552
|
+
}
|
553
|
+
else:
|
554
|
+
param = comp_param
|
555
|
+
|
556
|
+
if mellon_name not in node_param:
|
557
|
+
node_param[mellon_name] = param
|
558
|
+
else:
|
559
|
+
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
|
560
|
+
|
561
|
+
for out_name, out_param in self.config.output_params.items():
|
562
|
+
if out_name in self.name_mapping:
|
563
|
+
mellon_name = self.name_mapping[out_name]
|
564
|
+
else:
|
565
|
+
mellon_name = out_name
|
566
|
+
if isinstance(out_param, str):
|
567
|
+
param = {
|
568
|
+
"label": out_param,
|
569
|
+
"type": out_param,
|
570
|
+
"display": "output",
|
571
|
+
}
|
572
|
+
else:
|
573
|
+
param = out_param
|
574
|
+
|
575
|
+
if mellon_name not in node_param:
|
576
|
+
node_param[mellon_name] = param
|
577
|
+
else:
|
578
|
+
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
|
579
|
+
node["params"] = node_param
|
580
|
+
return node
|
581
|
+
|
582
|
+
def save_mellon_config(self, file_path):
|
583
|
+
"""
|
584
|
+
Save the Mellon configuration to a JSON file.
|
585
|
+
|
586
|
+
Args:
|
587
|
+
file_path (str or Path): Path where the JSON file will be saved
|
588
|
+
|
589
|
+
Returns:
|
590
|
+
Path: Path to the saved config file
|
591
|
+
"""
|
592
|
+
file_path = Path(file_path)
|
593
|
+
|
594
|
+
# Create directory if it doesn't exist
|
595
|
+
os.makedirs(file_path.parent, exist_ok=True)
|
596
|
+
|
597
|
+
# Create a combined dictionary with module definition and name mapping
|
598
|
+
config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
|
599
|
+
|
600
|
+
# Save the config to file
|
601
|
+
with open(file_path, "w", encoding="utf-8") as f:
|
602
|
+
json.dump(config, f, indent=2)
|
603
|
+
|
604
|
+
logger.info(f"Mellon config and name mapping saved to {file_path}")
|
605
|
+
|
606
|
+
return file_path
|
607
|
+
|
608
|
+
@classmethod
|
609
|
+
def load_mellon_config(cls, file_path):
|
610
|
+
"""
|
611
|
+
Load a Mellon configuration from a JSON file.
|
612
|
+
|
613
|
+
Args:
|
614
|
+
file_path (str or Path): Path to the JSON file containing Mellon config
|
615
|
+
|
616
|
+
Returns:
|
617
|
+
dict: The loaded combined configuration containing 'module' and 'name_mapping'
|
618
|
+
"""
|
619
|
+
file_path = Path(file_path)
|
620
|
+
|
621
|
+
if not file_path.exists():
|
622
|
+
raise FileNotFoundError(f"Config file not found: {file_path}")
|
623
|
+
|
624
|
+
with open(file_path, "r", encoding="utf-8") as f:
|
625
|
+
config = json.load(f)
|
626
|
+
|
627
|
+
logger.info(f"Mellon config loaded from {file_path}")
|
628
|
+
|
629
|
+
return config
|
630
|
+
|
631
|
+
def process_inputs(self, **kwargs):
|
632
|
+
params_components = {}
|
633
|
+
for comp_name, comp_param in self.config.component_params.items():
|
634
|
+
logger.debug(f"component: {comp_name}")
|
635
|
+
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
|
636
|
+
if mellon_comp_name in kwargs:
|
637
|
+
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
|
638
|
+
comp = kwargs[mellon_comp_name].pop(comp_name)
|
639
|
+
else:
|
640
|
+
comp = kwargs.pop(mellon_comp_name)
|
641
|
+
if comp:
|
642
|
+
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
|
643
|
+
|
644
|
+
params_run = {}
|
645
|
+
for inp_name, inp_param in self.config.input_params.items():
|
646
|
+
logger.debug(f"input: {inp_name}")
|
647
|
+
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
|
648
|
+
if mellon_inp_name in kwargs:
|
649
|
+
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
|
650
|
+
inp = kwargs[mellon_inp_name].pop(inp_name)
|
651
|
+
else:
|
652
|
+
inp = kwargs.pop(mellon_inp_name)
|
653
|
+
if inp is not None:
|
654
|
+
params_run[inp_name] = inp
|
655
|
+
|
656
|
+
return_output_names = list(self.config.output_params.keys())
|
657
|
+
|
658
|
+
return params_components, params_run, return_output_names
|
659
|
+
|
660
|
+
def execute(self, **kwargs):
|
661
|
+
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
|
662
|
+
|
663
|
+
self.pipeline.update_components(**params_components)
|
664
|
+
output = self.pipeline(**params_run, output=return_output_names)
|
665
|
+
return output
|