diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -2
- diffusers/commands/fp16_safetensors.py +10 -11
- diffusers/configuration_utils.py +13 -8
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +5 -5
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +463 -51
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +159 -0
- diffusers/loaders/lora.py +1553 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +637 -0
- diffusers/loaders/textual_inversion.py +455 -0
- diffusers/loaders/unet.py +828 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +26 -9
- diffusers/models/activations.py +9 -6
- diffusers/models/attention.py +301 -29
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +378 -6
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
- diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
- diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
- diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/downsampling.py +338 -0
- diffusers/models/embeddings.py +112 -29
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +14 -8
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +37 -29
- diffusers/models/normalization.py +110 -4
- diffusers/models/resnet.py +299 -652
- diffusers/models/transformer_2d.py +22 -5
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +46 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandinsky3.py +535 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/upsampling.py +454 -0
- diffusers/models/uvit_2d.py +471 -0
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +12 -3
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +137 -76
- diffusers/pipelines/amused/__init__.py +62 -0
- diffusers/pipelines/amused/pipeline_amused.py +328 -0
- diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +23 -13
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/deprecated/__init__.py +153 -0
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
- diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
- diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
- diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
- diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
- diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/onnx_utils.py +8 -5
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +11 -8
- diffusers/pipelines/pipeline_utils.py +63 -42
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/__init__.py +37 -65
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
- diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
- diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
- diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
- diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
- diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
- diffusers/schedulers/__init__.py +4 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_amused.py +162 -0
- diffusers/schedulers/scheduling_consistency_models.py +2 -0
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +47 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
- diffusers/schedulers/scheduling_deis_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
- diffusers/schedulers/scheduling_euler_discrete.py +102 -16
- diffusers/schedulers/scheduling_heun_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +3 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
- diffusers/schedulers/scheduling_utils.py +3 -1
- diffusers/schedulers/scheduling_utils_flax.py +3 -1
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +1 -2
- diffusers/utils/constants.py +10 -12
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +18 -22
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/hub_utils.py +24 -36
- diffusers/utils/logging.py +11 -11
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/state_dict_utils.py +8 -0
- diffusers/utils/testing_utils.py +199 -1
- diffusers/utils/torch_utils.py +4 -4
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
- diffusers-0.25.0.dist-info/RECORD +360 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
- diffusers/loaders.py +0 -3336
- diffusers-0.23.1.dist-info/RECORD +0 -323
- /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -92,6 +92,43 @@ def betas_for_alpha_bar(
|
|
92
92
|
return torch.tensor(betas, dtype=torch.float32)
|
93
93
|
|
94
94
|
|
95
|
+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
96
|
+
def rescale_zero_terminal_snr(betas):
|
97
|
+
"""
|
98
|
+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
99
|
+
|
100
|
+
|
101
|
+
Args:
|
102
|
+
betas (`torch.FloatTensor`):
|
103
|
+
the betas that the scheduler is being initialized with.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
107
|
+
"""
|
108
|
+
# Convert betas to alphas_bar_sqrt
|
109
|
+
alphas = 1.0 - betas
|
110
|
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
111
|
+
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
112
|
+
|
113
|
+
# Store old values.
|
114
|
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
115
|
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
116
|
+
|
117
|
+
# Shift so the last timestep is zero.
|
118
|
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
119
|
+
|
120
|
+
# Scale so the first timestep is back to the old value.
|
121
|
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
122
|
+
|
123
|
+
# Convert alphas_bar_sqrt to betas
|
124
|
+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
125
|
+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
126
|
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
127
|
+
betas = 1 - alphas
|
128
|
+
|
129
|
+
return betas
|
130
|
+
|
131
|
+
|
95
132
|
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
96
133
|
"""
|
97
134
|
Euler scheduler.
|
@@ -128,6 +165,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
128
165
|
An offset added to the inference steps. You can use a combination of `offset=1` and
|
129
166
|
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
130
167
|
Diffusion.
|
168
|
+
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
169
|
+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
170
|
+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
171
|
+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
131
172
|
"""
|
132
173
|
|
133
174
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
@@ -144,8 +185,12 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
144
185
|
prediction_type: str = "epsilon",
|
145
186
|
interpolation_type: str = "linear",
|
146
187
|
use_karras_sigmas: Optional[bool] = False,
|
188
|
+
sigma_min: Optional[float] = None,
|
189
|
+
sigma_max: Optional[float] = None,
|
147
190
|
timestep_spacing: str = "linspace",
|
191
|
+
timestep_type: str = "discrete", # can be "discrete" or "continuous"
|
148
192
|
steps_offset: int = 0,
|
193
|
+
rescale_betas_zero_snr: bool = False,
|
149
194
|
):
|
150
195
|
if trained_betas is not None:
|
151
196
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
@@ -153,38 +198,55 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
153
198
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
154
199
|
elif beta_schedule == "scaled_linear":
|
155
200
|
# this schedule is very specific to the latent diffusion model.
|
156
|
-
self.betas = (
|
157
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
158
|
-
)
|
201
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
159
202
|
elif beta_schedule == "squaredcos_cap_v2":
|
160
203
|
# Glide cosine schedule
|
161
204
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
162
205
|
else:
|
163
206
|
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
164
207
|
|
208
|
+
if rescale_betas_zero_snr:
|
209
|
+
self.betas = rescale_zero_terminal_snr(self.betas)
|
210
|
+
|
165
211
|
self.alphas = 1.0 - self.betas
|
166
212
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
167
213
|
|
214
|
+
if rescale_betas_zero_snr:
|
215
|
+
# Close to 0 without being 0 so first sigma is not inf
|
216
|
+
# FP16 smallest positive subnormal works well here
|
217
|
+
self.alphas_cumprod[-1] = 2**-24
|
218
|
+
|
168
219
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
169
|
-
|
170
|
-
|
220
|
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
221
|
+
|
222
|
+
sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32)
|
223
|
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
171
224
|
|
172
225
|
# setable values
|
173
226
|
self.num_inference_steps = None
|
174
|
-
|
175
|
-
|
227
|
+
|
228
|
+
# TODO: Support the full EDM scalings for all prediction types and timestep types
|
229
|
+
if timestep_type == "continuous" and prediction_type == "v_prediction":
|
230
|
+
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
|
231
|
+
else:
|
232
|
+
self.timesteps = timesteps
|
233
|
+
|
234
|
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
235
|
+
|
176
236
|
self.is_scale_input_called = False
|
177
237
|
self.use_karras_sigmas = use_karras_sigmas
|
178
238
|
|
179
239
|
self._step_index = None
|
240
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
180
241
|
|
181
242
|
@property
|
182
243
|
def init_noise_sigma(self):
|
183
244
|
# standard deviation of the initial noise distribution
|
245
|
+
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
|
184
246
|
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
185
|
-
return
|
247
|
+
return max_sigma
|
186
248
|
|
187
|
-
return (
|
249
|
+
return (max_sigma**2 + 1) ** 0.5
|
188
250
|
|
189
251
|
@property
|
190
252
|
def step_index(self):
|
@@ -259,7 +321,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
259
321
|
if self.config.interpolation_type == "linear":
|
260
322
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
261
323
|
elif self.config.interpolation_type == "log_linear":
|
262
|
-
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp()
|
324
|
+
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
|
263
325
|
else:
|
264
326
|
raise ValueError(
|
265
327
|
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
|
@@ -270,11 +332,17 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
270
332
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
271
333
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
272
334
|
|
273
|
-
sigmas =
|
274
|
-
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
335
|
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
275
336
|
|
276
|
-
|
337
|
+
# TODO: Support the full EDM scalings for all prediction types and timestep types
|
338
|
+
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
|
339
|
+
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
|
340
|
+
else:
|
341
|
+
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
|
342
|
+
|
343
|
+
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
277
344
|
self._step_index = None
|
345
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
278
346
|
|
279
347
|
def _sigma_to_t(self, sigma, log_sigmas):
|
280
348
|
# get log sigma
|
@@ -303,8 +371,20 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
303
371
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
304
372
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
305
373
|
|
306
|
-
|
307
|
-
|
374
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
375
|
+
# TODO: Add this logic to the other schedulers
|
376
|
+
if hasattr(self.config, "sigma_min"):
|
377
|
+
sigma_min = self.config.sigma_min
|
378
|
+
else:
|
379
|
+
sigma_min = None
|
380
|
+
|
381
|
+
if hasattr(self.config, "sigma_max"):
|
382
|
+
sigma_max = self.config.sigma_max
|
383
|
+
else:
|
384
|
+
sigma_max = None
|
385
|
+
|
386
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
387
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
308
388
|
|
309
389
|
rho = 7.0 # 7.0 is the value used in the paper
|
310
390
|
ramp = np.linspace(0, 1, num_inference_steps)
|
@@ -392,6 +472,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
392
472
|
if self.step_index is None:
|
393
473
|
self._init_step_index(timestep)
|
394
474
|
|
475
|
+
# Upcast to avoid precision issues when computing prev_sample
|
476
|
+
sample = sample.to(torch.float32)
|
477
|
+
|
395
478
|
sigma = self.sigmas[self.step_index]
|
396
479
|
|
397
480
|
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
@@ -414,7 +497,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
414
497
|
elif self.config.prediction_type == "epsilon":
|
415
498
|
pred_original_sample = sample - sigma_hat * model_output
|
416
499
|
elif self.config.prediction_type == "v_prediction":
|
417
|
-
# * c_out + input * c_skip
|
500
|
+
# denoised = model_output * c_out + input * c_skip
|
418
501
|
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
419
502
|
else:
|
420
503
|
raise ValueError(
|
@@ -428,6 +511,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
428
511
|
|
429
512
|
prev_sample = sample + derivative * dt
|
430
513
|
|
514
|
+
# Cast sample back to model compatible dtype
|
515
|
+
prev_sample = prev_sample.to(model_output.dtype)
|
516
|
+
|
431
517
|
# upon completion increase step index by one
|
432
518
|
self._step_index += 1
|
433
519
|
|
@@ -131,9 +131,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
131
131
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
132
132
|
elif beta_schedule == "scaled_linear":
|
133
133
|
# this schedule is very specific to the latent diffusion model.
|
134
|
-
self.betas = (
|
135
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
136
|
-
)
|
134
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
137
135
|
elif beta_schedule == "squaredcos_cap_v2":
|
138
136
|
# Glide cosine schedule
|
139
137
|
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
|
@@ -150,6 +148,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
150
148
|
self.use_karras_sigmas = use_karras_sigmas
|
151
149
|
|
152
150
|
self._step_index = None
|
151
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
153
152
|
|
154
153
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
155
154
|
if schedule_timesteps is None:
|
@@ -271,6 +270,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
271
270
|
self.dt = None
|
272
271
|
|
273
272
|
self._step_index = None
|
273
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
274
274
|
|
275
275
|
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
|
276
276
|
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
@@ -305,8 +305,20 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
305
305
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
306
306
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
307
307
|
|
308
|
-
|
309
|
-
|
308
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
309
|
+
# TODO: Add this logic to the other schedulers
|
310
|
+
if hasattr(self.config, "sigma_min"):
|
311
|
+
sigma_min = self.config.sigma_min
|
312
|
+
else:
|
313
|
+
sigma_min = None
|
314
|
+
|
315
|
+
if hasattr(self.config, "sigma_max"):
|
316
|
+
sigma_max = self.config.sigma_max
|
317
|
+
else:
|
318
|
+
sigma_max = None
|
319
|
+
|
320
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
321
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
310
322
|
|
311
323
|
rho = 7.0 # 7.0 is the value used in the paper
|
312
324
|
ramp = np.linspace(0, 1, num_inference_steps)
|
@@ -127,9 +127,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
127
127
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
128
128
|
elif beta_schedule == "scaled_linear":
|
129
129
|
# this schedule is very specific to the latent diffusion model.
|
130
|
-
self.betas = (
|
131
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
132
|
-
)
|
130
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
133
131
|
elif beta_schedule == "squaredcos_cap_v2":
|
134
132
|
# Glide cosine schedule
|
135
133
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -142,6 +140,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
142
140
|
# set all values
|
143
141
|
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
144
142
|
self._step_index = None
|
143
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
145
144
|
|
146
145
|
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
147
146
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
@@ -297,6 +296,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
297
296
|
self._index_counter = defaultdict(int)
|
298
297
|
|
299
298
|
self._step_index = None
|
299
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
300
300
|
|
301
301
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
302
302
|
def _sigma_to_t(self, sigma, log_sigmas):
|
@@ -326,8 +326,20 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
326
326
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
327
327
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
328
328
|
|
329
|
-
|
330
|
-
|
329
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
330
|
+
# TODO: Add this logic to the other schedulers
|
331
|
+
if hasattr(self.config, "sigma_min"):
|
332
|
+
sigma_min = self.config.sigma_min
|
333
|
+
else:
|
334
|
+
sigma_min = None
|
335
|
+
|
336
|
+
if hasattr(self.config, "sigma_max"):
|
337
|
+
sigma_max = self.config.sigma_max
|
338
|
+
else:
|
339
|
+
sigma_max = None
|
340
|
+
|
341
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
342
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
331
343
|
|
332
344
|
rho = 7.0 # 7.0 is the value used in the paper
|
333
345
|
ramp = np.linspace(0, 1, num_inference_steps)
|
@@ -126,9 +126,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
126
126
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
127
127
|
elif beta_schedule == "scaled_linear":
|
128
128
|
# this schedule is very specific to the latent diffusion model.
|
129
|
-
self.betas = (
|
130
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
131
|
-
)
|
129
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
132
130
|
elif beta_schedule == "squaredcos_cap_v2":
|
133
131
|
# Glide cosine schedule
|
134
132
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -142,6 +140,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
142
140
|
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
143
141
|
|
144
142
|
self._step_index = None
|
143
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
145
144
|
|
146
145
|
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
147
146
|
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
@@ -286,6 +285,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
286
285
|
self._index_counter = defaultdict(int)
|
287
286
|
|
288
287
|
self._step_index = None
|
288
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
289
289
|
|
290
290
|
@property
|
291
291
|
def state_in_first_order(self):
|
@@ -337,8 +337,20 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
337
337
|
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
338
338
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
339
339
|
|
340
|
-
|
341
|
-
|
340
|
+
# Hack to make sure that other schedulers which copy this function don't break
|
341
|
+
# TODO: Add this logic to the other schedulers
|
342
|
+
if hasattr(self.config, "sigma_min"):
|
343
|
+
sigma_min = self.config.sigma_min
|
344
|
+
else:
|
345
|
+
sigma_min = None
|
346
|
+
|
347
|
+
if hasattr(self.config, "sigma_max"):
|
348
|
+
sigma_max = self.config.sigma_max
|
349
|
+
else:
|
350
|
+
sigma_max = None
|
351
|
+
|
352
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
353
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
342
354
|
|
343
355
|
rho = 7.0 # 7.0 is the value used in the paper
|
344
356
|
ramp = np.linspace(0, 1, num_inference_steps)
|
@@ -221,9 +221,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
221
221
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
222
222
|
elif beta_schedule == "scaled_linear":
|
223
223
|
# this schedule is very specific to the latent diffusion model.
|
224
|
-
self.betas = (
|
225
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
226
|
-
)
|
224
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
227
225
|
elif beta_schedule == "squaredcos_cap_v2":
|
228
226
|
# Glide cosine schedule
|
229
227
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -249,6 +247,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
249
247
|
# setable values
|
250
248
|
self.num_inference_steps = None
|
251
249
|
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
250
|
+
self.custom_timesteps = False
|
252
251
|
|
253
252
|
self._step_index = None
|
254
253
|
|
@@ -326,17 +325,19 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
326
325
|
|
327
326
|
def set_timesteps(
|
328
327
|
self,
|
329
|
-
num_inference_steps: int,
|
328
|
+
num_inference_steps: Optional[int] = None,
|
330
329
|
device: Union[str, torch.device] = None,
|
331
330
|
original_inference_steps: Optional[int] = None,
|
331
|
+
timesteps: Optional[List[int]] = None,
|
332
332
|
strength: int = 1.0,
|
333
333
|
):
|
334
334
|
"""
|
335
335
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
336
336
|
|
337
337
|
Args:
|
338
|
-
num_inference_steps (`int
|
339
|
-
The number of diffusion steps used when generating samples with a pre-trained model.
|
338
|
+
num_inference_steps (`int`, *optional*):
|
339
|
+
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
340
|
+
`timesteps` must be `None`.
|
340
341
|
device (`str` or `torch.device`, *optional*):
|
341
342
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
342
343
|
original_inference_steps (`int`, *optional*):
|
@@ -344,16 +345,19 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
344
345
|
schedule (which is different from the standard `diffusers` implementation). We will then take
|
345
346
|
`num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
|
346
347
|
our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
|
348
|
+
timesteps (`List[int]`, *optional*):
|
349
|
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
350
|
+
timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
|
351
|
+
schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
|
347
352
|
"""
|
353
|
+
# 0. Check inputs
|
354
|
+
if num_inference_steps is None and timesteps is None:
|
355
|
+
raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
|
348
356
|
|
349
|
-
if num_inference_steps
|
350
|
-
raise ValueError(
|
351
|
-
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
352
|
-
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
353
|
-
f" maximal {self.config.num_train_timesteps} timesteps."
|
354
|
-
)
|
357
|
+
if num_inference_steps is not None and timesteps is not None:
|
358
|
+
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
355
359
|
|
356
|
-
|
360
|
+
# 1. Calculate the LCM original training/distillation timestep schedule.
|
357
361
|
original_steps = (
|
358
362
|
original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
|
359
363
|
)
|
@@ -365,23 +369,97 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
365
369
|
f" maximal {self.config.num_train_timesteps} timesteps."
|
366
370
|
)
|
367
371
|
|
368
|
-
if num_inference_steps > original_steps:
|
369
|
-
raise ValueError(
|
370
|
-
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
|
371
|
-
f" {original_steps} because the final timestep schedule will be a subset of the"
|
372
|
-
f" `original_inference_steps`-sized initial timestep schedule."
|
373
|
-
)
|
374
|
-
|
375
372
|
# LCM Timesteps Setting
|
376
|
-
#
|
377
|
-
|
378
|
-
# LCM Training Steps Schedule
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
373
|
+
# The skipping step parameter k from the paper.
|
374
|
+
k = self.config.num_train_timesteps // original_steps
|
375
|
+
# LCM Training/Distillation Steps Schedule
|
376
|
+
# Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
|
377
|
+
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
|
378
|
+
|
379
|
+
# 2. Calculate the LCM inference timestep schedule.
|
380
|
+
if timesteps is not None:
|
381
|
+
# 2.1 Handle custom timestep schedules.
|
382
|
+
train_timesteps = set(lcm_origin_timesteps)
|
383
|
+
non_train_timesteps = []
|
384
|
+
for i in range(1, len(timesteps)):
|
385
|
+
if timesteps[i] >= timesteps[i - 1]:
|
386
|
+
raise ValueError("`custom_timesteps` must be in descending order.")
|
387
|
+
|
388
|
+
if timesteps[i] not in train_timesteps:
|
389
|
+
non_train_timesteps.append(timesteps[i])
|
390
|
+
|
391
|
+
if timesteps[0] >= self.config.num_train_timesteps:
|
392
|
+
raise ValueError(
|
393
|
+
f"`timesteps` must start before `self.config.train_timesteps`:"
|
394
|
+
f" {self.config.num_train_timesteps}."
|
395
|
+
)
|
396
|
+
|
397
|
+
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
|
398
|
+
if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
|
399
|
+
logger.warning(
|
400
|
+
f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
|
401
|
+
f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
|
402
|
+
f" unexpected results when using this timestep schedule."
|
403
|
+
)
|
404
|
+
|
405
|
+
# Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
|
406
|
+
if non_train_timesteps:
|
407
|
+
logger.warning(
|
408
|
+
f"The custom timestep schedule contains the following timesteps which are not on the original"
|
409
|
+
f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
|
410
|
+
f" when using this timestep schedule."
|
411
|
+
)
|
412
|
+
|
413
|
+
# Raise warning if custom timestep schedule is longer than original_steps
|
414
|
+
if len(timesteps) > original_steps:
|
415
|
+
logger.warning(
|
416
|
+
f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
|
417
|
+
f" the length of the timestep schedule used for training: {original_steps}. You may get some"
|
418
|
+
f" unexpected results when using this timestep schedule."
|
419
|
+
)
|
420
|
+
|
421
|
+
timesteps = np.array(timesteps, dtype=np.int64)
|
422
|
+
self.num_inference_steps = len(timesteps)
|
423
|
+
self.custom_timesteps = True
|
424
|
+
|
425
|
+
# Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
|
426
|
+
init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps)
|
427
|
+
t_start = max(self.num_inference_steps - init_timestep, 0)
|
428
|
+
timesteps = timesteps[t_start * self.order :]
|
429
|
+
# TODO: also reset self.num_inference_steps?
|
430
|
+
else:
|
431
|
+
# 2.2 Create the "standard" LCM inference timestep schedule.
|
432
|
+
if num_inference_steps > self.config.num_train_timesteps:
|
433
|
+
raise ValueError(
|
434
|
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
435
|
+
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
436
|
+
f" maximal {self.config.num_train_timesteps} timesteps."
|
437
|
+
)
|
438
|
+
|
439
|
+
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
|
440
|
+
|
441
|
+
if skipping_step < 1:
|
442
|
+
raise ValueError(
|
443
|
+
f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
|
444
|
+
)
|
445
|
+
|
446
|
+
self.num_inference_steps = num_inference_steps
|
447
|
+
|
448
|
+
if num_inference_steps > original_steps:
|
449
|
+
raise ValueError(
|
450
|
+
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
|
451
|
+
f" {original_steps} because the final timestep schedule will be a subset of the"
|
452
|
+
f" `original_inference_steps`-sized initial timestep schedule."
|
453
|
+
)
|
454
|
+
|
455
|
+
# LCM Inference Steps Schedule
|
456
|
+
lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
|
457
|
+
# Select (approximately) evenly spaced indices from lcm_origin_timesteps.
|
458
|
+
inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False)
|
459
|
+
inference_indices = np.floor(inference_indices).astype(np.int64)
|
460
|
+
timesteps = lcm_origin_timesteps[inference_indices]
|
461
|
+
|
462
|
+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
|
385
463
|
|
386
464
|
self._step_index = None
|
387
465
|
|
@@ -536,3 +614,19 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
536
614
|
|
537
615
|
def __len__(self):
|
538
616
|
return self.config.num_train_timesteps
|
617
|
+
|
618
|
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
619
|
+
def previous_timestep(self, timestep):
|
620
|
+
if self.custom_timesteps:
|
621
|
+
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
622
|
+
if index == self.timesteps.shape[0] - 1:
|
623
|
+
prev_t = torch.tensor(-1)
|
624
|
+
else:
|
625
|
+
prev_t = self.timesteps[index + 1]
|
626
|
+
else:
|
627
|
+
num_inference_steps = (
|
628
|
+
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
629
|
+
)
|
630
|
+
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
631
|
+
|
632
|
+
return prev_t
|
@@ -146,9 +146,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
146
146
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
147
147
|
elif beta_schedule == "scaled_linear":
|
148
148
|
# this schedule is very specific to the latent diffusion model.
|
149
|
-
self.betas = (
|
150
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
151
|
-
)
|
149
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
152
150
|
elif beta_schedule == "squaredcos_cap_v2":
|
153
151
|
# Glide cosine schedule
|
154
152
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -170,6 +168,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
170
168
|
self.is_scale_input_called = False
|
171
169
|
|
172
170
|
self._step_index = None
|
171
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
173
172
|
|
174
173
|
@property
|
175
174
|
def init_noise_sigma(self):
|
@@ -281,6 +280,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
281
280
|
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
282
281
|
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
283
282
|
self._step_index = None
|
283
|
+
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
284
284
|
|
285
285
|
self.derivatives = []
|
286
286
|
|
@@ -132,9 +132,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
132
132
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
133
133
|
elif beta_schedule == "scaled_linear":
|
134
134
|
# this schedule is very specific to the latent diffusion model.
|
135
|
-
self.betas = (
|
136
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
137
|
-
)
|
135
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
138
136
|
elif beta_schedule == "squaredcos_cap_v2":
|
139
137
|
# Glide cosine schedule
|
140
138
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
@@ -134,9 +134,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
|
134
134
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
135
135
|
elif beta_schedule == "scaled_linear":
|
136
136
|
# this schedule is very specific to the latent diffusion model.
|
137
|
-
self.betas = (
|
138
|
-
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
139
|
-
)
|
137
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
140
138
|
elif beta_schedule == "squaredcos_cap_v2":
|
141
139
|
# Glide cosine schedule
|
142
140
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|