diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -26,6 +26,7 @@ from ...models import AutoencoderOobleck, StableAudioDiTModel
|
|
26
26
|
from ...models.embeddings import get_1d_rotary_pos_embed
|
27
27
|
from ...schedulers import EDMDPMSolverMultistepScheduler
|
28
28
|
from ...utils import (
|
29
|
+
is_torch_xla_available,
|
29
30
|
logging,
|
30
31
|
replace_example_docstring,
|
31
32
|
)
|
@@ -34,6 +35,13 @@ from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
|
34
35
|
from .modeling_stable_audio import StableAudioProjectionModel
|
35
36
|
|
36
37
|
|
38
|
+
if is_torch_xla_available():
|
39
|
+
import torch_xla.core.xla_model as xm
|
40
|
+
|
41
|
+
XLA_AVAILABLE = True
|
42
|
+
else:
|
43
|
+
XLA_AVAILABLE = False
|
44
|
+
|
37
45
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
38
46
|
|
39
47
|
EXAMPLE_DOC_STRING = """
|
@@ -438,7 +446,7 @@ class StableAudioPipeline(DiffusionPipeline):
|
|
438
446
|
f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions"
|
439
447
|
)
|
440
448
|
|
441
|
-
audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length
|
449
|
+
audio_vae_length = int(self.transformer.config.sample_size) * self.vae.hop_length
|
442
450
|
audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length)
|
443
451
|
|
444
452
|
# check num_channels
|
@@ -726,6 +734,9 @@ class StableAudioPipeline(DiffusionPipeline):
|
|
726
734
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
727
735
|
callback(step_idx, t, latents)
|
728
736
|
|
737
|
+
if XLA_AVAILABLE:
|
738
|
+
xm.mark_step()
|
739
|
+
|
729
740
|
# 9. Post-processing
|
730
741
|
if not output_type == "latent":
|
731
742
|
audio = self.vae.decode(latents).sample
|
@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|
281
281
|
def num_timesteps(self):
|
282
282
|
return self._num_timesteps
|
283
283
|
|
284
|
+
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
|
285
|
+
s = torch.tensor([0.008])
|
286
|
+
clamp_range = [0, 1]
|
287
|
+
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
|
288
|
+
var = alphas_cumprod[t]
|
289
|
+
var = var.clamp(*clamp_range)
|
290
|
+
s, min_var = s.to(var.device), min_var.to(var.device)
|
291
|
+
ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
292
|
+
return ratio
|
293
|
+
|
284
294
|
@torch.no_grad()
|
285
295
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
286
296
|
def __call__(
|
@@ -434,10 +444,30 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|
434
444
|
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
435
445
|
)
|
436
446
|
|
447
|
+
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
448
|
+
timesteps = timesteps[:-1]
|
449
|
+
else:
|
450
|
+
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
|
451
|
+
self.scheduler.config.clip_sample = False # disample sample clipping
|
452
|
+
logger.warning(" set `clip_sample` to be False")
|
453
|
+
|
437
454
|
# 6. Run denoising loop
|
438
|
-
self.
|
439
|
-
|
440
|
-
|
455
|
+
if hasattr(self.scheduler, "betas"):
|
456
|
+
alphas = 1.0 - self.scheduler.betas
|
457
|
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
458
|
+
else:
|
459
|
+
alphas_cumprod = []
|
460
|
+
|
461
|
+
self._num_timesteps = len(timesteps)
|
462
|
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
463
|
+
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
464
|
+
if len(alphas_cumprod) > 0:
|
465
|
+
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
|
466
|
+
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
|
467
|
+
else:
|
468
|
+
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
|
469
|
+
else:
|
470
|
+
timestep_ratio = t.expand(latents.size(0)).to(dtype)
|
441
471
|
|
442
472
|
# 7. Denoise latents
|
443
473
|
predicted_latents = self.decoder(
|
@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|
454
484
|
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
|
455
485
|
|
456
486
|
# 9. Renoise latents to next timestep
|
487
|
+
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
488
|
+
timestep_ratio = t
|
457
489
|
latents = self.scheduler.step(
|
458
490
|
model_output=predicted_latents,
|
459
491
|
timestep=timestep_ratio,
|
@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
|
353
353
|
return self._num_timesteps
|
354
354
|
|
355
355
|
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
|
356
|
-
s = torch.tensor([0.
|
356
|
+
s = torch.tensor([0.008])
|
357
357
|
clamp_range = [0, 1]
|
358
358
|
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
|
359
359
|
var = alphas_cumprod[t]
|
@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
|
557
557
|
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
558
558
|
timesteps = timesteps[:-1]
|
559
559
|
else:
|
560
|
-
if self.scheduler.config.clip_sample:
|
560
|
+
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
|
561
561
|
self.scheduler.config.clip_sample = False # disample sample clipping
|
562
562
|
logger.warning(" set `clip_sample` to be False")
|
563
563
|
# 6. Run denoising loop
|
@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
|
|
28
28
|
from ...utils import (
|
29
29
|
USE_PEFT_BACKEND,
|
30
30
|
deprecate,
|
31
|
+
is_torch_xla_available,
|
31
32
|
logging,
|
32
33
|
replace_example_docstring,
|
33
34
|
scale_lora_layers,
|
@@ -39,6 +40,13 @@ from .pipeline_output import StableDiffusionPipelineOutput
|
|
39
40
|
from .safety_checker import StableDiffusionSafetyChecker
|
40
41
|
|
41
42
|
|
43
|
+
if is_torch_xla_available():
|
44
|
+
import torch_xla.core.xla_model as xm
|
45
|
+
|
46
|
+
XLA_AVAILABLE = True
|
47
|
+
else:
|
48
|
+
XLA_AVAILABLE = False
|
49
|
+
|
42
50
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43
51
|
|
44
52
|
EXAMPLE_DOC_STRING = """
|
@@ -57,9 +65,21 @@ EXAMPLE_DOC_STRING = """
|
|
57
65
|
|
58
66
|
|
59
67
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
60
|
-
"""
|
61
|
-
|
62
|
-
|
68
|
+
r"""
|
69
|
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
70
|
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
71
|
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
72
|
+
|
73
|
+
Args:
|
74
|
+
noise_cfg (`torch.Tensor`):
|
75
|
+
The predicted noise tensor for the guided diffusion process.
|
76
|
+
noise_pred_text (`torch.Tensor`):
|
77
|
+
The predicted noise tensor for the text-guided diffusion process.
|
78
|
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
79
|
+
A rescale factor applied to the noise predictions.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
63
83
|
"""
|
64
84
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
65
85
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
@@ -78,7 +98,7 @@ def retrieve_timesteps(
|
|
78
98
|
sigmas: Optional[List[float]] = None,
|
79
99
|
**kwargs,
|
80
100
|
):
|
81
|
-
"""
|
101
|
+
r"""
|
82
102
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
83
103
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
84
104
|
|
@@ -137,7 +157,7 @@ class StableDiffusionPipeline(
|
|
137
157
|
IPAdapterMixin,
|
138
158
|
FromSingleFileMixin,
|
139
159
|
):
|
140
|
-
|
160
|
+
"""
|
141
161
|
Pipeline for text-to-image generation using Stable Diffusion.
|
142
162
|
|
143
163
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
@@ -235,7 +255,12 @@ class StableDiffusionPipeline(
|
|
235
255
|
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
236
256
|
version.parse(unet.config._diffusers_version).base_version
|
237
257
|
) < version.parse("0.9.0.dev0")
|
238
|
-
|
258
|
+
self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int)
|
259
|
+
is_unet_sample_size_less_64 = (
|
260
|
+
hasattr(unet.config, "sample_size")
|
261
|
+
and self._is_unet_config_sample_size_int
|
262
|
+
and unet.config.sample_size < 64
|
263
|
+
)
|
239
264
|
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
240
265
|
deprecation_message = (
|
241
266
|
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
@@ -882,8 +907,18 @@ class StableDiffusionPipeline(
|
|
882
907
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
883
908
|
|
884
909
|
# 0. Default height and width to unet
|
885
|
-
|
886
|
-
|
910
|
+
if not height or not width:
|
911
|
+
height = (
|
912
|
+
self.unet.config.sample_size
|
913
|
+
if self._is_unet_config_sample_size_int
|
914
|
+
else self.unet.config.sample_size[0]
|
915
|
+
)
|
916
|
+
width = (
|
917
|
+
self.unet.config.sample_size
|
918
|
+
if self._is_unet_config_sample_size_int
|
919
|
+
else self.unet.config.sample_size[1]
|
920
|
+
)
|
921
|
+
height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
|
887
922
|
# to deal with lora scaling and other possible forward hooks
|
888
923
|
|
889
924
|
# 1. Check inputs. Raise error if not correct
|
@@ -1036,6 +1071,9 @@ class StableDiffusionPipeline(
|
|
1036
1071
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
1037
1072
|
callback(step_idx, t, latents)
|
1038
1073
|
|
1074
|
+
if XLA_AVAILABLE:
|
1075
|
+
xm.mark_step()
|
1076
|
+
|
1039
1077
|
if not output_type == "latent":
|
1040
1078
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
1041
1079
|
0
|
@@ -1049,7 +1087,6 @@ class StableDiffusionPipeline(
|
|
1049
1087
|
do_denormalize = [True] * image.shape[0]
|
1050
1088
|
else:
|
1051
1089
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
1052
|
-
|
1053
1090
|
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
1054
1091
|
|
1055
1092
|
# Offload all models
|
@@ -119,7 +119,7 @@ def retrieve_timesteps(
|
|
119
119
|
sigmas: Optional[List[float]] = None,
|
120
120
|
**kwargs,
|
121
121
|
):
|
122
|
-
"""
|
122
|
+
r"""
|
123
123
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
124
124
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
125
125
|
|
@@ -60,7 +60,7 @@ def retrieve_timesteps(
|
|
60
60
|
sigmas: Optional[List[float]] = None,
|
61
61
|
**kwargs,
|
62
62
|
):
|
63
|
-
"""
|
63
|
+
r"""
|
64
64
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
65
65
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
66
66
|
|