diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import torch
|
20
|
-
from transformers import
|
20
|
+
from transformers import (
|
21
|
+
CLIPImageProcessor,
|
22
|
+
CLIPTextModel,
|
23
|
+
CLIPTokenizer,
|
24
|
+
CLIPVisionModelWithProjection,
|
25
|
+
T5EncoderModel,
|
26
|
+
T5TokenizerFast,
|
27
|
+
)
|
21
28
|
|
22
|
-
from ...image_processor import VaeImageProcessor
|
23
|
-
from ...loaders import FluxLoraLoaderMixin
|
29
|
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
30
|
+
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
24
31
|
from ...models.autoencoders import AutoencoderKL
|
25
32
|
from ...models.transformers import FluxTransformer2DModel
|
26
33
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
@@ -86,7 +93,7 @@ def retrieve_timesteps(
|
|
86
93
|
sigmas: Optional[List[float]] = None,
|
87
94
|
**kwargs,
|
88
95
|
):
|
89
|
-
"""
|
96
|
+
r"""
|
90
97
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
91
98
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
92
99
|
|
@@ -137,7 +144,13 @@ def retrieve_timesteps(
|
|
137
144
|
return timesteps, num_inference_steps
|
138
145
|
|
139
146
|
|
140
|
-
class FluxPipeline(
|
147
|
+
class FluxPipeline(
|
148
|
+
DiffusionPipeline,
|
149
|
+
FluxLoraLoaderMixin,
|
150
|
+
FromSingleFileMixin,
|
151
|
+
TextualInversionLoaderMixin,
|
152
|
+
FluxIPAdapterMixin,
|
153
|
+
):
|
141
154
|
r"""
|
142
155
|
The Flux pipeline for text-to-image generation.
|
143
156
|
|
@@ -164,8 +177,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
164
177
|
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
165
178
|
"""
|
166
179
|
|
167
|
-
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
168
|
-
_optional_components = []
|
180
|
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
181
|
+
_optional_components = ["image_encoder", "feature_extractor"]
|
169
182
|
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
170
183
|
|
171
184
|
def __init__(
|
@@ -177,6 +190,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
177
190
|
text_encoder_2: T5EncoderModel,
|
178
191
|
tokenizer_2: T5TokenizerFast,
|
179
192
|
transformer: FluxTransformer2DModel,
|
193
|
+
image_encoder: CLIPVisionModelWithProjection = None,
|
194
|
+
feature_extractor: CLIPImageProcessor = None,
|
180
195
|
):
|
181
196
|
super().__init__()
|
182
197
|
|
@@ -188,15 +203,19 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
188
203
|
tokenizer_2=tokenizer_2,
|
189
204
|
transformer=transformer,
|
190
205
|
scheduler=scheduler,
|
206
|
+
image_encoder=image_encoder,
|
207
|
+
feature_extractor=feature_extractor,
|
191
208
|
)
|
192
209
|
self.vae_scale_factor = (
|
193
|
-
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else
|
210
|
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
194
211
|
)
|
195
|
-
|
212
|
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
213
|
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
214
|
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
196
215
|
self.tokenizer_max_length = (
|
197
216
|
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
198
217
|
)
|
199
|
-
self.default_sample_size =
|
218
|
+
self.default_sample_size = 128
|
200
219
|
|
201
220
|
def _get_t5_prompt_embeds(
|
202
221
|
self,
|
@@ -212,6 +231,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
212
231
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
213
232
|
batch_size = len(prompt)
|
214
233
|
|
234
|
+
if isinstance(self, TextualInversionLoaderMixin):
|
235
|
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
236
|
+
|
215
237
|
text_inputs = self.tokenizer_2(
|
216
238
|
prompt,
|
217
239
|
padding="max_length",
|
@@ -255,6 +277,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
255
277
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
256
278
|
batch_size = len(prompt)
|
257
279
|
|
280
|
+
if isinstance(self, TextualInversionLoaderMixin):
|
281
|
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
282
|
+
|
258
283
|
text_inputs = self.tokenizer(
|
259
284
|
prompt,
|
260
285
|
padding="max_length",
|
@@ -331,10 +356,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
331
356
|
scale_lora_layers(self.text_encoder_2, lora_scale)
|
332
357
|
|
333
358
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
334
|
-
if prompt is not None:
|
335
|
-
batch_size = len(prompt)
|
336
|
-
else:
|
337
|
-
batch_size = prompt_embeds.shape[0]
|
338
359
|
|
339
360
|
if prompt_embeds is None:
|
340
361
|
prompt_2 = prompt_2 or prompt
|
@@ -364,24 +385,71 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
364
385
|
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
365
386
|
|
366
387
|
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
367
|
-
text_ids = torch.zeros(
|
368
|
-
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
388
|
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
369
389
|
|
370
390
|
return prompt_embeds, pooled_prompt_embeds, text_ids
|
371
391
|
|
392
|
+
def encode_image(self, image, device, num_images_per_prompt):
|
393
|
+
dtype = next(self.image_encoder.parameters()).dtype
|
394
|
+
|
395
|
+
if not isinstance(image, torch.Tensor):
|
396
|
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
397
|
+
|
398
|
+
image = image.to(device=device, dtype=dtype)
|
399
|
+
image_embeds = self.image_encoder(image).image_embeds
|
400
|
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
401
|
+
return image_embeds
|
402
|
+
|
403
|
+
def prepare_ip_adapter_image_embeds(
|
404
|
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
405
|
+
):
|
406
|
+
image_embeds = []
|
407
|
+
if ip_adapter_image_embeds is None:
|
408
|
+
if not isinstance(ip_adapter_image, list):
|
409
|
+
ip_adapter_image = [ip_adapter_image]
|
410
|
+
|
411
|
+
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
|
412
|
+
raise ValueError(
|
413
|
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
414
|
+
)
|
415
|
+
|
416
|
+
for single_ip_adapter_image, image_proj_layer in zip(
|
417
|
+
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
|
418
|
+
):
|
419
|
+
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
420
|
+
|
421
|
+
image_embeds.append(single_image_embeds[None, :])
|
422
|
+
else:
|
423
|
+
for single_image_embeds in ip_adapter_image_embeds:
|
424
|
+
image_embeds.append(single_image_embeds)
|
425
|
+
|
426
|
+
ip_adapter_image_embeds = []
|
427
|
+
for i, single_image_embeds in enumerate(image_embeds):
|
428
|
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
429
|
+
single_image_embeds = single_image_embeds.to(device=device)
|
430
|
+
ip_adapter_image_embeds.append(single_image_embeds)
|
431
|
+
|
432
|
+
return ip_adapter_image_embeds
|
433
|
+
|
372
434
|
def check_inputs(
|
373
435
|
self,
|
374
436
|
prompt,
|
375
437
|
prompt_2,
|
376
438
|
height,
|
377
439
|
width,
|
440
|
+
negative_prompt=None,
|
441
|
+
negative_prompt_2=None,
|
378
442
|
prompt_embeds=None,
|
443
|
+
negative_prompt_embeds=None,
|
379
444
|
pooled_prompt_embeds=None,
|
445
|
+
negative_pooled_prompt_embeds=None,
|
380
446
|
callback_on_step_end_tensor_inputs=None,
|
381
447
|
max_sequence_length=None,
|
382
448
|
):
|
383
|
-
if height %
|
384
|
-
|
449
|
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
450
|
+
logger.warning(
|
451
|
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
452
|
+
)
|
385
453
|
|
386
454
|
if callback_on_step_end_tensor_inputs is not None and not all(
|
387
455
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
@@ -409,25 +477,47 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
409
477
|
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
410
478
|
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
411
479
|
|
480
|
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
481
|
+
raise ValueError(
|
482
|
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
483
|
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
484
|
+
)
|
485
|
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
486
|
+
raise ValueError(
|
487
|
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
488
|
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
489
|
+
)
|
490
|
+
|
491
|
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
492
|
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
493
|
+
raise ValueError(
|
494
|
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
495
|
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
496
|
+
f" {negative_prompt_embeds.shape}."
|
497
|
+
)
|
498
|
+
|
412
499
|
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
413
500
|
raise ValueError(
|
414
501
|
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
415
502
|
)
|
503
|
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
504
|
+
raise ValueError(
|
505
|
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
506
|
+
)
|
416
507
|
|
417
508
|
if max_sequence_length is not None and max_sequence_length > 512:
|
418
509
|
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
419
510
|
|
420
511
|
@staticmethod
|
421
512
|
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
422
|
-
latent_image_ids = torch.zeros(height
|
423
|
-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height
|
424
|
-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width
|
513
|
+
latent_image_ids = torch.zeros(height, width, 3)
|
514
|
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
515
|
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
425
516
|
|
426
517
|
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
427
518
|
|
428
|
-
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
429
519
|
latent_image_ids = latent_image_ids.reshape(
|
430
|
-
|
520
|
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
431
521
|
)
|
432
522
|
|
433
523
|
return latent_image_ids.to(device=device, dtype=dtype)
|
@@ -444,16 +534,47 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
444
534
|
def _unpack_latents(latents, height, width, vae_scale_factor):
|
445
535
|
batch_size, num_patches, channels = latents.shape
|
446
536
|
|
447
|
-
|
448
|
-
|
537
|
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
538
|
+
# latent height and width to be divisible by 2.
|
539
|
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
540
|
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
449
541
|
|
450
|
-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
542
|
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
451
543
|
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
452
544
|
|
453
|
-
latents = latents.reshape(batch_size, channels // (2 * 2), height
|
545
|
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
454
546
|
|
455
547
|
return latents
|
456
548
|
|
549
|
+
def enable_vae_slicing(self):
|
550
|
+
r"""
|
551
|
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
552
|
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
553
|
+
"""
|
554
|
+
self.vae.enable_slicing()
|
555
|
+
|
556
|
+
def disable_vae_slicing(self):
|
557
|
+
r"""
|
558
|
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
559
|
+
computing decoding in one step.
|
560
|
+
"""
|
561
|
+
self.vae.disable_slicing()
|
562
|
+
|
563
|
+
def enable_vae_tiling(self):
|
564
|
+
r"""
|
565
|
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
566
|
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
567
|
+
processing larger images.
|
568
|
+
"""
|
569
|
+
self.vae.enable_tiling()
|
570
|
+
|
571
|
+
def disable_vae_tiling(self):
|
572
|
+
r"""
|
573
|
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
574
|
+
computing decoding in one step.
|
575
|
+
"""
|
576
|
+
self.vae.disable_tiling()
|
577
|
+
|
457
578
|
def prepare_latents(
|
458
579
|
self,
|
459
580
|
batch_size,
|
@@ -465,13 +586,15 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
465
586
|
generator,
|
466
587
|
latents=None,
|
467
588
|
):
|
468
|
-
|
469
|
-
|
589
|
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
590
|
+
# latent height and width to be divisible by 2.
|
591
|
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
592
|
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
470
593
|
|
471
594
|
shape = (batch_size, num_channels_latents, height, width)
|
472
595
|
|
473
596
|
if latents is not None:
|
474
|
-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
597
|
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
475
598
|
return latents.to(device=device, dtype=dtype), latent_image_ids
|
476
599
|
|
477
600
|
if isinstance(generator, list) and len(generator) != batch_size:
|
@@ -483,7 +606,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
483
606
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
484
607
|
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
485
608
|
|
486
|
-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
609
|
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
487
610
|
|
488
611
|
return latents, latent_image_ids
|
489
612
|
|
@@ -509,16 +632,25 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
509
632
|
self,
|
510
633
|
prompt: Union[str, List[str]] = None,
|
511
634
|
prompt_2: Optional[Union[str, List[str]]] = None,
|
635
|
+
negative_prompt: Union[str, List[str]] = None,
|
636
|
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
637
|
+
true_cfg_scale: float = 1.0,
|
512
638
|
height: Optional[int] = None,
|
513
639
|
width: Optional[int] = None,
|
514
640
|
num_inference_steps: int = 28,
|
515
|
-
|
516
|
-
guidance_scale: float =
|
641
|
+
sigmas: Optional[List[float]] = None,
|
642
|
+
guidance_scale: float = 3.5,
|
517
643
|
num_images_per_prompt: Optional[int] = 1,
|
518
644
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
519
645
|
latents: Optional[torch.FloatTensor] = None,
|
520
646
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
521
647
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
648
|
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
649
|
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
650
|
+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
651
|
+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
652
|
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
653
|
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
522
654
|
output_type: Optional[str] = "pil",
|
523
655
|
return_dict: bool = True,
|
524
656
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
@@ -543,10 +675,10 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
543
675
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
544
676
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
545
677
|
expense of slower inference.
|
546
|
-
|
547
|
-
Custom
|
548
|
-
|
549
|
-
|
678
|
+
sigmas (`List[float]`, *optional*):
|
679
|
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
680
|
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
681
|
+
will be used.
|
550
682
|
guidance_scale (`float`, *optional*, defaults to 7.0):
|
551
683
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
552
684
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
@@ -568,6 +700,17 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
568
700
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
569
701
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
570
702
|
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
703
|
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
704
|
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
705
|
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
706
|
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
707
|
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
708
|
+
negative_ip_adapter_image:
|
709
|
+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
710
|
+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
711
|
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
712
|
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
713
|
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
571
714
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
572
715
|
The output format of the generate image. Choose between
|
573
716
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
@@ -605,8 +748,12 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
605
748
|
prompt_2,
|
606
749
|
height,
|
607
750
|
width,
|
751
|
+
negative_prompt=negative_prompt,
|
752
|
+
negative_prompt_2=negative_prompt_2,
|
608
753
|
prompt_embeds=prompt_embeds,
|
754
|
+
negative_prompt_embeds=negative_prompt_embeds,
|
609
755
|
pooled_prompt_embeds=pooled_prompt_embeds,
|
756
|
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
610
757
|
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
611
758
|
max_sequence_length=max_sequence_length,
|
612
759
|
)
|
@@ -628,6 +775,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
628
775
|
lora_scale = (
|
629
776
|
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
630
777
|
)
|
778
|
+
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
|
631
779
|
(
|
632
780
|
prompt_embeds,
|
633
781
|
pooled_prompt_embeds,
|
@@ -642,6 +790,21 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
642
790
|
max_sequence_length=max_sequence_length,
|
643
791
|
lora_scale=lora_scale,
|
644
792
|
)
|
793
|
+
if do_true_cfg:
|
794
|
+
(
|
795
|
+
negative_prompt_embeds,
|
796
|
+
negative_pooled_prompt_embeds,
|
797
|
+
_,
|
798
|
+
) = self.encode_prompt(
|
799
|
+
prompt=negative_prompt,
|
800
|
+
prompt_2=negative_prompt_2,
|
801
|
+
prompt_embeds=negative_prompt_embeds,
|
802
|
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
803
|
+
device=device,
|
804
|
+
num_images_per_prompt=num_images_per_prompt,
|
805
|
+
max_sequence_length=max_sequence_length,
|
806
|
+
lora_scale=lora_scale,
|
807
|
+
)
|
645
808
|
|
646
809
|
# 4. Prepare latent variables
|
647
810
|
num_channels_latents = self.transformer.config.in_channels // 4
|
@@ -657,7 +820,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
657
820
|
)
|
658
821
|
|
659
822
|
# 5. Prepare timesteps
|
660
|
-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
823
|
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
661
824
|
image_seq_len = latents.shape[1]
|
662
825
|
mu = calculate_shift(
|
663
826
|
image_seq_len,
|
@@ -670,32 +833,61 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
670
833
|
self.scheduler,
|
671
834
|
num_inference_steps,
|
672
835
|
device,
|
673
|
-
|
674
|
-
sigmas,
|
836
|
+
sigmas=sigmas,
|
675
837
|
mu=mu,
|
676
838
|
)
|
677
839
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
678
840
|
self._num_timesteps = len(timesteps)
|
679
841
|
|
842
|
+
# handle guidance
|
843
|
+
if self.transformer.config.guidance_embeds:
|
844
|
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
845
|
+
guidance = guidance.expand(latents.shape[0])
|
846
|
+
else:
|
847
|
+
guidance = None
|
848
|
+
|
849
|
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
850
|
+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
851
|
+
):
|
852
|
+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
853
|
+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
854
|
+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
855
|
+
):
|
856
|
+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
857
|
+
|
858
|
+
if self.joint_attention_kwargs is None:
|
859
|
+
self._joint_attention_kwargs = {}
|
860
|
+
|
861
|
+
image_embeds = None
|
862
|
+
negative_image_embeds = None
|
863
|
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
864
|
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
865
|
+
ip_adapter_image,
|
866
|
+
ip_adapter_image_embeds,
|
867
|
+
device,
|
868
|
+
batch_size * num_images_per_prompt,
|
869
|
+
)
|
870
|
+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
871
|
+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
872
|
+
negative_ip_adapter_image,
|
873
|
+
negative_ip_adapter_image_embeds,
|
874
|
+
device,
|
875
|
+
batch_size * num_images_per_prompt,
|
876
|
+
)
|
877
|
+
|
680
878
|
# 6. Denoising loop
|
681
879
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
682
880
|
for i, t in enumerate(timesteps):
|
683
881
|
if self.interrupt:
|
684
882
|
continue
|
685
883
|
|
884
|
+
if image_embeds is not None:
|
885
|
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
686
886
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
687
887
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
688
888
|
|
689
|
-
# handle guidance
|
690
|
-
if self.transformer.config.guidance_embeds:
|
691
|
-
guidance = torch.tensor([guidance_scale], device=device)
|
692
|
-
guidance = guidance.expand(latents.shape[0])
|
693
|
-
else:
|
694
|
-
guidance = None
|
695
|
-
|
696
889
|
noise_pred = self.transformer(
|
697
890
|
hidden_states=latents,
|
698
|
-
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
699
891
|
timestep=timestep / 1000,
|
700
892
|
guidance=guidance,
|
701
893
|
pooled_projections=pooled_prompt_embeds,
|
@@ -706,6 +898,22 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
706
898
|
return_dict=False,
|
707
899
|
)[0]
|
708
900
|
|
901
|
+
if do_true_cfg:
|
902
|
+
if negative_image_embeds is not None:
|
903
|
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
904
|
+
neg_noise_pred = self.transformer(
|
905
|
+
hidden_states=latents,
|
906
|
+
timestep=timestep / 1000,
|
907
|
+
guidance=guidance,
|
908
|
+
pooled_projections=negative_pooled_prompt_embeds,
|
909
|
+
encoder_hidden_states=negative_prompt_embeds,
|
910
|
+
txt_ids=text_ids,
|
911
|
+
img_ids=latent_image_ids,
|
912
|
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
913
|
+
return_dict=False,
|
914
|
+
)[0]
|
915
|
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
916
|
+
|
709
917
|
# compute the previous noisy sample x_t -> x_t-1
|
710
918
|
latents_dtype = latents.dtype
|
711
919
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|