diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +20 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +7 -3
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +2 -2
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +110 -4
- diffusers/loaders/autoencoder.py +7 -8
- diffusers/loaders/controlnet.py +17 -8
- diffusers/loaders/ip_adapter.py +86 -23
- diffusers/loaders/lora.py +105 -310
- diffusers/loaders/lora_conversion_utils.py +1 -1
- diffusers/loaders/peft.py +1 -1
- diffusers/loaders/single_file.py +51 -12
- diffusers/loaders/single_file_utils.py +274 -49
- diffusers/loaders/textual_inversion.py +23 -4
- diffusers/loaders/unet.py +195 -41
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +3 -1
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +26 -36
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +171 -114
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flax.py +1 -1
- diffusers/models/downsampling.py +8 -12
- diffusers/models/dual_transformer_2d.py +1 -1
- diffusers/models/embeddings.py +3 -4
- diffusers/models/embeddings_flax.py +1 -1
- diffusers/models/lora.py +33 -10
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +4 -6
- diffusers/models/normalization.py +1 -1
- diffusers/models/resnet.py +31 -58
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/t5_film_transformer.py +1 -1
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/transformer_temporal.py +1 -1
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/t5_film_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +29 -31
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unet_1d.py +1 -1
- diffusers/models/unet_1d_blocks.py +1 -1
- diffusers/models/unet_2d.py +1 -1
- diffusers/models/unet_2d_blocks.py +1 -1
- diffusers/models/unet_2d_condition.py +1 -1
- diffusers/models/unets/__init__.py +1 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +4 -4
- diffusers/models/unets/unet_2d_blocks.py +238 -98
- diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +420 -323
- diffusers/models/unets/unet_2d_condition_flax.py +21 -12
- diffusers/models/unets/unet_3d_blocks.py +50 -40
- diffusers/models/unets/unet_3d_condition.py +47 -8
- diffusers/models/unets/unet_i2vgen_xl.py +75 -30
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +48 -8
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +610 -0
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +10 -16
- diffusers/models/vae_flax.py +1 -1
- diffusers/models/vq_model.py +1 -1
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +26 -0
- diffusers/pipelines/amused/pipeline_amused.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
- diffusers/pipelines/animatediff/pipeline_output.py +7 -6
- diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
- diffusers/pipelines/auto_pipeline.py +7 -16
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/free_init_utils.py +184 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ledits_pp/__init__.py +55 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
- diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
- diffusers/pipelines/onnx_utils.py +1 -1
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
- diffusers/pipelines/pia/pipeline_pia.py +168 -327
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +508 -0
- diffusers/pipelines/pipeline_utils.py +188 -534
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/__init__.py +50 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
- diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
- diffusers/pipelines/unclip/text_proj.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
- diffusers/schedulers/__init__.py +7 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +42 -19
- diffusers/schedulers/scheduling_ddim.py +2 -4
- diffusers/schedulers/scheduling_ddim_flax.py +13 -5
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm.py +2 -4
- diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +46 -19
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
- diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
- diffusers/schedulers/scheduling_edm_euler.py +381 -0
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
- diffusers/schedulers/scheduling_euler_discrete.py +42 -17
- diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_heun_discrete.py +35 -35
- diffusers/schedulers/scheduling_ipndm.py +37 -11
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
- diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +38 -14
- diffusers/schedulers/scheduling_lms_discrete.py +43 -15
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +2 -4
- diffusers/schedulers/scheduling_pndm_flax.py +2 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +41 -9
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_tcd.py +686 -0
- diffusers/schedulers/scheduling_unclip.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +9 -2
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +3 -3
- diffusers/utils/hub_utils.py +60 -16
- diffusers/utils/import_utils.py +15 -1
- diffusers/utils/loading_utils.py +2 -0
- diffusers/utils/logging.py +1 -1
- diffusers/utils/model_card_template.md +24 -0
- diffusers/utils/outputs.py +14 -7
- diffusers/utils/peft_utils.py +1 -1
- diffusers/utils/state_dict_utils.py +1 -1
- diffusers/utils/testing_utils.py +2 -0
- diffusers/utils/torch_utils.py +1 -1
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
- diffusers-0.27.0.dist-info/RECORD +399 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
- diffusers-0.26.3.dist-info/RECORD +0 -384
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -98,15 +98,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
98
98
|
self.custom_timesteps = False
|
99
99
|
self.is_scale_input_called = False
|
100
100
|
self._step_index = None
|
101
|
+
self._begin_index = None
|
101
102
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
102
103
|
|
103
|
-
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
104
|
-
if schedule_timesteps is None:
|
105
|
-
schedule_timesteps = self.timesteps
|
106
|
-
|
107
|
-
indices = (schedule_timesteps == timestep).nonzero()
|
108
|
-
return indices.item()
|
109
|
-
|
110
104
|
@property
|
111
105
|
def step_index(self):
|
112
106
|
"""
|
@@ -114,6 +108,24 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
114
108
|
"""
|
115
109
|
return self._step_index
|
116
110
|
|
111
|
+
@property
|
112
|
+
def begin_index(self):
|
113
|
+
"""
|
114
|
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
115
|
+
"""
|
116
|
+
return self._begin_index
|
117
|
+
|
118
|
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
119
|
+
def set_begin_index(self, begin_index: int = 0):
|
120
|
+
"""
|
121
|
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
begin_index (`int`):
|
125
|
+
The begin index for the scheduler.
|
126
|
+
"""
|
127
|
+
self._begin_index = begin_index
|
128
|
+
|
117
129
|
def scale_model_input(
|
118
130
|
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
119
131
|
) -> torch.FloatTensor:
|
@@ -231,6 +243,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
231
243
|
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
232
244
|
|
233
245
|
self._step_index = None
|
246
|
+
self._begin_index = None
|
234
247
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
235
248
|
|
236
249
|
# Modified _convert_to_karras implementation that takes in ramp as argument
|
@@ -280,23 +293,29 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
280
293
|
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
281
294
|
return c_skip, c_out
|
282
295
|
|
283
|
-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.
|
284
|
-
def
|
285
|
-
if
|
286
|
-
|
296
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
297
|
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
298
|
+
if schedule_timesteps is None:
|
299
|
+
schedule_timesteps = self.timesteps
|
287
300
|
|
288
|
-
|
301
|
+
indices = (schedule_timesteps == timestep).nonzero()
|
289
302
|
|
290
303
|
# The sigma index that is taken for the **very** first `step`
|
291
304
|
# is always the second index (or the last index if there is only 1)
|
292
305
|
# This way we can ensure we don't accidentally skip a sigma in
|
293
306
|
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
294
|
-
if len(
|
295
|
-
|
296
|
-
|
297
|
-
step_index = index_candidates[0]
|
307
|
+
pos = 1 if len(indices) > 1 else 0
|
308
|
+
|
309
|
+
return indices[pos].item()
|
298
310
|
|
299
|
-
|
311
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
312
|
+
def _init_step_index(self, timestep):
|
313
|
+
if self.begin_index is None:
|
314
|
+
if isinstance(timestep, torch.Tensor):
|
315
|
+
timestep = timestep.to(self.timesteps.device)
|
316
|
+
self._step_index = self.index_for_timestep(timestep)
|
317
|
+
else:
|
318
|
+
self._step_index = self._begin_index
|
300
319
|
|
301
320
|
def step(
|
302
321
|
self,
|
@@ -412,7 +431,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
412
431
|
schedule_timesteps = self.timesteps.to(original_samples.device)
|
413
432
|
timesteps = timesteps.to(original_samples.device)
|
414
433
|
|
415
|
-
|
434
|
+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
435
|
+
if self.begin_index is None:
|
436
|
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
437
|
+
else:
|
438
|
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
416
439
|
|
417
440
|
sigma = sigmas[step_indices].flatten()
|
418
441
|
while len(sigma.shape) < len(original_samples.shape):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -157,9 +157,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
157
157
|
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
158
158
|
otherwise it uses the alpha value at step 0.
|
159
159
|
steps_offset (`int`, defaults to 0):
|
160
|
-
An offset added to the inference steps
|
161
|
-
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
162
|
-
Diffusion.
|
160
|
+
An offset added to the inference steps, as required by some model families.
|
163
161
|
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
164
162
|
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
165
163
|
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -85,15 +85,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
85
85
|
trained_betas (`jnp.ndarray`, optional):
|
86
86
|
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
87
87
|
clip_sample (`bool`, default `True`):
|
88
|
-
option to clip predicted sample between
|
88
|
+
option to clip predicted sample between for numerical stability. The clip range is determined by `clip_sample_range`.
|
89
|
+
clip_sample_range (`float`, default `1.0`):
|
90
|
+
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
89
91
|
set_alpha_to_one (`bool`, default `True`):
|
90
92
|
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
91
93
|
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
92
94
|
otherwise it uses the value of alpha at step 0.
|
93
95
|
steps_offset (`int`, default `0`):
|
94
|
-
|
95
|
-
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
96
|
-
stable diffusion.
|
96
|
+
An offset added to the inference steps, as required by some model families.
|
97
97
|
prediction_type (`str`, default `epsilon`):
|
98
98
|
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
99
99
|
`v-prediction` is not supported for this scheduler.
|
@@ -117,6 +117,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
117
117
|
beta_end: float = 0.02,
|
118
118
|
beta_schedule: str = "linear",
|
119
119
|
trained_betas: Optional[jnp.ndarray] = None,
|
120
|
+
clip_sample: bool = True,
|
121
|
+
clip_sample_range: float = 1.0,
|
120
122
|
set_alpha_to_one: bool = True,
|
121
123
|
steps_offset: int = 0,
|
122
124
|
prediction_type: str = "epsilon",
|
@@ -267,6 +269,12 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
267
269
|
" `v_prediction`"
|
268
270
|
)
|
269
271
|
|
272
|
+
# 4. Clip or threshold "predicted x_0"
|
273
|
+
if self.config.clip_sample:
|
274
|
+
pred_original_sample = pred_original_sample.clip(
|
275
|
+
-self.config.clip_sample_range, self.config.clip_sample_range
|
276
|
+
)
|
277
|
+
|
270
278
|
# 4. compute variance: "sigma_t(η)" -> see formula (16)
|
271
279
|
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
272
280
|
variance = self._get_variance(state, timestep, prev_timestep)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -155,9 +155,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
155
155
|
there is no previous alpha. When this option is `True` the previous alpha product is fixed to 0, otherwise
|
156
156
|
it uses the alpha value at step `num_train_timesteps - 1`.
|
157
157
|
steps_offset (`int`, defaults to 0):
|
158
|
-
An offset added to the inference steps
|
159
|
-
`set_alpha_to_one=False` to make the last step use `num_train_timesteps - 1` for the previous alpha
|
160
|
-
product.
|
158
|
+
An offset added to the inference steps, as required by some model families.
|
161
159
|
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
162
160
|
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
163
161
|
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -159,9 +159,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
159
159
|
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
160
160
|
otherwise it uses the value of alpha at step 0.
|
161
161
|
steps_offset (`int`, default `0`):
|
162
|
-
|
163
|
-
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
164
|
-
stable diffusion.
|
162
|
+
An offset added to the inference steps, as required by some model families.
|
165
163
|
prediction_type (`str`, default `epsilon`, optional):
|
166
164
|
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
167
165
|
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -167,9 +167,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
167
167
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
168
168
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
169
169
|
steps_offset (`int`, defaults to 0):
|
170
|
-
An offset added to the inference steps
|
171
|
-
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
172
|
-
Diffusion.
|
170
|
+
An offset added to the inference steps, as required by some model families.
|
173
171
|
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
174
172
|
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
175
173
|
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -173,9 +173,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
173
173
|
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
|
174
174
|
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
|
175
175
|
steps_offset (`int`, default `0`):
|
176
|
-
|
177
|
-
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
178
|
-
stable diffusion.
|
176
|
+
An offset added to the inference steps, as required by some model families.
|
179
177
|
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
180
178
|
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
181
179
|
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# Copyright (c) 2022 Pablo Pernías MIT License
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 FLAIR Lab and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -115,9 +115,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
115
115
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
116
116
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
117
117
|
steps_offset (`int`, defaults to 0):
|
118
|
-
An offset added to the inference steps
|
119
|
-
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
120
|
-
Diffusion.
|
118
|
+
An offset added to the inference steps, as required by some model families.
|
121
119
|
"""
|
122
120
|
|
123
121
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
@@ -187,6 +185,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
187
185
|
self.model_outputs = [None] * solver_order
|
188
186
|
self.lower_order_nums = 0
|
189
187
|
self._step_index = None
|
188
|
+
self._begin_index = None
|
190
189
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
191
190
|
|
192
191
|
@property
|
@@ -196,6 +195,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
196
195
|
"""
|
197
196
|
return self._step_index
|
198
197
|
|
198
|
+
@property
|
199
|
+
def begin_index(self):
|
200
|
+
"""
|
201
|
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
202
|
+
"""
|
203
|
+
return self._begin_index
|
204
|
+
|
205
|
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
206
|
+
def set_begin_index(self, begin_index: int = 0):
|
207
|
+
"""
|
208
|
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
begin_index (`int`):
|
212
|
+
The begin index for the scheduler.
|
213
|
+
"""
|
214
|
+
self._begin_index = begin_index
|
215
|
+
|
199
216
|
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
200
217
|
"""
|
201
218
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
@@ -255,6 +272,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
255
272
|
|
256
273
|
# add an index counter for schedulers that allow duplicated timesteps
|
257
274
|
self._step_index = None
|
275
|
+
self._begin_index = None
|
258
276
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
259
277
|
|
260
278
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
@@ -620,11 +638,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
620
638
|
else:
|
621
639
|
raise NotImplementedError("only support log-rho multistep deis now")
|
622
640
|
|
623
|
-
|
624
|
-
|
625
|
-
|
641
|
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
642
|
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
643
|
+
if schedule_timesteps is None:
|
644
|
+
schedule_timesteps = self.timesteps
|
626
645
|
|
627
|
-
index_candidates = (
|
646
|
+
index_candidates = (schedule_timesteps == timestep).nonzero()
|
628
647
|
|
629
648
|
if len(index_candidates) == 0:
|
630
649
|
step_index = len(self.timesteps) - 1
|
@@ -637,7 +656,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
637
656
|
else:
|
638
657
|
step_index = index_candidates[0].item()
|
639
658
|
|
640
|
-
|
659
|
+
return step_index
|
660
|
+
|
661
|
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
662
|
+
def _init_step_index(self, timestep):
|
663
|
+
"""
|
664
|
+
Initialize the step_index counter for the scheduler.
|
665
|
+
"""
|
666
|
+
|
667
|
+
if self.begin_index is None:
|
668
|
+
if isinstance(timestep, torch.Tensor):
|
669
|
+
timestep = timestep.to(self.timesteps.device)
|
670
|
+
self._step_index = self.index_for_timestep(timestep)
|
671
|
+
else:
|
672
|
+
self._step_index = self._begin_index
|
641
673
|
|
642
674
|
def step(
|
643
675
|
self,
|
@@ -736,16 +768,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
736
768
|
schedule_timesteps = self.timesteps.to(original_samples.device)
|
737
769
|
timesteps = timesteps.to(original_samples.device)
|
738
770
|
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
elif len(index_candidates) > 1:
|
745
|
-
step_index = index_candidates[1].item()
|
746
|
-
else:
|
747
|
-
step_index = index_candidates[0].item()
|
748
|
-
step_indices.append(step_index)
|
771
|
+
# begin_index is None when the scheduler is used for training
|
772
|
+
if self.begin_index is None:
|
773
|
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
774
|
+
else:
|
775
|
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
749
776
|
|
750
777
|
sigma = sigmas[step_indices].flatten()
|
751
778
|
while len(sigma.shape) < len(original_samples.shape):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -71,6 +71,43 @@ def betas_for_alpha_bar(
|
|
71
71
|
return torch.tensor(betas, dtype=torch.float32)
|
72
72
|
|
73
73
|
|
74
|
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
75
|
+
def rescale_zero_terminal_snr(betas):
|
76
|
+
"""
|
77
|
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
78
|
+
|
79
|
+
|
80
|
+
Args:
|
81
|
+
betas (`torch.FloatTensor`):
|
82
|
+
the betas that the scheduler is being initialized with.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
86
|
+
"""
|
87
|
+
# Convert betas to alphas_bar_sqrt
|
88
|
+
alphas = 1.0 - betas
|
89
|
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
90
|
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
91
|
+
|
92
|
+
# Store old values.
|
93
|
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
94
|
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
95
|
+
|
96
|
+
# Shift so the last timestep is zero.
|
97
|
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
98
|
+
|
99
|
+
# Scale so the first timestep is back to the old value.
|
100
|
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
101
|
+
|
102
|
+
# Convert alphas_bar_sqrt to betas
|
103
|
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
104
|
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
105
|
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
106
|
+
betas = 1 - alphas
|
107
|
+
|
108
|
+
return betas
|
109
|
+
|
110
|
+
|
74
111
|
class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
75
112
|
"""
|
76
113
|
`DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
|
@@ -141,9 +178,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
141
178
|
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
142
179
|
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
143
180
|
steps_offset (`int`, defaults to 0):
|
144
|
-
An offset added to the inference steps
|
145
|
-
|
146
|
-
|
181
|
+
An offset added to the inference steps, as required by some model families.
|
182
|
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
183
|
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
184
|
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
185
|
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
147
186
|
"""
|
148
187
|
|
149
188
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
@@ -173,6 +212,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
173
212
|
variance_type: Optional[str] = None,
|
174
213
|
timestep_spacing: str = "linspace",
|
175
214
|
steps_offset: int = 0,
|
215
|
+
rescale_betas_zero_snr: bool = False,
|
176
216
|
):
|
177
217
|
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
178
218
|
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
@@ -191,8 +231,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
191
231
|
else:
|
192
232
|
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
193
233
|
|
234
|
+
if rescale_betas_zero_snr:
|
235
|
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
236
|
+
|
194
237
|
self.alphas = 1.0 - self.betas
|
195
238
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
239
|
+
|
240
|
+
if rescale_betas_zero_snr:
|
241
|
+
# Close to 0 without being 0 so first sigma is not inf
|
242
|
+
# FP16 smallest positive subnormal works well here
|
243
|
+
self.alphas_cumprod[-1] = 2**-24
|
244
|
+
|
196
245
|
# Currently we only support VP-type noise schedule
|
197
246
|
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
198
247
|
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
@@ -227,6 +276,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
227
276
|
self.model_outputs = [None] * solver_order
|
228
277
|
self.lower_order_nums = 0
|
229
278
|
self._step_index = None
|
279
|
+
self._begin_index = None
|
230
280
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
231
281
|
|
232
282
|
@property
|
@@ -236,6 +286,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
236
286
|
"""
|
237
287
|
return self._step_index
|
238
288
|
|
289
|
+
@property
|
290
|
+
def begin_index(self):
|
291
|
+
"""
|
292
|
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
293
|
+
"""
|
294
|
+
return self._begin_index
|
295
|
+
|
296
|
+
def set_begin_index(self, begin_index: int = 0):
|
297
|
+
"""
|
298
|
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
299
|
+
|
300
|
+
Args:
|
301
|
+
begin_index (`int`):
|
302
|
+
The begin index for the scheduler.
|
303
|
+
"""
|
304
|
+
self._begin_index = begin_index
|
305
|
+
|
239
306
|
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
240
307
|
"""
|
241
308
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
@@ -311,6 +378,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
311
378
|
|
312
379
|
# add an index counter for schedulers that allow duplicated timesteps
|
313
380
|
self._step_index = None
|
381
|
+
self._begin_index = None
|
314
382
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
315
383
|
|
316
384
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
@@ -792,11 +860,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
792
860
|
)
|
793
861
|
return x_t
|
794
862
|
|
795
|
-
def
|
796
|
-
if
|
797
|
-
|
863
|
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
864
|
+
if schedule_timesteps is None:
|
865
|
+
schedule_timesteps = self.timesteps
|
798
866
|
|
799
|
-
index_candidates = (
|
867
|
+
index_candidates = (schedule_timesteps == timestep).nonzero()
|
800
868
|
|
801
869
|
if len(index_candidates) == 0:
|
802
870
|
step_index = len(self.timesteps) - 1
|
@@ -809,7 +877,19 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
809
877
|
else:
|
810
878
|
step_index = index_candidates[0].item()
|
811
879
|
|
812
|
-
|
880
|
+
return step_index
|
881
|
+
|
882
|
+
def _init_step_index(self, timestep):
|
883
|
+
"""
|
884
|
+
Initialize the step_index counter for the scheduler.
|
885
|
+
"""
|
886
|
+
|
887
|
+
if self.begin_index is None:
|
888
|
+
if isinstance(timestep, torch.Tensor):
|
889
|
+
timestep = timestep.to(self.timesteps.device)
|
890
|
+
self._step_index = self.index_for_timestep(timestep)
|
891
|
+
else:
|
892
|
+
self._step_index = self._begin_index
|
813
893
|
|
814
894
|
def step(
|
815
895
|
self,
|
@@ -817,6 +897,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
817
897
|
timestep: int,
|
818
898
|
sample: torch.FloatTensor,
|
819
899
|
generator=None,
|
900
|
+
variance_noise: Optional[torch.FloatTensor] = None,
|
820
901
|
return_dict: bool = True,
|
821
902
|
) -> Union[SchedulerOutput, Tuple]:
|
822
903
|
"""
|
@@ -832,6 +913,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
832
913
|
A current instance of a sample created by the diffusion process.
|
833
914
|
generator (`torch.Generator`, *optional*):
|
834
915
|
A random number generator.
|
916
|
+
variance_noise (`torch.FloatTensor`):
|
917
|
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
918
|
+
itself. Useful for methods such as [`LEdits++`].
|
835
919
|
return_dict (`bool`):
|
836
920
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
837
921
|
|
@@ -864,10 +948,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
864
948
|
self.model_outputs[i] = self.model_outputs[i + 1]
|
865
949
|
self.model_outputs[-1] = model_output
|
866
950
|
|
867
|
-
|
951
|
+
# Upcast to avoid precision issues when computing prev_sample
|
952
|
+
sample = sample.to(torch.float32)
|
953
|
+
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
868
954
|
noise = randn_tensor(
|
869
|
-
model_output.shape, generator=generator, device=model_output.device, dtype=
|
955
|
+
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
|
870
956
|
)
|
957
|
+
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
958
|
+
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
|
871
959
|
else:
|
872
960
|
noise = None
|
873
961
|
|
@@ -881,6 +969,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
881
969
|
if self.lower_order_nums < self.config.solver_order:
|
882
970
|
self.lower_order_nums += 1
|
883
971
|
|
972
|
+
# Cast sample back to expected dtype
|
973
|
+
prev_sample = prev_sample.to(model_output.dtype)
|
974
|
+
|
884
975
|
# upon completion increase step index by one
|
885
976
|
self._step_index += 1
|
886
977
|
|
@@ -920,16 +1011,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
920
1011
|
schedule_timesteps = self.timesteps.to(original_samples.device)
|
921
1012
|
timesteps = timesteps.to(original_samples.device)
|
922
1013
|
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
elif len(index_candidates) > 1:
|
929
|
-
step_index = index_candidates[1].item()
|
930
|
-
else:
|
931
|
-
step_index = index_candidates[0].item()
|
932
|
-
step_indices.append(step_index)
|
1014
|
+
# begin_index is None when the scheduler is used for training
|
1015
|
+
if self.begin_index is None:
|
1016
|
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
1017
|
+
else:
|
1018
|
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
933
1019
|
|
934
1020
|
sigma = sigmas[step_indices].flatten()
|
935
1021
|
while len(sigma.shape) < len(original_samples.shape):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|