diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,20 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
|
|
33
33
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
34
34
|
|
35
35
|
|
36
|
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
37
|
+
def retrieve_latents(
|
38
|
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
39
|
+
):
|
40
|
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
41
|
+
return encoder_output.latent_dist.sample(generator)
|
42
|
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
43
|
+
return encoder_output.latent_dist.mode()
|
44
|
+
elif hasattr(encoder_output, "latents"):
|
45
|
+
return encoder_output.latents
|
46
|
+
else:
|
47
|
+
raise AttributeError("Could not access latents of provided encoder_output")
|
48
|
+
|
49
|
+
|
36
50
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
|
37
51
|
def preprocess(image):
|
38
52
|
warnings.warn(
|
@@ -105,7 +119,54 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
105
119
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
106
120
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
|
107
121
|
|
108
|
-
def _encode_prompt(
|
122
|
+
def _encode_prompt(
|
123
|
+
self,
|
124
|
+
prompt,
|
125
|
+
device,
|
126
|
+
do_classifier_free_guidance,
|
127
|
+
negative_prompt=None,
|
128
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
129
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
130
|
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
131
|
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
132
|
+
**kwargs,
|
133
|
+
):
|
134
|
+
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
135
|
+
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
136
|
+
|
137
|
+
(
|
138
|
+
prompt_embeds,
|
139
|
+
negative_prompt_embeds,
|
140
|
+
pooled_prompt_embeds,
|
141
|
+
negative_pooled_prompt_embeds,
|
142
|
+
) = self.encode_prompt(
|
143
|
+
prompt=prompt,
|
144
|
+
device=device,
|
145
|
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
146
|
+
negative_prompt=negative_prompt,
|
147
|
+
prompt_embeds=prompt_embeds,
|
148
|
+
negative_prompt_embeds=negative_prompt_embeds,
|
149
|
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
150
|
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
151
|
+
**kwargs,
|
152
|
+
)
|
153
|
+
|
154
|
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
155
|
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
|
156
|
+
|
157
|
+
return prompt_embeds, pooled_prompt_embeds
|
158
|
+
|
159
|
+
def encode_prompt(
|
160
|
+
self,
|
161
|
+
prompt,
|
162
|
+
device,
|
163
|
+
do_classifier_free_guidance,
|
164
|
+
negative_prompt=None,
|
165
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
166
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
167
|
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
168
|
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
169
|
+
):
|
109
170
|
r"""
|
110
171
|
Encodes the prompt into text encoder hidden states.
|
111
172
|
|
@@ -119,81 +180,100 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
119
180
|
negative_prompt (`str` or `List[str]`):
|
120
181
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
121
182
|
if `guidance_scale` is less than `1`).
|
183
|
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
184
|
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
185
|
+
provided, text embeddings will be generated from `prompt` input argument.
|
186
|
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
187
|
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
188
|
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
189
|
+
argument.
|
190
|
+
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
191
|
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
192
|
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
193
|
+
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
194
|
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
195
|
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
196
|
+
input argument.
|
122
197
|
"""
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
prompt
|
127
|
-
|
128
|
-
|
129
|
-
truncation=True,
|
130
|
-
return_length=True,
|
131
|
-
return_tensors="pt",
|
132
|
-
)
|
133
|
-
text_input_ids = text_inputs.input_ids
|
134
|
-
|
135
|
-
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
136
|
-
|
137
|
-
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
138
|
-
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
139
|
-
logger.warning(
|
140
|
-
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
141
|
-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
142
|
-
)
|
143
|
-
|
144
|
-
text_encoder_out = self.text_encoder(
|
145
|
-
text_input_ids.to(device),
|
146
|
-
output_hidden_states=True,
|
147
|
-
)
|
148
|
-
text_embeddings = text_encoder_out.hidden_states[-1]
|
149
|
-
text_pooler_out = text_encoder_out.pooler_output
|
150
|
-
|
151
|
-
# get unconditional embeddings for classifier free guidance
|
152
|
-
if do_classifier_free_guidance:
|
153
|
-
uncond_tokens: List[str]
|
154
|
-
if negative_prompt is None:
|
155
|
-
uncond_tokens = [""] * batch_size
|
156
|
-
elif type(prompt) is not type(negative_prompt):
|
157
|
-
raise TypeError(
|
158
|
-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
159
|
-
f" {type(prompt)}."
|
160
|
-
)
|
161
|
-
elif isinstance(negative_prompt, str):
|
162
|
-
uncond_tokens = [negative_prompt]
|
163
|
-
elif batch_size != len(negative_prompt):
|
164
|
-
raise ValueError(
|
165
|
-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
166
|
-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
167
|
-
" the batch size of `prompt`."
|
168
|
-
)
|
169
|
-
else:
|
170
|
-
uncond_tokens = negative_prompt
|
198
|
+
if prompt is not None and isinstance(prompt, str):
|
199
|
+
batch_size = 1
|
200
|
+
elif prompt is not None and isinstance(prompt, list):
|
201
|
+
batch_size = len(prompt)
|
202
|
+
else:
|
203
|
+
batch_size = prompt_embeds.shape[0]
|
171
204
|
|
172
|
-
|
173
|
-
|
174
|
-
|
205
|
+
if prompt_embeds is None or pooled_prompt_embeds is None:
|
206
|
+
text_inputs = self.tokenizer(
|
207
|
+
prompt,
|
175
208
|
padding="max_length",
|
176
|
-
max_length=
|
209
|
+
max_length=self.tokenizer.model_max_length,
|
177
210
|
truncation=True,
|
178
211
|
return_length=True,
|
179
212
|
return_tensors="pt",
|
180
213
|
)
|
214
|
+
text_input_ids = text_inputs.input_ids
|
181
215
|
|
182
|
-
|
183
|
-
|
216
|
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
217
|
+
|
218
|
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
219
|
+
text_input_ids, untruncated_ids
|
220
|
+
):
|
221
|
+
removed_text = self.tokenizer.batch_decode(
|
222
|
+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
223
|
+
)
|
224
|
+
logger.warning(
|
225
|
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
226
|
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
227
|
+
)
|
228
|
+
|
229
|
+
text_encoder_out = self.text_encoder(
|
230
|
+
text_input_ids.to(device),
|
184
231
|
output_hidden_states=True,
|
185
232
|
)
|
233
|
+
prompt_embeds = text_encoder_out.hidden_states[-1]
|
234
|
+
pooled_prompt_embeds = text_encoder_out.pooler_output
|
186
235
|
|
187
|
-
|
188
|
-
|
236
|
+
# get unconditional embeddings for classifier free guidance
|
237
|
+
if do_classifier_free_guidance:
|
238
|
+
if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
|
239
|
+
uncond_tokens: List[str]
|
240
|
+
if negative_prompt is None:
|
241
|
+
uncond_tokens = [""] * batch_size
|
242
|
+
elif type(prompt) is not type(negative_prompt):
|
243
|
+
raise TypeError(
|
244
|
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
245
|
+
f" {type(prompt)}."
|
246
|
+
)
|
247
|
+
elif isinstance(negative_prompt, str):
|
248
|
+
uncond_tokens = [negative_prompt]
|
249
|
+
elif batch_size != len(negative_prompt):
|
250
|
+
raise ValueError(
|
251
|
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
252
|
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
253
|
+
" the batch size of `prompt`."
|
254
|
+
)
|
255
|
+
else:
|
256
|
+
uncond_tokens = negative_prompt
|
257
|
+
|
258
|
+
max_length = text_input_ids.shape[-1]
|
259
|
+
uncond_input = self.tokenizer(
|
260
|
+
uncond_tokens,
|
261
|
+
padding="max_length",
|
262
|
+
max_length=max_length,
|
263
|
+
truncation=True,
|
264
|
+
return_length=True,
|
265
|
+
return_tensors="pt",
|
266
|
+
)
|
267
|
+
|
268
|
+
uncond_encoder_out = self.text_encoder(
|
269
|
+
uncond_input.input_ids.to(device),
|
270
|
+
output_hidden_states=True,
|
271
|
+
)
|
189
272
|
|
190
|
-
|
191
|
-
|
192
|
-
# to avoid doing two forward passes
|
193
|
-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
194
|
-
text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out])
|
273
|
+
negative_prompt_embeds = uncond_encoder_out.hidden_states[-1]
|
274
|
+
negative_pooled_prompt_embeds = uncond_encoder_out.pooler_output
|
195
275
|
|
196
|
-
return
|
276
|
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
197
277
|
|
198
278
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
199
279
|
def decode_latents(self, latents):
|
@@ -207,12 +287,56 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
207
287
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
208
288
|
return image
|
209
289
|
|
210
|
-
def check_inputs(
|
211
|
-
|
290
|
+
def check_inputs(
|
291
|
+
self,
|
292
|
+
prompt,
|
293
|
+
image,
|
294
|
+
callback_steps,
|
295
|
+
negative_prompt=None,
|
296
|
+
prompt_embeds=None,
|
297
|
+
negative_prompt_embeds=None,
|
298
|
+
pooled_prompt_embeds=None,
|
299
|
+
negative_pooled_prompt_embeds=None,
|
300
|
+
):
|
301
|
+
if prompt is not None and prompt_embeds is not None:
|
302
|
+
raise ValueError(
|
303
|
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
304
|
+
" only forward one of the two."
|
305
|
+
)
|
306
|
+
elif prompt is None and prompt_embeds is None:
|
307
|
+
raise ValueError(
|
308
|
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
309
|
+
)
|
310
|
+
elif prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list):
|
212
311
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
213
312
|
|
313
|
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
314
|
+
raise ValueError(
|
315
|
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
316
|
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
317
|
+
)
|
318
|
+
|
319
|
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
320
|
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
321
|
+
raise ValueError(
|
322
|
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
323
|
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
324
|
+
f" {negative_prompt_embeds.shape}."
|
325
|
+
)
|
326
|
+
|
327
|
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
328
|
+
raise ValueError(
|
329
|
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
330
|
+
)
|
331
|
+
|
332
|
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
333
|
+
raise ValueError(
|
334
|
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
335
|
+
)
|
336
|
+
|
214
337
|
if (
|
215
338
|
not isinstance(image, torch.Tensor)
|
339
|
+
and not isinstance(image, np.ndarray)
|
216
340
|
and not isinstance(image, PIL.Image.Image)
|
217
341
|
and not isinstance(image, list)
|
218
342
|
):
|
@@ -222,10 +346,14 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
222
346
|
|
223
347
|
# verify batch size of prompt and image are same if image is a list or tensor
|
224
348
|
if isinstance(image, (list, torch.Tensor)):
|
225
|
-
if
|
226
|
-
|
349
|
+
if prompt is not None:
|
350
|
+
if isinstance(prompt, str):
|
351
|
+
batch_size = 1
|
352
|
+
else:
|
353
|
+
batch_size = len(prompt)
|
227
354
|
else:
|
228
|
-
batch_size =
|
355
|
+
batch_size = prompt_embeds.shape[0]
|
356
|
+
|
229
357
|
if isinstance(image, list):
|
230
358
|
image_batch_size = len(image)
|
231
359
|
else:
|
@@ -261,13 +389,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
261
389
|
@torch.no_grad()
|
262
390
|
def __call__(
|
263
391
|
self,
|
264
|
-
prompt: Union[str, List[str]],
|
392
|
+
prompt: Union[str, List[str]] = None,
|
265
393
|
image: PipelineImageInput = None,
|
266
394
|
num_inference_steps: int = 75,
|
267
395
|
guidance_scale: float = 9.0,
|
268
396
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
269
397
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
270
398
|
latents: Optional[torch.Tensor] = None,
|
399
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
400
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
401
|
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
402
|
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
271
403
|
output_type: Optional[str] = "pil",
|
272
404
|
return_dict: bool = True,
|
273
405
|
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
@@ -359,10 +491,22 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
359
491
|
"""
|
360
492
|
|
361
493
|
# 1. Check inputs
|
362
|
-
self.check_inputs(
|
494
|
+
self.check_inputs(
|
495
|
+
prompt,
|
496
|
+
image,
|
497
|
+
callback_steps,
|
498
|
+
negative_prompt,
|
499
|
+
prompt_embeds,
|
500
|
+
negative_prompt_embeds,
|
501
|
+
pooled_prompt_embeds,
|
502
|
+
negative_pooled_prompt_embeds,
|
503
|
+
)
|
363
504
|
|
364
505
|
# 2. Define call parameters
|
365
|
-
|
506
|
+
if prompt is not None:
|
507
|
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
508
|
+
else:
|
509
|
+
batch_size = prompt_embeds.shape[0]
|
366
510
|
device = self._execution_device
|
367
511
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
368
512
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
@@ -373,16 +517,32 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
373
517
|
prompt = [""] * batch_size
|
374
518
|
|
375
519
|
# 3. Encode input prompt
|
376
|
-
|
377
|
-
|
520
|
+
(
|
521
|
+
prompt_embeds,
|
522
|
+
negative_prompt_embeds,
|
523
|
+
pooled_prompt_embeds,
|
524
|
+
negative_pooled_prompt_embeds,
|
525
|
+
) = self.encode_prompt(
|
526
|
+
prompt,
|
527
|
+
device,
|
528
|
+
do_classifier_free_guidance,
|
529
|
+
negative_prompt,
|
530
|
+
prompt_embeds,
|
531
|
+
negative_prompt_embeds,
|
532
|
+
pooled_prompt_embeds,
|
533
|
+
negative_pooled_prompt_embeds,
|
378
534
|
)
|
379
535
|
|
536
|
+
if do_classifier_free_guidance:
|
537
|
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
538
|
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
|
539
|
+
|
380
540
|
# 4. Preprocess image
|
381
541
|
image = self.image_processor.preprocess(image)
|
382
|
-
image = image.to(dtype=
|
542
|
+
image = image.to(dtype=prompt_embeds.dtype, device=device)
|
383
543
|
if image.shape[1] == 3:
|
384
544
|
# encode image if not in latent-space yet
|
385
|
-
image = self.vae.encode(image)
|
545
|
+
image = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
|
386
546
|
|
387
547
|
# 5. set timesteps
|
388
548
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
@@ -400,17 +560,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
400
560
|
inv_noise_level = (noise_level**2 + 1) ** (-0.5)
|
401
561
|
|
402
562
|
image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None]
|
403
|
-
image_cond = image_cond.to(
|
563
|
+
image_cond = image_cond.to(prompt_embeds.dtype)
|
404
564
|
|
405
565
|
noise_level_embed = torch.cat(
|
406
566
|
[
|
407
|
-
torch.ones(
|
408
|
-
torch.zeros(
|
567
|
+
torch.ones(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
|
568
|
+
torch.zeros(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
|
409
569
|
],
|
410
570
|
dim=1,
|
411
571
|
)
|
412
572
|
|
413
|
-
timestep_condition = torch.cat([noise_level_embed,
|
573
|
+
timestep_condition = torch.cat([noise_level_embed, pooled_prompt_embeds], dim=1)
|
414
574
|
|
415
575
|
# 6. Prepare latent variables
|
416
576
|
height, width = image.shape[2:]
|
@@ -420,7 +580,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
420
580
|
num_channels_latents,
|
421
581
|
height * 2, # 2x upscale
|
422
582
|
width * 2,
|
423
|
-
|
583
|
+
prompt_embeds.dtype,
|
424
584
|
device,
|
425
585
|
generator,
|
426
586
|
latents,
|
@@ -454,7 +614,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
|
454
614
|
noise_pred = self.unet(
|
455
615
|
scaled_model_input,
|
456
616
|
timestep,
|
457
|
-
encoder_hidden_states=
|
617
|
+
encoder_hidden_states=prompt_embeds,
|
458
618
|
timestep_cond=timestep_condition,
|
459
619
|
).sample
|
460
620
|
|