diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +34 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +170 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +35 -6
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
|
|
28
28
|
from ...utils import (
|
29
29
|
USE_PEFT_BACKEND,
|
30
30
|
deprecate,
|
31
|
+
is_torch_xla_available,
|
31
32
|
logging,
|
32
33
|
replace_example_docstring,
|
33
34
|
scale_lora_layers,
|
@@ -39,6 +40,13 @@ from .pipeline_output import StableDiffusionPipelineOutput
|
|
39
40
|
from .safety_checker import StableDiffusionSafetyChecker
|
40
41
|
|
41
42
|
|
43
|
+
if is_torch_xla_available():
|
44
|
+
import torch_xla.core.xla_model as xm
|
45
|
+
|
46
|
+
XLA_AVAILABLE = True
|
47
|
+
else:
|
48
|
+
XLA_AVAILABLE = False
|
49
|
+
|
42
50
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43
51
|
|
44
52
|
EXAMPLE_DOC_STRING = """
|
@@ -57,9 +65,21 @@ EXAMPLE_DOC_STRING = """
|
|
57
65
|
|
58
66
|
|
59
67
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
60
|
-
"""
|
61
|
-
|
62
|
-
|
68
|
+
r"""
|
69
|
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
70
|
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
71
|
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
72
|
+
|
73
|
+
Args:
|
74
|
+
noise_cfg (`torch.Tensor`):
|
75
|
+
The predicted noise tensor for the guided diffusion process.
|
76
|
+
noise_pred_text (`torch.Tensor`):
|
77
|
+
The predicted noise tensor for the text-guided diffusion process.
|
78
|
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
79
|
+
A rescale factor applied to the noise predictions.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
63
83
|
"""
|
64
84
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
65
85
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
@@ -78,7 +98,7 @@ def retrieve_timesteps(
|
|
78
98
|
sigmas: Optional[List[float]] = None,
|
79
99
|
**kwargs,
|
80
100
|
):
|
81
|
-
"""
|
101
|
+
r"""
|
82
102
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
83
103
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
84
104
|
|
@@ -137,7 +157,7 @@ class StableDiffusionPipeline(
|
|
137
157
|
IPAdapterMixin,
|
138
158
|
FromSingleFileMixin,
|
139
159
|
):
|
140
|
-
|
160
|
+
"""
|
141
161
|
Pipeline for text-to-image generation using Stable Diffusion.
|
142
162
|
|
143
163
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
@@ -1036,6 +1056,9 @@ class StableDiffusionPipeline(
|
|
1036
1056
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
1037
1057
|
callback(step_idx, t, latents)
|
1038
1058
|
|
1059
|
+
if XLA_AVAILABLE:
|
1060
|
+
xm.mark_step()
|
1061
|
+
|
1039
1062
|
if not output_type == "latent":
|
1040
1063
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
1041
1064
|
0
|
@@ -1049,7 +1072,6 @@ class StableDiffusionPipeline(
|
|
1049
1072
|
do_denormalize = [True] * image.shape[0]
|
1050
1073
|
else:
|
1051
1074
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
1052
|
-
|
1053
1075
|
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
1054
1076
|
|
1055
1077
|
# Offload all models
|
@@ -119,7 +119,7 @@ def retrieve_timesteps(
|
|
119
119
|
sigmas: Optional[List[float]] = None,
|
120
120
|
**kwargs,
|
121
121
|
):
|
122
|
-
"""
|
122
|
+
r"""
|
123
123
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
124
124
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
125
125
|
|
@@ -60,7 +60,7 @@ def retrieve_timesteps(
|
|
60
60
|
sigmas: Optional[List[float]] = None,
|
61
61
|
**kwargs,
|
62
62
|
):
|
63
|
-
"""
|
63
|
+
r"""
|
64
64
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
65
65
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
66
66
|
|
@@ -33,6 +33,20 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
|
|
33
33
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
34
34
|
|
35
35
|
|
36
|
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
37
|
+
def retrieve_latents(
|
38
|
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
39
|
+
):
|
40
|
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
41
|
+
return encoder_output.latent_dist.sample(generator)
|
42
|
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
43
|
+
return encoder_output.latent_dist.mode()
|
44
|
+
elif hasattr(encoder_output, "latents"):
|
45
|
+
return encoder_output.latents
|
46
|
+
else:
|
47
|
+
raise AttributeError("Could not access latents of provided encoder_output")
|
48
|
+
|
49
|
+
|
36
50
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
|
37
51
|
def preprocess(image):
|
38
52
|
warnings.warn(
|
@@ -105,7 +119,54 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
105
119
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
106
120
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
|
107
121
|
|
108
|
-
def _encode_prompt(
|
122
|
+
def _encode_prompt(
|
123
|
+
self,
|
124
|
+
prompt,
|
125
|
+
device,
|
126
|
+
do_classifier_free_guidance,
|
127
|
+
negative_prompt=None,
|
128
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
129
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
130
|
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
131
|
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
132
|
+
**kwargs,
|
133
|
+
):
|
134
|
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
135
|
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
136
|
+
|
137
|
+
(
|
138
|
+
prompt_embeds,
|
139
|
+
negative_prompt_embeds,
|
140
|
+
pooled_prompt_embeds,
|
141
|
+
negative_pooled_prompt_embeds,
|
142
|
+
) = self.encode_prompt(
|
143
|
+
prompt=prompt,
|
144
|
+
device=device,
|
145
|
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
146
|
+
negative_prompt=negative_prompt,
|
147
|
+
prompt_embeds=prompt_embeds,
|
148
|
+
negative_prompt_embeds=negative_prompt_embeds,
|
149
|
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
150
|
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
151
|
+
**kwargs,
|
152
|
+
)
|
153
|
+
|
154
|
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
155
|
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
|
156
|
+
|
157
|
+
return prompt_embeds, pooled_prompt_embeds
|
158
|
+
|
159
|
+
def encode_prompt(
|
160
|
+
self,
|
161
|
+
prompt,
|
162
|
+
device,
|
163
|
+
do_classifier_free_guidance,
|
164
|
+
negative_prompt=None,
|
165
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
166
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
167
|
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
168
|
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
169
|
+
):
|
109
170
|
r"""
|
110
171
|
Encodes the prompt into text encoder hidden states.
|
111
172
|
|
@@ -119,81 +180,100 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
119
180
|
negative_prompt (`str` or `List[str]`):
|
120
181
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
121
182
|
if `guidance_scale` is less than `1`).
|
183
|
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
184
|
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
185
|
+
provided, text embeddings will be generated from `prompt` input argument.
|
186
|
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
187
|
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
188
|
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
189
|
+
argument.
|
190
|
+
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
191
|
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
192
|
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
193
|
+
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
194
|
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
195
|
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
196
|
+
input argument.
|
122
197
|
"""
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
prompt
|
127
|
-
|
128
|
-
|
129
|
-
truncation=True,
|
130
|
-
return_length=True,
|
131
|
-
return_tensors="pt",
|
132
|
-
)
|
133
|
-
text_input_ids = text_inputs.input_ids
|
134
|
-
|
135
|
-
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
136
|
-
|
137
|
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
138
|
-
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
139
|
-
logger.warning(
|
140
|
-
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
141
|
-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
142
|
-
)
|
143
|
-
|
144
|
-
text_encoder_out = self.text_encoder(
|
145
|
-
text_input_ids.to(device),
|
146
|
-
output_hidden_states=True,
|
147
|
-
)
|
148
|
-
text_embeddings = text_encoder_out.hidden_states[-1]
|
149
|
-
text_pooler_out = text_encoder_out.pooler_output
|
150
|
-
|
151
|
-
# get unconditional embeddings for classifier free guidance
|
152
|
-
if do_classifier_free_guidance:
|
153
|
-
uncond_tokens: List[str]
|
154
|
-
if negative_prompt is None:
|
155
|
-
uncond_tokens = [""] * batch_size
|
156
|
-
elif type(prompt) is not type(negative_prompt):
|
157
|
-
raise TypeError(
|
158
|
-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
159
|
-
f" {type(prompt)}."
|
160
|
-
)
|
161
|
-
elif isinstance(negative_prompt, str):
|
162
|
-
uncond_tokens = [negative_prompt]
|
163
|
-
elif batch_size != len(negative_prompt):
|
164
|
-
raise ValueError(
|
165
|
-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
166
|
-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
167
|
-
" the batch size of `prompt`."
|
168
|
-
)
|
169
|
-
else:
|
170
|
-
uncond_tokens = negative_prompt
|
198
|
+
if prompt is not None and isinstance(prompt, str):
|
199
|
+
batch_size = 1
|
200
|
+
elif prompt is not None and isinstance(prompt, list):
|
201
|
+
batch_size = len(prompt)
|
202
|
+
else:
|
203
|
+
batch_size = prompt_embeds.shape[0]
|
171
204
|
|
172
|
-
|
173
|
-
|
174
|
-
|
205
|
+
if prompt_embeds is None or pooled_prompt_embeds is None:
|
206
|
+
text_inputs = self.tokenizer(
|
207
|
+
prompt,
|
175
208
|
padding="max_length",
|
176
|
-
max_length=
|
209
|
+
max_length=self.tokenizer.model_max_length,
|
177
210
|
truncation=True,
|
178
211
|
return_length=True,
|
179
212
|
return_tensors="pt",
|
180
213
|
)
|
214
|
+
text_input_ids = text_inputs.input_ids
|
181
215
|
|
182
|
-
|
183
|
-
|
216
|
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
217
|
+
|
218
|
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
219
|
+
text_input_ids, untruncated_ids
|
220
|
+
):
|
221
|
+
removed_text = self.tokenizer.batch_decode(
|
222
|
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
223
|
+
)
|
224
|
+
logger.warning(
|
225
|
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
226
|
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
227
|
+
)
|
228
|
+
|
229
|
+
text_encoder_out = self.text_encoder(
|
230
|
+
text_input_ids.to(device),
|
184
231
|
output_hidden_states=True,
|
185
232
|
)
|
233
|
+
prompt_embeds = text_encoder_out.hidden_states[-1]
|
234
|
+
pooled_prompt_embeds = text_encoder_out.pooler_output
|
186
235
|
|
187
|
-
|
188
|
-
|
236
|
+
# get unconditional embeddings for classifier free guidance
|
237
|
+
if do_classifier_free_guidance:
|
238
|
+
if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
|
239
|
+
uncond_tokens: List[str]
|
240
|
+
if negative_prompt is None:
|
241
|
+
uncond_tokens = [""] * batch_size
|
242
|
+
elif type(prompt) is not type(negative_prompt):
|
243
|
+
raise TypeError(
|
244
|
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
245
|
+
f" {type(prompt)}."
|
246
|
+
)
|
247
|
+
elif isinstance(negative_prompt, str):
|
248
|
+
uncond_tokens = [negative_prompt]
|
249
|
+
elif batch_size != len(negative_prompt):
|
250
|
+
raise ValueError(
|
251
|
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
252
|
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
253
|
+
" the batch size of `prompt`."
|
254
|
+
)
|
255
|
+
else:
|
256
|
+
uncond_tokens = negative_prompt
|
257
|
+
|
258
|
+
max_length = text_input_ids.shape[-1]
|
259
|
+
uncond_input = self.tokenizer(
|
260
|
+
uncond_tokens,
|
261
|
+
padding="max_length",
|
262
|
+
max_length=max_length,
|
263
|
+
truncation=True,
|
264
|
+
return_length=True,
|
265
|
+
return_tensors="pt",
|
266
|
+
)
|
267
|
+
|
268
|
+
uncond_encoder_out = self.text_encoder(
|
269
|
+
uncond_input.input_ids.to(device),
|
270
|
+
output_hidden_states=True,
|
271
|
+
)
|
189
272
|
|
190
|
-
|
191
|
-
|
192
|
-
# to avoid doing two forward passes
|
193
|
-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
194
|
-
text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out])
|
273
|
+
negative_prompt_embeds = uncond_encoder_out.hidden_states[-1]
|
274
|
+
negative_pooled_prompt_embeds = uncond_encoder_out.pooler_output
|
195
275
|
|
196
|
-
return
|
276
|
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
197
277
|
|
198
278
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
199
279
|
def decode_latents(self, latents):
|
@@ -207,12 +287,56 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
207
287
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
208
288
|
return image
|
209
289
|
|
210
|
-
def check_inputs(
|
211
|
-
|
290
|
+
def check_inputs(
|
291
|
+
self,
|
292
|
+
prompt,
|
293
|
+
image,
|
294
|
+
callback_steps,
|
295
|
+
negative_prompt=None,
|
296
|
+
prompt_embeds=None,
|
297
|
+
negative_prompt_embeds=None,
|
298
|
+
pooled_prompt_embeds=None,
|
299
|
+
negative_pooled_prompt_embeds=None,
|
300
|
+
):
|
301
|
+
if prompt is not None and prompt_embeds is not None:
|
302
|
+
raise ValueError(
|
303
|
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
304
|
+
" only forward one of the two."
|
305
|
+
)
|
306
|
+
elif prompt is None and prompt_embeds is None:
|
307
|
+
raise ValueError(
|
308
|
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
309
|
+
)
|
310
|
+
elif prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list):
|
212
311
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
213
312
|
|
313
|
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
314
|
+
raise ValueError(
|
315
|
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
316
|
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
317
|
+
)
|
318
|
+
|
319
|
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
320
|
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
321
|
+
raise ValueError(
|
322
|
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
323
|
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
324
|
+
f" {negative_prompt_embeds.shape}."
|
325
|
+
)
|
326
|
+
|
327
|
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
328
|
+
raise ValueError(
|
329
|
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
330
|
+
)
|
331
|
+
|
332
|
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
333
|
+
raise ValueError(
|
334
|
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
335
|
+
)
|
336
|
+
|
214
337
|
if (
|
215
338
|
not isinstance(image, torch.Tensor)
|
339
|
+
and not isinstance(image, np.ndarray)
|
216
340
|
and not isinstance(image, PIL.Image.Image)
|
217
341
|
and not isinstance(image, list)
|
218
342
|
):
|
@@ -222,10 +346,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
222
346
|
|
223
347
|
# verify batch size of prompt and image are same if image is a list or tensor
|
224
348
|
if isinstance(image, (list, torch.Tensor)):
|
225
|
-
if
|
226
|
-
|
349
|
+
if prompt is not None:
|
350
|
+
if isinstance(prompt, str):
|
351
|
+
batch_size = 1
|
352
|
+
else:
|
353
|
+
batch_size = len(prompt)
|
227
354
|
else:
|
228
|
-
batch_size =
|
355
|
+
batch_size = prompt_embeds.shape[0]
|
356
|
+
|
229
357
|
if isinstance(image, list):
|
230
358
|
image_batch_size = len(image)
|
231
359
|
else:
|
@@ -261,13 +389,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
261
389
|
@torch.no_grad()
|
262
390
|
def __call__(
|
263
391
|
self,
|
264
|
-
prompt: Union[str, List[str]],
|
392
|
+
prompt: Union[str, List[str]] = None,
|
265
393
|
image: PipelineImageInput = None,
|
266
394
|
num_inference_steps: int = 75,
|
267
395
|
guidance_scale: float = 9.0,
|
268
396
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
269
397
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
270
398
|
latents: Optional[torch.Tensor] = None,
|
399
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
400
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
401
|
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
402
|
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
271
403
|
output_type: Optional[str] = "pil",
|
272
404
|
return_dict: bool = True,
|
273
405
|
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
@@ -359,10 +491,22 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
359
491
|
"""
|
360
492
|
|
361
493
|
# 1. Check inputs
|
362
|
-
self.check_inputs(
|
494
|
+
self.check_inputs(
|
495
|
+
prompt,
|
496
|
+
image,
|
497
|
+
callback_steps,
|
498
|
+
negative_prompt,
|
499
|
+
prompt_embeds,
|
500
|
+
negative_prompt_embeds,
|
501
|
+
pooled_prompt_embeds,
|
502
|
+
negative_pooled_prompt_embeds,
|
503
|
+
)
|
363
504
|
|
364
505
|
# 2. Define call parameters
|
365
|
-
|
506
|
+
if prompt is not None:
|
507
|
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
508
|
+
else:
|
509
|
+
batch_size = prompt_embeds.shape[0]
|
366
510
|
device = self._execution_device
|
367
511
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
368
512
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
@@ -373,16 +517,32 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
373
517
|
prompt = [""] * batch_size
|
374
518
|
|
375
519
|
# 3. Encode input prompt
|
376
|
-
|
377
|
-
|
520
|
+
(
|
521
|
+
prompt_embeds,
|
522
|
+
negative_prompt_embeds,
|
523
|
+
pooled_prompt_embeds,
|
524
|
+
negative_pooled_prompt_embeds,
|
525
|
+
) = self.encode_prompt(
|
526
|
+
prompt,
|
527
|
+
device,
|
528
|
+
do_classifier_free_guidance,
|
529
|
+
negative_prompt,
|
530
|
+
prompt_embeds,
|
531
|
+
negative_prompt_embeds,
|
532
|
+
pooled_prompt_embeds,
|
533
|
+
negative_pooled_prompt_embeds,
|
378
534
|
)
|
379
535
|
|
536
|
+
if do_classifier_free_guidance:
|
537
|
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
538
|
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
|
539
|
+
|
380
540
|
# 4. Preprocess image
|
381
541
|
image = self.image_processor.preprocess(image)
|
382
|
-
image = image.to(dtype=
|
542
|
+
image = image.to(dtype=prompt_embeds.dtype, device=device)
|
383
543
|
if image.shape[1] == 3:
|
384
544
|
# encode image if not in latent-space yet
|
385
|
-
image = self.vae.encode(image)
|
545
|
+
image = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
|
386
546
|
|
387
547
|
# 5. set timesteps
|
388
548
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
@@ -400,17 +560,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
400
560
|
inv_noise_level = (noise_level**2 + 1) ** (-0.5)
|
401
561
|
|
402
562
|
image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None]
|
403
|
-
image_cond = image_cond.to(
|
563
|
+
image_cond = image_cond.to(prompt_embeds.dtype)
|
404
564
|
|
405
565
|
noise_level_embed = torch.cat(
|
406
566
|
[
|
407
|
-
torch.ones(
|
408
|
-
torch.zeros(
|
567
|
+
torch.ones(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
|
568
|
+
torch.zeros(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
|
409
569
|
],
|
410
570
|
dim=1,
|
411
571
|
)
|
412
572
|
|
413
|
-
timestep_condition = torch.cat([noise_level_embed,
|
573
|
+
timestep_condition = torch.cat([noise_level_embed, pooled_prompt_embeds], dim=1)
|
414
574
|
|
415
575
|
# 6. Prepare latent variables
|
416
576
|
height, width = image.shape[2:]
|
@@ -420,7 +580,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
420
580
|
num_channels_latents,
|
421
581
|
height * 2, # 2x upscale
|
422
582
|
width * 2,
|
423
|
-
|
583
|
+
prompt_embeds.dtype,
|
424
584
|
device,
|
425
585
|
generator,
|
426
586
|
latents,
|
@@ -454,7 +614,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
454
614
|
noise_pred = self.unet(
|
455
615
|
scaled_model_input,
|
456
616
|
timestep,
|
457
|
-
encoder_hidden_states=
|
617
|
+
encoder_hidden_states=prompt_embeds,
|
458
618
|
timestep_cond=timestep_condition,
|
459
619
|
).sample
|
460
620
|
|
@@ -77,7 +77,7 @@ def retrieve_timesteps(
|
|
77
77
|
sigmas: Optional[List[float]] = None,
|
78
78
|
**kwargs,
|
79
79
|
):
|
80
|
-
"""
|
80
|
+
r"""
|
81
81
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
82
82
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
83
83
|
|
@@ -203,6 +203,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
203
203
|
if hasattr(self, "transformer") and self.transformer is not None
|
204
204
|
else 128
|
205
205
|
)
|
206
|
+
self.patch_size = (
|
207
|
+
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
|
208
|
+
)
|
206
209
|
|
207
210
|
def _get_t5_prompt_embeds(
|
208
211
|
self,
|
@@ -525,8 +528,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
525
528
|
callback_on_step_end_tensor_inputs=None,
|
526
529
|
max_sequence_length=None,
|
527
530
|
):
|
528
|
-
if
|
529
|
-
|
531
|
+
if (
|
532
|
+
height % (self.vae_scale_factor * self.patch_size) != 0
|
533
|
+
or width % (self.vae_scale_factor * self.patch_size) != 0
|
534
|
+
):
|
535
|
+
raise ValueError(
|
536
|
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
|
537
|
+
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
|
538
|
+
)
|
530
539
|
|
531
540
|
if callback_on_step_end_tensor_inputs is not None and not all(
|
532
541
|
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
from typing import Callable, Dict, List, Optional, Union
|
16
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
17
17
|
|
18
18
|
import PIL.Image
|
19
19
|
import torch
|
@@ -25,7 +25,7 @@ from transformers import (
|
|
25
25
|
)
|
26
26
|
|
27
27
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
28
|
-
from ...loaders import SD3LoraLoaderMixin
|
28
|
+
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
29
29
|
from ...models.autoencoders import AutoencoderKL
|
30
30
|
from ...models.transformers import SD3Transformer2DModel
|
31
31
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
@@ -98,7 +98,7 @@ def retrieve_timesteps(
|
|
98
98
|
sigmas: Optional[List[float]] = None,
|
99
99
|
**kwargs,
|
100
100
|
):
|
101
|
-
"""
|
101
|
+
r"""
|
102
102
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
103
103
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
104
104
|
|
@@ -149,7 +149,7 @@ def retrieve_timesteps(
|
|
149
149
|
return timesteps, num_inference_steps
|
150
150
|
|
151
151
|
|
152
|
-
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
152
|
+
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
|
153
153
|
r"""
|
154
154
|
Args:
|
155
155
|
transformer ([`SD3Transformer2DModel`]):
|
@@ -680,6 +680,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
|
680
680
|
def guidance_scale(self):
|
681
681
|
return self._guidance_scale
|
682
682
|
|
683
|
+
@property
|
684
|
+
def joint_attention_kwargs(self):
|
685
|
+
return self._joint_attention_kwargs
|
686
|
+
|
683
687
|
@property
|
684
688
|
def clip_skip(self):
|
685
689
|
return self._clip_skip
|
@@ -723,6 +727,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
|
723
727
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
724
728
|
output_type: Optional[str] = "pil",
|
725
729
|
return_dict: bool = True,
|
730
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
726
731
|
clip_skip: Optional[int] = None,
|
727
732
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
728
733
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
@@ -797,6 +802,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
|
797
802
|
return_dict (`bool`, *optional*, defaults to `True`):
|
798
803
|
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
799
804
|
of a plain tuple.
|
805
|
+
joint_attention_kwargs (`dict`, *optional*):
|
806
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
807
|
+
`self.processor` in
|
808
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
800
809
|
callback_on_step_end (`Callable`, *optional*):
|
801
810
|
A function that calls at the end of each denoising steps during the inference. The function is called
|
802
811
|
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
@@ -835,6 +844,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
|
835
844
|
|
836
845
|
self._guidance_scale = guidance_scale
|
837
846
|
self._clip_skip = clip_skip
|
847
|
+
self._joint_attention_kwargs = joint_attention_kwargs
|
838
848
|
self._interrupt = False
|
839
849
|
|
840
850
|
# 2. Define call parameters
|
@@ -847,6 +857,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
|
847
857
|
|
848
858
|
device = self._execution_device
|
849
859
|
|
860
|
+
lora_scale = (
|
861
|
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
862
|
+
)
|
863
|
+
|
850
864
|
(
|
851
865
|
prompt_embeds,
|
852
866
|
negative_prompt_embeds,
|
@@ -868,6 +882,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
|
868
882
|
clip_skip=self.clip_skip,
|
869
883
|
num_images_per_prompt=num_images_per_prompt,
|
870
884
|
max_sequence_length=max_sequence_length,
|
885
|
+
lora_scale=lora_scale,
|
871
886
|
)
|
872
887
|
|
873
888
|
if self.do_classifier_free_guidance:
|
@@ -912,6 +927,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
|
912
927
|
timestep=timestep,
|
913
928
|
encoder_hidden_states=prompt_embeds,
|
914
929
|
pooled_projections=pooled_prompt_embeds,
|
930
|
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
915
931
|
return_dict=False,
|
916
932
|
)[0]
|
917
933
|
|