diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +72 -26
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,492 @@
|
|
1
|
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from typing import List, Optional, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from PIL import Image
|
20
|
+
from transformers import (
|
21
|
+
CLIPTextModel,
|
22
|
+
CLIPTokenizer,
|
23
|
+
SiglipImageProcessor,
|
24
|
+
SiglipVisionModel,
|
25
|
+
T5EncoderModel,
|
26
|
+
T5TokenizerFast,
|
27
|
+
)
|
28
|
+
|
29
|
+
from ...image_processor import PipelineImageInput
|
30
|
+
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
31
|
+
from ...utils import (
|
32
|
+
USE_PEFT_BACKEND,
|
33
|
+
is_torch_xla_available,
|
34
|
+
logging,
|
35
|
+
replace_example_docstring,
|
36
|
+
scale_lora_layers,
|
37
|
+
unscale_lora_layers,
|
38
|
+
)
|
39
|
+
from ..pipeline_utils import DiffusionPipeline
|
40
|
+
from .modeling_flux import ReduxImageEncoder
|
41
|
+
from .pipeline_output import FluxPriorReduxPipelineOutput
|
42
|
+
|
43
|
+
|
44
|
+
if is_torch_xla_available():
|
45
|
+
XLA_AVAILABLE = True
|
46
|
+
else:
|
47
|
+
XLA_AVAILABLE = False
|
48
|
+
|
49
|
+
|
50
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
51
|
+
|
52
|
+
EXAMPLE_DOC_STRING = """
|
53
|
+
Examples:
|
54
|
+
```py
|
55
|
+
>>> import torch
|
56
|
+
>>> from diffusers import FluxPriorReduxPipeline, FluxPipeline
|
57
|
+
>>> from diffusers.utils import load_image
|
58
|
+
|
59
|
+
>>> device = "cuda"
|
60
|
+
>>> dtype = torch.bfloat16
|
61
|
+
|
62
|
+
>>> repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
|
63
|
+
>>> repo_base = "black-forest-labs/FLUX.1-dev"
|
64
|
+
>>> pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
|
65
|
+
>>> pipe = FluxPipeline.from_pretrained(
|
66
|
+
... repo_base, text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16
|
67
|
+
... ).to(device)
|
68
|
+
|
69
|
+
>>> image = load_image(
|
70
|
+
... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png"
|
71
|
+
... )
|
72
|
+
>>> pipe_prior_output = pipe_prior_redux(image)
|
73
|
+
>>> images = pipe(
|
74
|
+
... guidance_scale=2.5,
|
75
|
+
... num_inference_steps=50,
|
76
|
+
... generator=torch.Generator("cpu").manual_seed(0),
|
77
|
+
... **pipe_prior_output,
|
78
|
+
... ).images
|
79
|
+
>>> images[0].save("flux-redux.png")
|
80
|
+
```
|
81
|
+
"""
|
82
|
+
|
83
|
+
|
84
|
+
class FluxPriorReduxPipeline(DiffusionPipeline):
|
85
|
+
r"""
|
86
|
+
The Flux Redux pipeline for image-to-image generation.
|
87
|
+
|
88
|
+
Reference: https://blackforestlabs.ai/flux-1-tools/
|
89
|
+
|
90
|
+
Args:
|
91
|
+
image_encoder ([`SiglipVisionModel`]):
|
92
|
+
SIGLIP vision model to encode the input image.
|
93
|
+
feature_extractor ([`SiglipImageProcessor`]):
|
94
|
+
Image processor for preprocessing images for the SIGLIP model.
|
95
|
+
image_embedder ([`ReduxImageEncoder`]):
|
96
|
+
Redux image encoder to process the SIGLIP embeddings.
|
97
|
+
text_encoder ([`CLIPTextModel`], *optional*):
|
98
|
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
99
|
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
100
|
+
text_encoder_2 ([`T5EncoderModel`], *optional*):
|
101
|
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
102
|
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
103
|
+
tokenizer (`CLIPTokenizer`, *optional*):
|
104
|
+
Tokenizer of class
|
105
|
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
106
|
+
tokenizer_2 (`T5TokenizerFast`, *optional*):
|
107
|
+
Second Tokenizer of class
|
108
|
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
109
|
+
"""
|
110
|
+
|
111
|
+
model_cpu_offload_seq = "image_encoder->image_embedder"
|
112
|
+
_optional_components = [
|
113
|
+
"text_encoder",
|
114
|
+
"tokenizer",
|
115
|
+
"text_encoder_2",
|
116
|
+
"tokenizer_2",
|
117
|
+
]
|
118
|
+
_callback_tensor_inputs = []
|
119
|
+
|
120
|
+
def __init__(
|
121
|
+
self,
|
122
|
+
image_encoder: SiglipVisionModel,
|
123
|
+
feature_extractor: SiglipImageProcessor,
|
124
|
+
image_embedder: ReduxImageEncoder,
|
125
|
+
text_encoder: CLIPTextModel = None,
|
126
|
+
tokenizer: CLIPTokenizer = None,
|
127
|
+
text_encoder_2: T5EncoderModel = None,
|
128
|
+
tokenizer_2: T5TokenizerFast = None,
|
129
|
+
):
|
130
|
+
super().__init__()
|
131
|
+
|
132
|
+
self.register_modules(
|
133
|
+
image_encoder=image_encoder,
|
134
|
+
feature_extractor=feature_extractor,
|
135
|
+
image_embedder=image_embedder,
|
136
|
+
text_encoder=text_encoder,
|
137
|
+
tokenizer=tokenizer,
|
138
|
+
text_encoder_2=text_encoder_2,
|
139
|
+
tokenizer_2=tokenizer_2,
|
140
|
+
)
|
141
|
+
self.tokenizer_max_length = (
|
142
|
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
143
|
+
)
|
144
|
+
|
145
|
+
def check_inputs(
|
146
|
+
self,
|
147
|
+
image,
|
148
|
+
prompt,
|
149
|
+
prompt_2,
|
150
|
+
prompt_embeds=None,
|
151
|
+
pooled_prompt_embeds=None,
|
152
|
+
prompt_embeds_scale=1.0,
|
153
|
+
pooled_prompt_embeds_scale=1.0,
|
154
|
+
):
|
155
|
+
if prompt is not None and prompt_embeds is not None:
|
156
|
+
raise ValueError(
|
157
|
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
158
|
+
" only forward one of the two."
|
159
|
+
)
|
160
|
+
elif prompt_2 is not None and prompt_embeds is not None:
|
161
|
+
raise ValueError(
|
162
|
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
163
|
+
" only forward one of the two."
|
164
|
+
)
|
165
|
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
166
|
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
167
|
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
168
|
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
169
|
+
if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)):
|
170
|
+
raise ValueError(
|
171
|
+
f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images"
|
172
|
+
)
|
173
|
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
174
|
+
raise ValueError(
|
175
|
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
176
|
+
)
|
177
|
+
if isinstance(prompt_embeds_scale, list) and (
|
178
|
+
isinstance(image, list) and len(prompt_embeds_scale) != len(image)
|
179
|
+
):
|
180
|
+
raise ValueError(
|
181
|
+
f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images"
|
182
|
+
)
|
183
|
+
|
184
|
+
def encode_image(self, image, device, num_images_per_prompt):
|
185
|
+
dtype = next(self.image_encoder.parameters()).dtype
|
186
|
+
image = self.feature_extractor.preprocess(
|
187
|
+
images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True
|
188
|
+
)
|
189
|
+
image = image.to(device=device, dtype=dtype)
|
190
|
+
|
191
|
+
image_enc_hidden_states = self.image_encoder(**image).last_hidden_state
|
192
|
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
193
|
+
|
194
|
+
return image_enc_hidden_states
|
195
|
+
|
196
|
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
|
197
|
+
def _get_t5_prompt_embeds(
|
198
|
+
self,
|
199
|
+
prompt: Union[str, List[str]] = None,
|
200
|
+
num_images_per_prompt: int = 1,
|
201
|
+
max_sequence_length: int = 512,
|
202
|
+
device: Optional[torch.device] = None,
|
203
|
+
dtype: Optional[torch.dtype] = None,
|
204
|
+
):
|
205
|
+
device = device or self._execution_device
|
206
|
+
dtype = dtype or self.text_encoder.dtype
|
207
|
+
|
208
|
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
209
|
+
batch_size = len(prompt)
|
210
|
+
|
211
|
+
if isinstance(self, TextualInversionLoaderMixin):
|
212
|
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
213
|
+
|
214
|
+
text_inputs = self.tokenizer_2(
|
215
|
+
prompt,
|
216
|
+
padding="max_length",
|
217
|
+
max_length=max_sequence_length,
|
218
|
+
truncation=True,
|
219
|
+
return_length=False,
|
220
|
+
return_overflowing_tokens=False,
|
221
|
+
return_tensors="pt",
|
222
|
+
)
|
223
|
+
text_input_ids = text_inputs.input_ids
|
224
|
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
225
|
+
|
226
|
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
227
|
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
228
|
+
logger.warning(
|
229
|
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
230
|
+
f" {max_sequence_length} tokens: {removed_text}"
|
231
|
+
)
|
232
|
+
|
233
|
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
234
|
+
|
235
|
+
dtype = self.text_encoder_2.dtype
|
236
|
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
237
|
+
|
238
|
+
_, seq_len, _ = prompt_embeds.shape
|
239
|
+
|
240
|
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
241
|
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
242
|
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
243
|
+
|
244
|
+
return prompt_embeds
|
245
|
+
|
246
|
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
|
247
|
+
def _get_clip_prompt_embeds(
|
248
|
+
self,
|
249
|
+
prompt: Union[str, List[str]],
|
250
|
+
num_images_per_prompt: int = 1,
|
251
|
+
device: Optional[torch.device] = None,
|
252
|
+
):
|
253
|
+
device = device or self._execution_device
|
254
|
+
|
255
|
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
256
|
+
batch_size = len(prompt)
|
257
|
+
|
258
|
+
if isinstance(self, TextualInversionLoaderMixin):
|
259
|
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
260
|
+
|
261
|
+
text_inputs = self.tokenizer(
|
262
|
+
prompt,
|
263
|
+
padding="max_length",
|
264
|
+
max_length=self.tokenizer_max_length,
|
265
|
+
truncation=True,
|
266
|
+
return_overflowing_tokens=False,
|
267
|
+
return_length=False,
|
268
|
+
return_tensors="pt",
|
269
|
+
)
|
270
|
+
|
271
|
+
text_input_ids = text_inputs.input_ids
|
272
|
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
273
|
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
274
|
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
275
|
+
logger.warning(
|
276
|
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
277
|
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
278
|
+
)
|
279
|
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
280
|
+
|
281
|
+
# Use pooled output of CLIPTextModel
|
282
|
+
prompt_embeds = prompt_embeds.pooler_output
|
283
|
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
284
|
+
|
285
|
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
286
|
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
287
|
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
288
|
+
|
289
|
+
return prompt_embeds
|
290
|
+
|
291
|
+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
|
292
|
+
def encode_prompt(
|
293
|
+
self,
|
294
|
+
prompt: Union[str, List[str]],
|
295
|
+
prompt_2: Union[str, List[str]],
|
296
|
+
device: Optional[torch.device] = None,
|
297
|
+
num_images_per_prompt: int = 1,
|
298
|
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
299
|
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
300
|
+
max_sequence_length: int = 512,
|
301
|
+
lora_scale: Optional[float] = None,
|
302
|
+
):
|
303
|
+
r"""
|
304
|
+
|
305
|
+
Args:
|
306
|
+
prompt (`str` or `List[str]`, *optional*):
|
307
|
+
prompt to be encoded
|
308
|
+
prompt_2 (`str` or `List[str]`, *optional*):
|
309
|
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
310
|
+
used in all text-encoders
|
311
|
+
device: (`torch.device`):
|
312
|
+
torch device
|
313
|
+
num_images_per_prompt (`int`):
|
314
|
+
number of images that should be generated per prompt
|
315
|
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
316
|
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
317
|
+
provided, text embeddings will be generated from `prompt` input argument.
|
318
|
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
319
|
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
320
|
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
321
|
+
lora_scale (`float`, *optional*):
|
322
|
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
323
|
+
"""
|
324
|
+
device = device or self._execution_device
|
325
|
+
|
326
|
+
# set lora scale so that monkey patched LoRA
|
327
|
+
# function of text encoder can correctly access it
|
328
|
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
329
|
+
self._lora_scale = lora_scale
|
330
|
+
|
331
|
+
# dynamically adjust the LoRA scale
|
332
|
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
333
|
+
scale_lora_layers(self.text_encoder, lora_scale)
|
334
|
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
335
|
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
336
|
+
|
337
|
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
338
|
+
|
339
|
+
if prompt_embeds is None:
|
340
|
+
prompt_2 = prompt_2 or prompt
|
341
|
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
342
|
+
|
343
|
+
# We only use the pooled prompt output from the CLIPTextModel
|
344
|
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
345
|
+
prompt=prompt,
|
346
|
+
device=device,
|
347
|
+
num_images_per_prompt=num_images_per_prompt,
|
348
|
+
)
|
349
|
+
prompt_embeds = self._get_t5_prompt_embeds(
|
350
|
+
prompt=prompt_2,
|
351
|
+
num_images_per_prompt=num_images_per_prompt,
|
352
|
+
max_sequence_length=max_sequence_length,
|
353
|
+
device=device,
|
354
|
+
)
|
355
|
+
|
356
|
+
if self.text_encoder is not None:
|
357
|
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
358
|
+
# Retrieve the original scale by scaling back the LoRA layers
|
359
|
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
360
|
+
|
361
|
+
if self.text_encoder_2 is not None:
|
362
|
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
363
|
+
# Retrieve the original scale by scaling back the LoRA layers
|
364
|
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
365
|
+
|
366
|
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
367
|
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
368
|
+
|
369
|
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
370
|
+
|
371
|
+
@torch.no_grad()
|
372
|
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
373
|
+
def __call__(
|
374
|
+
self,
|
375
|
+
image: PipelineImageInput,
|
376
|
+
prompt: Union[str, List[str]] = None,
|
377
|
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
378
|
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
379
|
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
380
|
+
prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
|
381
|
+
pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
|
382
|
+
return_dict: bool = True,
|
383
|
+
):
|
384
|
+
r"""
|
385
|
+
Function invoked when calling the pipeline for generation.
|
386
|
+
|
387
|
+
Args:
|
388
|
+
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
389
|
+
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
390
|
+
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
|
391
|
+
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
|
392
|
+
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
|
393
|
+
prompt (`str` or `List[str]`, *optional*):
|
394
|
+
The prompt or prompts to guide the image generation. **experimental feature**: to use this feature,
|
395
|
+
make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders
|
396
|
+
are not loaded.
|
397
|
+
prompt_2 (`str` or `List[str]`, *optional*):
|
398
|
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
|
399
|
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
400
|
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
401
|
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
402
|
+
Pre-generated pooled text embeddings.
|
403
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
404
|
+
Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.
|
405
|
+
|
406
|
+
Examples:
|
407
|
+
|
408
|
+
Returns:
|
409
|
+
[`~pipelines.flux.FluxPriorReduxPipelineOutput`] or `tuple`:
|
410
|
+
[`~pipelines.flux.FluxPriorReduxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
411
|
+
returning a tuple, the first element is a list with the generated images.
|
412
|
+
"""
|
413
|
+
|
414
|
+
# 1. Check inputs. Raise error if not correct
|
415
|
+
self.check_inputs(
|
416
|
+
image,
|
417
|
+
prompt,
|
418
|
+
prompt_2,
|
419
|
+
prompt_embeds=prompt_embeds,
|
420
|
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
421
|
+
prompt_embeds_scale=prompt_embeds_scale,
|
422
|
+
pooled_prompt_embeds_scale=pooled_prompt_embeds_scale,
|
423
|
+
)
|
424
|
+
|
425
|
+
# 2. Define call parameters
|
426
|
+
if image is not None and isinstance(image, Image.Image):
|
427
|
+
batch_size = 1
|
428
|
+
elif image is not None and isinstance(image, list):
|
429
|
+
batch_size = len(image)
|
430
|
+
else:
|
431
|
+
batch_size = image.shape[0]
|
432
|
+
if prompt is not None and isinstance(prompt, str):
|
433
|
+
prompt = batch_size * [prompt]
|
434
|
+
if isinstance(prompt_embeds_scale, float):
|
435
|
+
prompt_embeds_scale = batch_size * [prompt_embeds_scale]
|
436
|
+
if isinstance(pooled_prompt_embeds_scale, float):
|
437
|
+
pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale]
|
438
|
+
|
439
|
+
device = self._execution_device
|
440
|
+
|
441
|
+
# 3. Prepare image embeddings
|
442
|
+
image_latents = self.encode_image(image, device, 1)
|
443
|
+
|
444
|
+
image_embeds = self.image_embedder(image_latents).image_embeds
|
445
|
+
image_embeds = image_embeds.to(device=device)
|
446
|
+
|
447
|
+
# 3. Prepare (dummy) text embeddings
|
448
|
+
if hasattr(self, "text_encoder") and self.text_encoder is not None:
|
449
|
+
(
|
450
|
+
prompt_embeds,
|
451
|
+
pooled_prompt_embeds,
|
452
|
+
_,
|
453
|
+
) = self.encode_prompt(
|
454
|
+
prompt=prompt,
|
455
|
+
prompt_2=prompt_2,
|
456
|
+
prompt_embeds=prompt_embeds,
|
457
|
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
458
|
+
device=device,
|
459
|
+
num_images_per_prompt=1,
|
460
|
+
max_sequence_length=512,
|
461
|
+
lora_scale=None,
|
462
|
+
)
|
463
|
+
else:
|
464
|
+
if prompt is not None:
|
465
|
+
logger.warning(
|
466
|
+
"prompt input is ignored when text encoders are not loaded to the pipeline. "
|
467
|
+
"Make sure to explicitly load the text encoders to enable prompt input. "
|
468
|
+
)
|
469
|
+
# max_sequence_length is 512, t5 encoder hidden size is 4096
|
470
|
+
prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype)
|
471
|
+
# pooled_prompt_embeds is 768, clip text encoder hidden size
|
472
|
+
pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)
|
473
|
+
|
474
|
+
# scale & concatenate image and text embeddings
|
475
|
+
prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
|
476
|
+
|
477
|
+
prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None]
|
478
|
+
pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[
|
479
|
+
:, None
|
480
|
+
]
|
481
|
+
|
482
|
+
# weighted sum
|
483
|
+
prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True)
|
484
|
+
pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True)
|
485
|
+
|
486
|
+
# Offload all models
|
487
|
+
self.maybe_free_model_hooks()
|
488
|
+
|
489
|
+
if not return_dict:
|
490
|
+
return (prompt_embeds, pooled_prompt_embeds)
|
491
|
+
|
492
|
+
return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds)
|
@@ -3,6 +3,7 @@ from typing import List, Union
|
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import PIL.Image
|
6
|
+
import torch
|
6
7
|
|
7
8
|
from ...utils import BaseOutput
|
8
9
|
|
@@ -19,3 +20,18 @@ class FluxPipelineOutput(BaseOutput):
|
|
19
20
|
"""
|
20
21
|
|
21
22
|
images: Union[List[PIL.Image.Image], np.ndarray]
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class FluxPriorReduxPipelineOutput(BaseOutput):
|
27
|
+
"""
|
28
|
+
Output class for Flux Prior Redux pipelines.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
32
|
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
33
|
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
34
|
+
"""
|
35
|
+
|
36
|
+
prompt_embeds: torch.Tensor
|
37
|
+
pooled_prompt_embeds: torch.Tensor
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
from ...utils import (
|
4
|
+
DIFFUSERS_SLOW_IMPORT,
|
5
|
+
OptionalDependencyNotAvailable,
|
6
|
+
_LazyModule,
|
7
|
+
get_objects_from_module,
|
8
|
+
is_torch_available,
|
9
|
+
is_transformers_available,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
_dummy_objects = {}
|
14
|
+
_import_structure = {}
|
15
|
+
|
16
|
+
|
17
|
+
try:
|
18
|
+
if not (is_transformers_available() and is_torch_available()):
|
19
|
+
raise OptionalDependencyNotAvailable()
|
20
|
+
except OptionalDependencyNotAvailable:
|
21
|
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
22
|
+
|
23
|
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
24
|
+
else:
|
25
|
+
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
|
26
|
+
|
27
|
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
28
|
+
try:
|
29
|
+
if not (is_transformers_available() and is_torch_available()):
|
30
|
+
raise OptionalDependencyNotAvailable()
|
31
|
+
|
32
|
+
except OptionalDependencyNotAvailable:
|
33
|
+
from ...utils.dummy_torch_and_transformers_objects import *
|
34
|
+
else:
|
35
|
+
from .pipeline_hunyuan_video import HunyuanVideoPipeline
|
36
|
+
|
37
|
+
else:
|
38
|
+
import sys
|
39
|
+
|
40
|
+
sys.modules[__name__] = _LazyModule(
|
41
|
+
__name__,
|
42
|
+
globals()["__file__"],
|
43
|
+
_import_structure,
|
44
|
+
module_spec=__spec__,
|
45
|
+
)
|
46
|
+
|
47
|
+
for name, value in _dummy_objects.items():
|
48
|
+
setattr(sys.modules[__name__], name, value)
|