diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -61,7 +61,7 @@ def betas_for_alpha_bar(
|
|
61
61
|
return math.exp(t * -12.0)
|
62
62
|
|
63
63
|
else:
|
64
|
-
raise ValueError(f"Unsupported
|
64
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
65
65
|
|
66
66
|
betas = []
|
67
67
|
for i in range(num_diffusion_timesteps):
|
@@ -71,6 +71,43 @@ def betas_for_alpha_bar(
|
|
71
71
|
return torch.tensor(betas, dtype=torch.float32)
|
72
72
|
|
73
73
|
|
74
|
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
75
|
+
def rescale_zero_terminal_snr(betas):
|
76
|
+
"""
|
77
|
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
78
|
+
|
79
|
+
|
80
|
+
Args:
|
81
|
+
betas (`torch.Tensor`):
|
82
|
+
the betas that the scheduler is being initialized with.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
86
|
+
"""
|
87
|
+
# Convert betas to alphas_bar_sqrt
|
88
|
+
alphas = 1.0 - betas
|
89
|
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
90
|
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
91
|
+
|
92
|
+
# Store old values.
|
93
|
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
94
|
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
95
|
+
|
96
|
+
# Shift so the last timestep is zero.
|
97
|
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
98
|
+
|
99
|
+
# Scale so the first timestep is back to the old value.
|
100
|
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
101
|
+
|
102
|
+
# Convert alphas_bar_sqrt to betas
|
103
|
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
104
|
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
105
|
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
106
|
+
betas = 1 - alphas
|
107
|
+
|
108
|
+
return betas
|
109
|
+
|
110
|
+
|
74
111
|
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
75
112
|
"""
|
76
113
|
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
|
@@ -127,6 +164,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
127
164
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
128
165
|
steps_offset (`int`, defaults to 0):
|
129
166
|
An offset added to the inference steps, as required by some model families.
|
167
|
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
168
|
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
169
|
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
170
|
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
171
|
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
172
|
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
173
|
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
130
174
|
"""
|
131
175
|
|
132
176
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
@@ -153,6 +197,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
153
197
|
use_karras_sigmas: Optional[bool] = False,
|
154
198
|
timestep_spacing: str = "linspace",
|
155
199
|
steps_offset: int = 0,
|
200
|
+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
201
|
+
rescale_betas_zero_snr: bool = False,
|
156
202
|
):
|
157
203
|
if trained_betas is not None:
|
158
204
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
@@ -165,10 +211,19 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
165
211
|
# Glide cosine schedule
|
166
212
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
167
213
|
else:
|
168
|
-
raise NotImplementedError(f"{beta_schedule}
|
214
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
215
|
+
|
216
|
+
if rescale_betas_zero_snr:
|
217
|
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
169
218
|
|
170
219
|
self.alphas = 1.0 - self.betas
|
171
220
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
221
|
+
|
222
|
+
if rescale_betas_zero_snr:
|
223
|
+
# Close to 0 without being 0 so first sigma is not inf
|
224
|
+
# FP16 smallest positive subnormal works well here
|
225
|
+
self.alphas_cumprod[-1] = 2**-24
|
226
|
+
|
172
227
|
# Currently we only support VP-type noise schedule
|
173
228
|
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
174
229
|
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
@@ -182,7 +237,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
182
237
|
if solver_type in ["midpoint", "heun", "logrho"]:
|
183
238
|
self.register_to_config(solver_type="bh2")
|
184
239
|
else:
|
185
|
-
raise NotImplementedError(f"{solver_type}
|
240
|
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
186
241
|
|
187
242
|
self.predict_x0 = predict_x0
|
188
243
|
# setable values
|
@@ -202,7 +257,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
202
257
|
@property
|
203
258
|
def step_index(self):
|
204
259
|
"""
|
205
|
-
The index counter for current timestep. It will
|
260
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
206
261
|
"""
|
207
262
|
return self._step_index
|
208
263
|
|
@@ -265,10 +320,25 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
265
320
|
sigmas = np.flip(sigmas).copy()
|
266
321
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
267
322
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
268
|
-
|
323
|
+
if self.config.final_sigmas_type == "sigma_min":
|
324
|
+
sigma_last = sigmas[-1]
|
325
|
+
elif self.config.final_sigmas_type == "zero":
|
326
|
+
sigma_last = 0
|
327
|
+
else:
|
328
|
+
raise ValueError(
|
329
|
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
330
|
+
)
|
331
|
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
269
332
|
else:
|
270
333
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
271
|
-
|
334
|
+
if self.config.final_sigmas_type == "sigma_min":
|
335
|
+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
336
|
+
elif self.config.final_sigmas_type == "zero":
|
337
|
+
sigma_last = 0
|
338
|
+
else:
|
339
|
+
raise ValueError(
|
340
|
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
341
|
+
)
|
272
342
|
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
273
343
|
|
274
344
|
self.sigmas = torch.from_numpy(sigmas)
|
@@ -290,7 +360,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
290
360
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
291
361
|
|
292
362
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
293
|
-
def _threshold_sample(self, sample: torch.
|
363
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
294
364
|
"""
|
295
365
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
296
366
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -355,7 +425,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
355
425
|
return alpha_t, sigma_t
|
356
426
|
|
357
427
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
358
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
428
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
359
429
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
360
430
|
|
361
431
|
# Hack to make sure that other schedulers which copy this function don't break
|
@@ -382,24 +452,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
382
452
|
|
383
453
|
def convert_model_output(
|
384
454
|
self,
|
385
|
-
model_output: torch.
|
455
|
+
model_output: torch.Tensor,
|
386
456
|
*args,
|
387
|
-
sample: torch.
|
457
|
+
sample: torch.Tensor = None,
|
388
458
|
**kwargs,
|
389
|
-
) -> torch.
|
459
|
+
) -> torch.Tensor:
|
390
460
|
r"""
|
391
461
|
Convert the model output to the corresponding type the UniPC algorithm needs.
|
392
462
|
|
393
463
|
Args:
|
394
|
-
model_output (`torch.
|
464
|
+
model_output (`torch.Tensor`):
|
395
465
|
The direct output from the learned diffusion model.
|
396
466
|
timestep (`int`):
|
397
467
|
The current discrete timestep in the diffusion chain.
|
398
|
-
sample (`torch.
|
468
|
+
sample (`torch.Tensor`):
|
399
469
|
A current instance of a sample created by the diffusion process.
|
400
470
|
|
401
471
|
Returns:
|
402
|
-
`torch.
|
472
|
+
`torch.Tensor`:
|
403
473
|
The converted model output.
|
404
474
|
"""
|
405
475
|
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
@@ -452,27 +522,27 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
452
522
|
|
453
523
|
def multistep_uni_p_bh_update(
|
454
524
|
self,
|
455
|
-
model_output: torch.
|
525
|
+
model_output: torch.Tensor,
|
456
526
|
*args,
|
457
|
-
sample: torch.
|
527
|
+
sample: torch.Tensor = None,
|
458
528
|
order: int = None,
|
459
529
|
**kwargs,
|
460
|
-
) -> torch.
|
530
|
+
) -> torch.Tensor:
|
461
531
|
"""
|
462
532
|
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
463
533
|
|
464
534
|
Args:
|
465
|
-
model_output (`torch.
|
535
|
+
model_output (`torch.Tensor`):
|
466
536
|
The direct output from the learned diffusion model at the current timestep.
|
467
537
|
prev_timestep (`int`):
|
468
538
|
The previous discrete timestep in the diffusion chain.
|
469
|
-
sample (`torch.
|
539
|
+
sample (`torch.Tensor`):
|
470
540
|
A current instance of a sample created by the diffusion process.
|
471
541
|
order (`int`):
|
472
542
|
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
|
473
543
|
|
474
544
|
Returns:
|
475
|
-
`torch.
|
545
|
+
`torch.Tensor`:
|
476
546
|
The sample tensor at the previous timestep.
|
477
547
|
"""
|
478
548
|
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
@@ -557,7 +627,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
557
627
|
if order == 2:
|
558
628
|
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
559
629
|
else:
|
560
|
-
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
630
|
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
|
561
631
|
else:
|
562
632
|
D1s = None
|
563
633
|
|
@@ -581,30 +651,30 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
581
651
|
|
582
652
|
def multistep_uni_c_bh_update(
|
583
653
|
self,
|
584
|
-
this_model_output: torch.
|
654
|
+
this_model_output: torch.Tensor,
|
585
655
|
*args,
|
586
|
-
last_sample: torch.
|
587
|
-
this_sample: torch.
|
656
|
+
last_sample: torch.Tensor = None,
|
657
|
+
this_sample: torch.Tensor = None,
|
588
658
|
order: int = None,
|
589
659
|
**kwargs,
|
590
|
-
) -> torch.
|
660
|
+
) -> torch.Tensor:
|
591
661
|
"""
|
592
662
|
One step for the UniC (B(h) version).
|
593
663
|
|
594
664
|
Args:
|
595
|
-
this_model_output (`torch.
|
665
|
+
this_model_output (`torch.Tensor`):
|
596
666
|
The model outputs at `x_t`.
|
597
667
|
this_timestep (`int`):
|
598
668
|
The current timestep `t`.
|
599
|
-
last_sample (`torch.
|
669
|
+
last_sample (`torch.Tensor`):
|
600
670
|
The generated sample before the last predictor `x_{t-1}`.
|
601
|
-
this_sample (`torch.
|
671
|
+
this_sample (`torch.Tensor`):
|
602
672
|
The generated sample after the last predictor `x_{t}`.
|
603
673
|
order (`int`):
|
604
674
|
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
|
605
675
|
|
606
676
|
Returns:
|
607
|
-
`torch.
|
677
|
+
`torch.Tensor`:
|
608
678
|
The corrected sample tensor at the current timestep.
|
609
679
|
"""
|
610
680
|
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
@@ -695,7 +765,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
695
765
|
if order == 1:
|
696
766
|
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
697
767
|
else:
|
698
|
-
rhos_c = torch.linalg.solve(R, b)
|
768
|
+
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
|
699
769
|
|
700
770
|
if self.predict_x0:
|
701
771
|
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
@@ -751,9 +821,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
751
821
|
|
752
822
|
def step(
|
753
823
|
self,
|
754
|
-
model_output: torch.
|
824
|
+
model_output: torch.Tensor,
|
755
825
|
timestep: int,
|
756
|
-
sample: torch.
|
826
|
+
sample: torch.Tensor,
|
757
827
|
return_dict: bool = True,
|
758
828
|
) -> Union[SchedulerOutput, Tuple]:
|
759
829
|
"""
|
@@ -761,11 +831,11 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
761
831
|
the multistep UniPC.
|
762
832
|
|
763
833
|
Args:
|
764
|
-
model_output (`torch.
|
834
|
+
model_output (`torch.Tensor`):
|
765
835
|
The direct output from learned diffusion model.
|
766
836
|
timestep (`int`):
|
767
837
|
The current discrete timestep in the diffusion chain.
|
768
|
-
sample (`torch.
|
838
|
+
sample (`torch.Tensor`):
|
769
839
|
A current instance of a sample created by the diffusion process.
|
770
840
|
return_dict (`bool`):
|
771
841
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
@@ -830,17 +900,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
830
900
|
|
831
901
|
return SchedulerOutput(prev_sample=prev_sample)
|
832
902
|
|
833
|
-
def scale_model_input(self, sample: torch.
|
903
|
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
834
904
|
"""
|
835
905
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
836
906
|
current timestep.
|
837
907
|
|
838
908
|
Args:
|
839
|
-
sample (`torch.
|
909
|
+
sample (`torch.Tensor`):
|
840
910
|
The input sample.
|
841
911
|
|
842
912
|
Returns:
|
843
|
-
`torch.
|
913
|
+
`torch.Tensor`:
|
844
914
|
A scaled input sample.
|
845
915
|
"""
|
846
916
|
return sample
|
@@ -848,10 +918,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
848
918
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
849
919
|
def add_noise(
|
850
920
|
self,
|
851
|
-
original_samples: torch.
|
852
|
-
noise: torch.
|
921
|
+
original_samples: torch.Tensor,
|
922
|
+
noise: torch.Tensor,
|
853
923
|
timesteps: torch.IntTensor,
|
854
|
-
) -> torch.
|
924
|
+
) -> torch.Tensor:
|
855
925
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
856
926
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
857
927
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -862,10 +932,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
862
932
|
schedule_timesteps = self.timesteps.to(original_samples.device)
|
863
933
|
timesteps = timesteps.to(original_samples.device)
|
864
934
|
|
865
|
-
# begin_index is None when the scheduler is used for training
|
935
|
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
866
936
|
if self.begin_index is None:
|
867
937
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
938
|
+
elif self.step_index is not None:
|
939
|
+
# add_noise is called after first denoising step (for inpainting)
|
940
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
868
941
|
else:
|
942
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
869
943
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
870
944
|
|
871
945
|
sigma = sigmas[step_indices].flatten()
|
@@ -48,18 +48,27 @@ class KarrasDiffusionSchedulers(Enum):
|
|
48
48
|
EDMEulerScheduler = 15
|
49
49
|
|
50
50
|
|
51
|
+
AysSchedules = {
|
52
|
+
"StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24],
|
53
|
+
"StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0],
|
54
|
+
"StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13],
|
55
|
+
"StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0],
|
56
|
+
"StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0],
|
57
|
+
}
|
58
|
+
|
59
|
+
|
51
60
|
@dataclass
|
52
61
|
class SchedulerOutput(BaseOutput):
|
53
62
|
"""
|
54
63
|
Base class for the output of a scheduler's `step` function.
|
55
64
|
|
56
65
|
Args:
|
57
|
-
prev_sample (`torch.
|
66
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
58
67
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
59
68
|
denoising loop.
|
60
69
|
"""
|
61
70
|
|
62
|
-
prev_sample: torch.
|
71
|
+
prev_sample: torch.Tensor
|
63
72
|
|
64
73
|
|
65
74
|
class SchedulerMixin(PushToHubMixin):
|
@@ -112,9 +121,9 @@ class SchedulerMixin(PushToHubMixin):
|
|
112
121
|
force_download (`bool`, *optional*, defaults to `False`):
|
113
122
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
114
123
|
cached versions if they exist.
|
115
|
-
resume_download
|
116
|
-
|
117
|
-
|
124
|
+
resume_download:
|
125
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
126
|
+
of Diffusers.
|
118
127
|
proxies (`Dict[str, str]`, *optional*):
|
119
128
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
120
129
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -102,9 +102,9 @@ class FlaxSchedulerMixin(PushToHubMixin):
|
|
102
102
|
force_download (`bool`, *optional*, defaults to `False`):
|
103
103
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
104
104
|
cached versions if they exist.
|
105
|
-
resume_download
|
106
|
-
|
107
|
-
|
105
|
+
resume_download:
|
106
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
107
|
+
of Diffusers.
|
108
108
|
proxies (`Dict[str, str]`, *optional*):
|
109
109
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
110
110
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -38,7 +38,7 @@ class VQDiffusionSchedulerOutput(BaseOutput):
|
|
38
38
|
prev_sample: torch.LongTensor
|
39
39
|
|
40
40
|
|
41
|
-
def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.
|
41
|
+
def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.Tensor:
|
42
42
|
"""
|
43
43
|
Convert batch of vector of class indices into batch of log onehot vectors
|
44
44
|
|
@@ -50,7 +50,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
|
|
50
50
|
number of classes to be used for the onehot vectors
|
51
51
|
|
52
52
|
Returns:
|
53
|
-
`torch.
|
53
|
+
`torch.Tensor` of shape `(batch size, num classes, vector length)`:
|
54
54
|
Log onehot vectors
|
55
55
|
"""
|
56
56
|
x_onehot = F.one_hot(x, num_classes)
|
@@ -59,7 +59,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
|
|
59
59
|
return log_x
|
60
60
|
|
61
61
|
|
62
|
-
def gumbel_noised(logits: torch.
|
62
|
+
def gumbel_noised(logits: torch.Tensor, generator: Optional[torch.Generator]) -> torch.Tensor:
|
63
63
|
"""
|
64
64
|
Apply gumbel noise to `logits`
|
65
65
|
"""
|
@@ -199,7 +199,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
|
199
199
|
|
200
200
|
def step(
|
201
201
|
self,
|
202
|
-
model_output: torch.
|
202
|
+
model_output: torch.Tensor,
|
203
203
|
timestep: torch.long,
|
204
204
|
sample: torch.LongTensor,
|
205
205
|
generator: Optional[torch.Generator] = None,
|
@@ -210,7 +210,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
|
210
210
|
[`~VQDiffusionScheduler.q_posterior`] for more details about how the distribution is computer.
|
211
211
|
|
212
212
|
Args:
|
213
|
-
log_p_x_0: (`torch.
|
213
|
+
log_p_x_0: (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`):
|
214
214
|
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
|
215
215
|
prediction for the masked class as the initial unnoised image cannot be masked.
|
216
216
|
t (`torch.long`):
|
@@ -251,7 +251,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
|
251
251
|
```
|
252
252
|
|
253
253
|
Args:
|
254
|
-
log_p_x_0 (`torch.
|
254
|
+
log_p_x_0 (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`):
|
255
255
|
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
|
256
256
|
prediction for the masked class as the initial unnoised image cannot be masked.
|
257
257
|
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
|
@@ -260,7 +260,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
|
260
260
|
The timestep that determines which transition matrix is used.
|
261
261
|
|
262
262
|
Returns:
|
263
|
-
`torch.
|
263
|
+
`torch.Tensor` of shape `(batch size, num classes, num latent pixels)`:
|
264
264
|
The log probabilities for the predicted classes of the image at timestep `t-1`.
|
265
265
|
"""
|
266
266
|
log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed)
|
@@ -354,7 +354,7 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
|
354
354
|
return log_p_x_t_min_1
|
355
355
|
|
356
356
|
def log_Q_t_transitioning_to_known_class(
|
357
|
-
self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.
|
357
|
+
self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.Tensor, cumulative: bool
|
358
358
|
):
|
359
359
|
"""
|
360
360
|
Calculates the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each
|
@@ -365,14 +365,14 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
|
365
365
|
The timestep that determines which transition matrix is used.
|
366
366
|
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
|
367
367
|
The classes of each latent pixel at time `t`.
|
368
|
-
log_onehot_x_t (`torch.
|
368
|
+
log_onehot_x_t (`torch.Tensor` of shape `(batch size, num classes, num latent pixels)`):
|
369
369
|
The log one-hot vectors of `x_t`.
|
370
370
|
cumulative (`bool`):
|
371
371
|
If cumulative is `False`, the single step transition matrix `t-1`->`t` is used. If cumulative is
|
372
372
|
`True`, the cumulative transition matrix `0`->`t` is used.
|
373
373
|
|
374
374
|
Returns:
|
375
|
-
`torch.
|
375
|
+
`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`:
|
376
376
|
Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability
|
377
377
|
transition matrix.
|
378
378
|
|
diffusers/training_utils.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1
1
|
import contextlib
|
2
2
|
import copy
|
3
3
|
import random
|
4
|
-
from typing import Any, Dict, Iterable, List, Optional, Union
|
4
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
8
8
|
|
9
9
|
from .models import UNet2DConditionModel
|
10
|
+
from .schedulers import SchedulerMixin
|
10
11
|
from .utils import (
|
11
12
|
convert_state_dict_to_diffusers,
|
12
13
|
convert_state_dict_to_peft,
|
@@ -117,6 +118,60 @@ def resolve_interpolation_mode(interpolation_type: str):
|
|
117
118
|
return interpolation_mode
|
118
119
|
|
119
120
|
|
121
|
+
def compute_dream_and_update_latents(
|
122
|
+
unet: UNet2DConditionModel,
|
123
|
+
noise_scheduler: SchedulerMixin,
|
124
|
+
timesteps: torch.Tensor,
|
125
|
+
noise: torch.Tensor,
|
126
|
+
noisy_latents: torch.Tensor,
|
127
|
+
target: torch.Tensor,
|
128
|
+
encoder_hidden_states: torch.Tensor,
|
129
|
+
dream_detail_preservation: float = 1.0,
|
130
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
131
|
+
"""
|
132
|
+
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210.
|
133
|
+
DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra
|
134
|
+
forward step without gradients.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
`unet`: The state unet to use to make a prediction.
|
138
|
+
`noise_scheduler`: The noise scheduler used to add noise for the given timestep.
|
139
|
+
`timesteps`: The timesteps for the noise_scheduler to user.
|
140
|
+
`noise`: A tensor of noise in the shape of noisy_latents.
|
141
|
+
`noisy_latents`: Previously noise latents from the training loop.
|
142
|
+
`target`: The ground-truth tensor to predict after eps is removed.
|
143
|
+
`encoder_hidden_states`: Text embeddings from the text model.
|
144
|
+
`dream_detail_preservation`: A float value that indicates detail preservation level.
|
145
|
+
See reference.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
|
149
|
+
"""
|
150
|
+
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
|
151
|
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
152
|
+
|
153
|
+
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
|
154
|
+
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
|
155
|
+
|
156
|
+
pred = None
|
157
|
+
with torch.no_grad():
|
158
|
+
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
159
|
+
|
160
|
+
noisy_latents, target = (None, None)
|
161
|
+
if noise_scheduler.config.prediction_type == "epsilon":
|
162
|
+
predicted_noise = pred
|
163
|
+
delta_noise = (noise - predicted_noise).detach()
|
164
|
+
delta_noise.mul_(dream_lambda)
|
165
|
+
noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
|
166
|
+
target = target.add(delta_noise)
|
167
|
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
168
|
+
raise NotImplementedError("DREAM has not been implemented for v-prediction")
|
169
|
+
else:
|
170
|
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
171
|
+
|
172
|
+
return noisy_latents, target
|
173
|
+
|
174
|
+
|
120
175
|
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
121
176
|
r"""
|
122
177
|
Returns:
|
diffusers/utils/__init__.py
CHANGED
@@ -58,19 +58,26 @@ from .import_utils import (
|
|
58
58
|
get_objects_from_module,
|
59
59
|
is_accelerate_available,
|
60
60
|
is_accelerate_version,
|
61
|
+
is_bitsandbytes_available,
|
61
62
|
is_bs4_available,
|
62
63
|
is_flax_available,
|
63
64
|
is_ftfy_available,
|
65
|
+
is_google_colab,
|
64
66
|
is_inflect_available,
|
65
67
|
is_invisible_watermark_available,
|
66
68
|
is_k_diffusion_available,
|
67
69
|
is_k_diffusion_version,
|
68
70
|
is_librosa_available,
|
71
|
+
is_matplotlib_available,
|
69
72
|
is_note_seq_available,
|
73
|
+
is_notebook,
|
70
74
|
is_onnx_available,
|
71
75
|
is_peft_available,
|
76
|
+
is_peft_version,
|
77
|
+
is_safetensors_available,
|
72
78
|
is_scipy_available,
|
73
79
|
is_tensorboard_available,
|
80
|
+
is_timm_available,
|
74
81
|
is_torch_available,
|
75
82
|
is_torch_npu_available,
|
76
83
|
is_torch_version,
|
diffusers/utils/doc_utils.py
CHANGED
@@ -92,6 +92,21 @@ class ControlNetModel(metaclass=DummyObject):
|
|
92
92
|
requires_backends(cls, ["torch"])
|
93
93
|
|
94
94
|
|
95
|
+
class ControlNetXSAdapter(metaclass=DummyObject):
|
96
|
+
_backends = ["torch"]
|
97
|
+
|
98
|
+
def __init__(self, *args, **kwargs):
|
99
|
+
requires_backends(self, ["torch"])
|
100
|
+
|
101
|
+
@classmethod
|
102
|
+
def from_config(cls, *args, **kwargs):
|
103
|
+
requires_backends(cls, ["torch"])
|
104
|
+
|
105
|
+
@classmethod
|
106
|
+
def from_pretrained(cls, *args, **kwargs):
|
107
|
+
requires_backends(cls, ["torch"])
|
108
|
+
|
109
|
+
|
95
110
|
class I2VGenXLUNet(metaclass=DummyObject):
|
96
111
|
_backends = ["torch"]
|
97
112
|
|
@@ -287,6 +302,21 @@ class UNet3DConditionModel(metaclass=DummyObject):
|
|
287
302
|
requires_backends(cls, ["torch"])
|
288
303
|
|
289
304
|
|
305
|
+
class UNetControlNetXSModel(metaclass=DummyObject):
|
306
|
+
_backends = ["torch"]
|
307
|
+
|
308
|
+
def __init__(self, *args, **kwargs):
|
309
|
+
requires_backends(self, ["torch"])
|
310
|
+
|
311
|
+
@classmethod
|
312
|
+
def from_config(cls, *args, **kwargs):
|
313
|
+
requires_backends(cls, ["torch"])
|
314
|
+
|
315
|
+
@classmethod
|
316
|
+
def from_pretrained(cls, *args, **kwargs):
|
317
|
+
requires_backends(cls, ["torch"])
|
318
|
+
|
319
|
+
|
290
320
|
class UNetMotionModel(metaclass=DummyObject):
|
291
321
|
_backends = ["torch"]
|
292
322
|
|