diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -61,7 +61,7 @@ def betas_for_alpha_bar(
|
|
61
61
|
return math.exp(t * -12.0)
|
62
62
|
|
63
63
|
else:
|
64
|
-
raise ValueError(f"Unsupported
|
64
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
65
65
|
|
66
66
|
betas = []
|
67
67
|
for i in range(num_diffusion_timesteps):
|
@@ -78,11 +78,11 @@ def rescale_zero_terminal_snr(betas):
|
|
78
78
|
|
79
79
|
|
80
80
|
Args:
|
81
|
-
betas (`torch.
|
81
|
+
betas (`torch.Tensor`):
|
82
82
|
the betas that the scheduler is being initialized with.
|
83
83
|
|
84
84
|
Returns:
|
85
|
-
`torch.
|
85
|
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
86
86
|
"""
|
87
87
|
# Convert betas to alphas_bar_sqrt
|
88
88
|
alphas = 1.0 - betas
|
@@ -166,8 +166,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
166
166
|
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
|
167
167
|
`lambda(t)`.
|
168
168
|
final_sigmas_type (`str`, defaults to `"zero"`):
|
169
|
-
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
170
|
-
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
169
|
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
170
|
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
171
171
|
lambda_min_clipped (`float`, defaults to `-inf`):
|
172
172
|
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
173
173
|
cosine (`squaredcos_cap_v2`) noise schedule.
|
@@ -229,7 +229,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
229
229
|
# Glide cosine schedule
|
230
230
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
231
231
|
else:
|
232
|
-
raise NotImplementedError(f"{beta_schedule}
|
232
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
233
233
|
|
234
234
|
if rescale_betas_zero_snr:
|
235
235
|
self.betas = rescale_zero_terminal_snr(self.betas)
|
@@ -256,13 +256,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
256
256
|
if algorithm_type == "deis":
|
257
257
|
self.register_to_config(algorithm_type="dpmsolver++")
|
258
258
|
else:
|
259
|
-
raise NotImplementedError(f"{algorithm_type}
|
259
|
+
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
260
260
|
|
261
261
|
if solver_type not in ["midpoint", "heun"]:
|
262
262
|
if solver_type in ["logrho", "bh1", "bh2"]:
|
263
263
|
self.register_to_config(solver_type="midpoint")
|
264
264
|
else:
|
265
|
-
raise NotImplementedError(f"{solver_type}
|
265
|
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
266
266
|
|
267
267
|
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
268
268
|
raise ValueError(
|
@@ -282,7 +282,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
282
282
|
@property
|
283
283
|
def step_index(self):
|
284
284
|
"""
|
285
|
-
The index counter for current timestep. It will
|
285
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
286
286
|
"""
|
287
287
|
return self._step_index
|
288
288
|
|
@@ -303,7 +303,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
303
303
|
"""
|
304
304
|
self._begin_index = begin_index
|
305
305
|
|
306
|
-
def set_timesteps(
|
306
|
+
def set_timesteps(
|
307
|
+
self,
|
308
|
+
num_inference_steps: int = None,
|
309
|
+
device: Union[str, torch.device] = None,
|
310
|
+
timesteps: Optional[List[int]] = None,
|
311
|
+
):
|
307
312
|
"""
|
308
313
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
309
314
|
|
@@ -312,33 +317,54 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
312
317
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
313
318
|
device (`str` or `torch.device`, *optional*):
|
314
319
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
320
|
+
timesteps (`List[int]`, *optional*):
|
321
|
+
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
322
|
+
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
323
|
+
must be `None`, and `timestep_spacing` attribute will be ignored.
|
315
324
|
"""
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
if self.config.
|
323
|
-
timesteps =
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
step_ratio = last_timestep // (num_inference_steps + 1)
|
328
|
-
# creates integer timesteps by multiplying by ratio
|
329
|
-
# casting to int to avoid issues when num_inference_step is power of 3
|
330
|
-
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
|
331
|
-
timesteps += self.config.steps_offset
|
332
|
-
elif self.config.timestep_spacing == "trailing":
|
333
|
-
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
334
|
-
# creates integer timesteps by multiplying by ratio
|
335
|
-
# casting to int to avoid issues when num_inference_step is power of 3
|
336
|
-
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
|
337
|
-
timesteps -= 1
|
325
|
+
if num_inference_steps is None and timesteps is None:
|
326
|
+
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
327
|
+
if num_inference_steps is not None and timesteps is not None:
|
328
|
+
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
329
|
+
if timesteps is not None and self.config.use_karras_sigmas:
|
330
|
+
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
|
331
|
+
if timesteps is not None and self.config.use_lu_lambdas:
|
332
|
+
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
|
333
|
+
|
334
|
+
if timesteps is not None:
|
335
|
+
timesteps = np.array(timesteps).astype(np.int64)
|
338
336
|
else:
|
339
|
-
|
340
|
-
|
341
|
-
)
|
337
|
+
# Clipping the minimum of all lambda(t) for numerical stability.
|
338
|
+
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
|
339
|
+
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
|
340
|
+
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
|
341
|
+
|
342
|
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
343
|
+
if self.config.timestep_spacing == "linspace":
|
344
|
+
timesteps = (
|
345
|
+
np.linspace(0, last_timestep - 1, num_inference_steps + 1)
|
346
|
+
.round()[::-1][:-1]
|
347
|
+
.copy()
|
348
|
+
.astype(np.int64)
|
349
|
+
)
|
350
|
+
elif self.config.timestep_spacing == "leading":
|
351
|
+
step_ratio = last_timestep // (num_inference_steps + 1)
|
352
|
+
# creates integer timesteps by multiplying by ratio
|
353
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
354
|
+
timesteps = (
|
355
|
+
(np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
|
356
|
+
)
|
357
|
+
timesteps += self.config.steps_offset
|
358
|
+
elif self.config.timestep_spacing == "trailing":
|
359
|
+
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
360
|
+
# creates integer timesteps by multiplying by ratio
|
361
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
362
|
+
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
|
363
|
+
timesteps -= 1
|
364
|
+
else:
|
365
|
+
raise ValueError(
|
366
|
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
367
|
+
)
|
342
368
|
|
343
369
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
344
370
|
log_sigmas = np.log(sigmas)
|
@@ -382,7 +408,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
382
408
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
383
409
|
|
384
410
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
385
|
-
def _threshold_sample(self, sample: torch.
|
411
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
386
412
|
"""
|
387
413
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
388
414
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -446,7 +472,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
446
472
|
return alpha_t, sigma_t
|
447
473
|
|
448
474
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
449
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
475
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
450
476
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
451
477
|
|
452
478
|
# Hack to make sure that other schedulers which copy this function don't break
|
@@ -471,7 +497,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
471
497
|
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
472
498
|
return sigmas
|
473
499
|
|
474
|
-
def _convert_to_lu(self, in_lambdas: torch.
|
500
|
+
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
475
501
|
"""Constructs the noise schedule of Lu et al. (2022)."""
|
476
502
|
|
477
503
|
lambda_min: float = in_lambdas[-1].item()
|
@@ -486,11 +512,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
486
512
|
|
487
513
|
def convert_model_output(
|
488
514
|
self,
|
489
|
-
model_output: torch.
|
515
|
+
model_output: torch.Tensor,
|
490
516
|
*args,
|
491
|
-
sample: torch.
|
517
|
+
sample: torch.Tensor = None,
|
492
518
|
**kwargs,
|
493
|
-
) -> torch.
|
519
|
+
) -> torch.Tensor:
|
494
520
|
"""
|
495
521
|
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
496
522
|
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
@@ -504,13 +530,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
504
530
|
</Tip>
|
505
531
|
|
506
532
|
Args:
|
507
|
-
model_output (`torch.
|
533
|
+
model_output (`torch.Tensor`):
|
508
534
|
The direct output from the learned diffusion model.
|
509
|
-
sample (`torch.
|
535
|
+
sample (`torch.Tensor`):
|
510
536
|
A current instance of a sample created by the diffusion process.
|
511
537
|
|
512
538
|
Returns:
|
513
|
-
`torch.
|
539
|
+
`torch.Tensor`:
|
514
540
|
The converted model output.
|
515
541
|
"""
|
516
542
|
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
@@ -585,23 +611,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
585
611
|
|
586
612
|
def dpm_solver_first_order_update(
|
587
613
|
self,
|
588
|
-
model_output: torch.
|
614
|
+
model_output: torch.Tensor,
|
589
615
|
*args,
|
590
|
-
sample: torch.
|
591
|
-
noise: Optional[torch.
|
616
|
+
sample: torch.Tensor = None,
|
617
|
+
noise: Optional[torch.Tensor] = None,
|
592
618
|
**kwargs,
|
593
|
-
) -> torch.
|
619
|
+
) -> torch.Tensor:
|
594
620
|
"""
|
595
621
|
One step for the first-order DPMSolver (equivalent to DDIM).
|
596
622
|
|
597
623
|
Args:
|
598
|
-
model_output (`torch.
|
624
|
+
model_output (`torch.Tensor`):
|
599
625
|
The direct output from the learned diffusion model.
|
600
|
-
sample (`torch.
|
626
|
+
sample (`torch.Tensor`):
|
601
627
|
A current instance of a sample created by the diffusion process.
|
602
628
|
|
603
629
|
Returns:
|
604
|
-
`torch.
|
630
|
+
`torch.Tensor`:
|
605
631
|
The sample tensor at the previous timestep.
|
606
632
|
"""
|
607
633
|
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
@@ -654,23 +680,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
654
680
|
|
655
681
|
def multistep_dpm_solver_second_order_update(
|
656
682
|
self,
|
657
|
-
model_output_list: List[torch.
|
683
|
+
model_output_list: List[torch.Tensor],
|
658
684
|
*args,
|
659
|
-
sample: torch.
|
660
|
-
noise: Optional[torch.
|
685
|
+
sample: torch.Tensor = None,
|
686
|
+
noise: Optional[torch.Tensor] = None,
|
661
687
|
**kwargs,
|
662
|
-
) -> torch.
|
688
|
+
) -> torch.Tensor:
|
663
689
|
"""
|
664
690
|
One step for the second-order multistep DPMSolver.
|
665
691
|
|
666
692
|
Args:
|
667
|
-
model_output_list (`List[torch.
|
693
|
+
model_output_list (`List[torch.Tensor]`):
|
668
694
|
The direct outputs from learned diffusion model at current and latter timesteps.
|
669
|
-
sample (`torch.
|
695
|
+
sample (`torch.Tensor`):
|
670
696
|
A current instance of a sample created by the diffusion process.
|
671
697
|
|
672
698
|
Returns:
|
673
|
-
`torch.
|
699
|
+
`torch.Tensor`:
|
674
700
|
The sample tensor at the previous timestep.
|
675
701
|
"""
|
676
702
|
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
@@ -777,22 +803,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
777
803
|
|
778
804
|
def multistep_dpm_solver_third_order_update(
|
779
805
|
self,
|
780
|
-
model_output_list: List[torch.
|
806
|
+
model_output_list: List[torch.Tensor],
|
781
807
|
*args,
|
782
|
-
sample: torch.
|
808
|
+
sample: torch.Tensor = None,
|
783
809
|
**kwargs,
|
784
|
-
) -> torch.
|
810
|
+
) -> torch.Tensor:
|
785
811
|
"""
|
786
812
|
One step for the third-order multistep DPMSolver.
|
787
813
|
|
788
814
|
Args:
|
789
|
-
model_output_list (`List[torch.
|
815
|
+
model_output_list (`List[torch.Tensor]`):
|
790
816
|
The direct outputs from learned diffusion model at current and latter timesteps.
|
791
|
-
sample (`torch.
|
817
|
+
sample (`torch.Tensor`):
|
792
818
|
A current instance of a sample created by diffusion process.
|
793
819
|
|
794
820
|
Returns:
|
795
|
-
`torch.
|
821
|
+
`torch.Tensor`:
|
796
822
|
The sample tensor at the previous timestep.
|
797
823
|
"""
|
798
824
|
|
@@ -893,11 +919,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
893
919
|
|
894
920
|
def step(
|
895
921
|
self,
|
896
|
-
model_output: torch.
|
922
|
+
model_output: torch.Tensor,
|
897
923
|
timestep: int,
|
898
|
-
sample: torch.
|
924
|
+
sample: torch.Tensor,
|
899
925
|
generator=None,
|
900
|
-
variance_noise: Optional[torch.
|
926
|
+
variance_noise: Optional[torch.Tensor] = None,
|
901
927
|
return_dict: bool = True,
|
902
928
|
) -> Union[SchedulerOutput, Tuple]:
|
903
929
|
"""
|
@@ -905,15 +931,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
905
931
|
the multistep DPMSolver.
|
906
932
|
|
907
933
|
Args:
|
908
|
-
model_output (`torch.
|
934
|
+
model_output (`torch.Tensor`):
|
909
935
|
The direct output from learned diffusion model.
|
910
936
|
timestep (`int`):
|
911
937
|
The current discrete timestep in the diffusion chain.
|
912
|
-
sample (`torch.
|
938
|
+
sample (`torch.Tensor`):
|
913
939
|
A current instance of a sample created by the diffusion process.
|
914
940
|
generator (`torch.Generator`, *optional*):
|
915
941
|
A random number generator.
|
916
|
-
variance_noise (`torch.
|
942
|
+
variance_noise (`torch.Tensor`):
|
917
943
|
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
918
944
|
itself. Useful for methods such as [`LEdits++`].
|
919
945
|
return_dict (`bool`):
|
@@ -980,27 +1006,27 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
980
1006
|
|
981
1007
|
return SchedulerOutput(prev_sample=prev_sample)
|
982
1008
|
|
983
|
-
def scale_model_input(self, sample: torch.
|
1009
|
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
984
1010
|
"""
|
985
1011
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
986
1012
|
current timestep.
|
987
1013
|
|
988
1014
|
Args:
|
989
|
-
sample (`torch.
|
1015
|
+
sample (`torch.Tensor`):
|
990
1016
|
The input sample.
|
991
1017
|
|
992
1018
|
Returns:
|
993
|
-
`torch.
|
1019
|
+
`torch.Tensor`:
|
994
1020
|
A scaled input sample.
|
995
1021
|
"""
|
996
1022
|
return sample
|
997
1023
|
|
998
1024
|
def add_noise(
|
999
1025
|
self,
|
1000
|
-
original_samples: torch.
|
1001
|
-
noise: torch.
|
1026
|
+
original_samples: torch.Tensor,
|
1027
|
+
noise: torch.Tensor,
|
1002
1028
|
timesteps: torch.IntTensor,
|
1003
|
-
) -> torch.
|
1029
|
+
) -> torch.Tensor:
|
1004
1030
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
1005
1031
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
1006
1032
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -1011,10 +1037,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
1011
1037
|
schedule_timesteps = self.timesteps.to(original_samples.device)
|
1012
1038
|
timesteps = timesteps.to(original_samples.device)
|
1013
1039
|
|
1014
|
-
# begin_index is None when the scheduler is used for training
|
1040
|
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
1015
1041
|
if self.begin_index is None:
|
1016
1042
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
1043
|
+
elif self.step_index is not None:
|
1044
|
+
# add_noise is called after first denoising step (for inpainting)
|
1045
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
1017
1046
|
else:
|
1047
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
1018
1048
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
1019
1049
|
|
1020
1050
|
sigma = sigmas[step_indices].flatten()
|
@@ -182,9 +182,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
182
182
|
|
183
183
|
# settings for DPM-Solver
|
184
184
|
if self.config.algorithm_type not in ["dpmsolver", "dpmsolver++"]:
|
185
|
-
raise NotImplementedError(f"{self.config.algorithm_type}
|
185
|
+
raise NotImplementedError(f"{self.config.algorithm_type} is not implemented for {self.__class__}")
|
186
186
|
if self.config.solver_type not in ["midpoint", "heun"]:
|
187
|
-
raise NotImplementedError(f"{self.config.solver_type}
|
187
|
+
raise NotImplementedError(f"{self.config.solver_type} is not implemented for {self.__class__}")
|
188
188
|
|
189
189
|
# standard deviation of the initial noise distribution
|
190
190
|
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)
|
@@ -61,7 +61,7 @@ def betas_for_alpha_bar(
|
|
61
61
|
return math.exp(t * -12.0)
|
62
62
|
|
63
63
|
else:
|
64
|
-
raise ValueError(f"Unsupported
|
64
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
65
65
|
|
66
66
|
betas = []
|
67
67
|
for i in range(num_diffusion_timesteps):
|
@@ -178,7 +178,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
178
178
|
# Glide cosine schedule
|
179
179
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
180
180
|
else:
|
181
|
-
raise NotImplementedError(f"{beta_schedule}
|
181
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
182
182
|
|
183
183
|
self.alphas = 1.0 - self.betas
|
184
184
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
@@ -196,13 +196,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
196
196
|
if algorithm_type == "deis":
|
197
197
|
self.register_to_config(algorithm_type="dpmsolver++")
|
198
198
|
else:
|
199
|
-
raise NotImplementedError(f"{algorithm_type}
|
199
|
+
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
200
200
|
|
201
201
|
if solver_type not in ["midpoint", "heun"]:
|
202
202
|
if solver_type in ["logrho", "bh1", "bh2"]:
|
203
203
|
self.register_to_config(solver_type="midpoint")
|
204
204
|
else:
|
205
|
-
raise NotImplementedError(f"{solver_type}
|
205
|
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
206
206
|
|
207
207
|
# setable values
|
208
208
|
self.num_inference_steps = None
|
@@ -217,7 +217,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
217
217
|
@property
|
218
218
|
def step_index(self):
|
219
219
|
"""
|
220
|
-
The index counter for current timestep. It will
|
220
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
221
221
|
"""
|
222
222
|
return self._step_index
|
223
223
|
|
@@ -233,7 +233,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
233
233
|
"""
|
234
234
|
# Clipping the minimum of all lambda(t) for numerical stability.
|
235
235
|
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
|
236
|
-
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped).item()
|
236
|
+
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped).item()
|
237
237
|
self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx
|
238
238
|
|
239
239
|
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
@@ -295,7 +295,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
295
295
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
296
296
|
|
297
297
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
298
|
-
def _threshold_sample(self, sample: torch.
|
298
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
299
299
|
"""
|
300
300
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
301
301
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -360,7 +360,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
360
360
|
return alpha_t, sigma_t
|
361
361
|
|
362
362
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
363
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
363
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
364
364
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
365
365
|
|
366
366
|
# Hack to make sure that other schedulers which copy this function don't break
|
@@ -388,11 +388,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
388
388
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
|
389
389
|
def convert_model_output(
|
390
390
|
self,
|
391
|
-
model_output: torch.
|
391
|
+
model_output: torch.Tensor,
|
392
392
|
*args,
|
393
|
-
sample: torch.
|
393
|
+
sample: torch.Tensor = None,
|
394
394
|
**kwargs,
|
395
|
-
) -> torch.
|
395
|
+
) -> torch.Tensor:
|
396
396
|
"""
|
397
397
|
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
398
398
|
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
@@ -406,13 +406,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
406
406
|
</Tip>
|
407
407
|
|
408
408
|
Args:
|
409
|
-
model_output (`torch.
|
409
|
+
model_output (`torch.Tensor`):
|
410
410
|
The direct output from the learned diffusion model.
|
411
|
-
sample (`torch.
|
411
|
+
sample (`torch.Tensor`):
|
412
412
|
A current instance of a sample created by the diffusion process.
|
413
413
|
|
414
414
|
Returns:
|
415
|
-
`torch.
|
415
|
+
`torch.Tensor`:
|
416
416
|
The converted model output.
|
417
417
|
"""
|
418
418
|
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
@@ -488,23 +488,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
488
488
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
|
489
489
|
def dpm_solver_first_order_update(
|
490
490
|
self,
|
491
|
-
model_output: torch.
|
491
|
+
model_output: torch.Tensor,
|
492
492
|
*args,
|
493
|
-
sample: torch.
|
494
|
-
noise: Optional[torch.
|
493
|
+
sample: torch.Tensor = None,
|
494
|
+
noise: Optional[torch.Tensor] = None,
|
495
495
|
**kwargs,
|
496
|
-
) -> torch.
|
496
|
+
) -> torch.Tensor:
|
497
497
|
"""
|
498
498
|
One step for the first-order DPMSolver (equivalent to DDIM).
|
499
499
|
|
500
500
|
Args:
|
501
|
-
model_output (`torch.
|
501
|
+
model_output (`torch.Tensor`):
|
502
502
|
The direct output from the learned diffusion model.
|
503
|
-
sample (`torch.
|
503
|
+
sample (`torch.Tensor`):
|
504
504
|
A current instance of a sample created by the diffusion process.
|
505
505
|
|
506
506
|
Returns:
|
507
|
-
`torch.
|
507
|
+
`torch.Tensor`:
|
508
508
|
The sample tensor at the previous timestep.
|
509
509
|
"""
|
510
510
|
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
@@ -558,23 +558,23 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
558
558
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
|
559
559
|
def multistep_dpm_solver_second_order_update(
|
560
560
|
self,
|
561
|
-
model_output_list: List[torch.
|
561
|
+
model_output_list: List[torch.Tensor],
|
562
562
|
*args,
|
563
|
-
sample: torch.
|
564
|
-
noise: Optional[torch.
|
563
|
+
sample: torch.Tensor = None,
|
564
|
+
noise: Optional[torch.Tensor] = None,
|
565
565
|
**kwargs,
|
566
|
-
) -> torch.
|
566
|
+
) -> torch.Tensor:
|
567
567
|
"""
|
568
568
|
One step for the second-order multistep DPMSolver.
|
569
569
|
|
570
570
|
Args:
|
571
|
-
model_output_list (`List[torch.
|
571
|
+
model_output_list (`List[torch.Tensor]`):
|
572
572
|
The direct outputs from learned diffusion model at current and latter timesteps.
|
573
|
-
sample (`torch.
|
573
|
+
sample (`torch.Tensor`):
|
574
574
|
A current instance of a sample created by the diffusion process.
|
575
575
|
|
576
576
|
Returns:
|
577
|
-
`torch.
|
577
|
+
`torch.Tensor`:
|
578
578
|
The sample tensor at the previous timestep.
|
579
579
|
"""
|
580
580
|
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
@@ -682,22 +682,22 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
682
682
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
|
683
683
|
def multistep_dpm_solver_third_order_update(
|
684
684
|
self,
|
685
|
-
model_output_list: List[torch.
|
685
|
+
model_output_list: List[torch.Tensor],
|
686
686
|
*args,
|
687
|
-
sample: torch.
|
687
|
+
sample: torch.Tensor = None,
|
688
688
|
**kwargs,
|
689
|
-
) -> torch.
|
689
|
+
) -> torch.Tensor:
|
690
690
|
"""
|
691
691
|
One step for the third-order multistep DPMSolver.
|
692
692
|
|
693
693
|
Args:
|
694
|
-
model_output_list (`List[torch.
|
694
|
+
model_output_list (`List[torch.Tensor]`):
|
695
695
|
The direct outputs from learned diffusion model at current and latter timesteps.
|
696
|
-
sample (`torch.
|
696
|
+
sample (`torch.Tensor`):
|
697
697
|
A current instance of a sample created by diffusion process.
|
698
698
|
|
699
699
|
Returns:
|
700
|
-
`torch.
|
700
|
+
`torch.Tensor`:
|
701
701
|
The sample tensor at the previous timestep.
|
702
702
|
"""
|
703
703
|
|
@@ -786,11 +786,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
786
786
|
|
787
787
|
def step(
|
788
788
|
self,
|
789
|
-
model_output: torch.
|
789
|
+
model_output: torch.Tensor,
|
790
790
|
timestep: int,
|
791
|
-
sample: torch.
|
791
|
+
sample: torch.Tensor,
|
792
792
|
generator=None,
|
793
|
-
variance_noise: Optional[torch.
|
793
|
+
variance_noise: Optional[torch.Tensor] = None,
|
794
794
|
return_dict: bool = True,
|
795
795
|
) -> Union[SchedulerOutput, Tuple]:
|
796
796
|
"""
|
@@ -798,15 +798,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
798
798
|
the multistep DPMSolver.
|
799
799
|
|
800
800
|
Args:
|
801
|
-
model_output (`torch.
|
801
|
+
model_output (`torch.Tensor`):
|
802
802
|
The direct output from learned diffusion model.
|
803
803
|
timestep (`int`):
|
804
804
|
The current discrete timestep in the diffusion chain.
|
805
|
-
sample (`torch.
|
805
|
+
sample (`torch.Tensor`):
|
806
806
|
A current instance of a sample created by the diffusion process.
|
807
807
|
generator (`torch.Generator`, *optional*):
|
808
808
|
A random number generator.
|
809
|
-
variance_noise (`torch.
|
809
|
+
variance_noise (`torch.Tensor`):
|
810
810
|
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
811
811
|
itself. Useful for methods such as [`CycleDiffusion`].
|
812
812
|
return_dict (`bool`):
|
@@ -867,27 +867,27 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
867
867
|
return SchedulerOutput(prev_sample=prev_sample)
|
868
868
|
|
869
869
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
|
870
|
-
def scale_model_input(self, sample: torch.
|
870
|
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
871
871
|
"""
|
872
872
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
873
873
|
current timestep.
|
874
874
|
|
875
875
|
Args:
|
876
|
-
sample (`torch.
|
876
|
+
sample (`torch.Tensor`):
|
877
877
|
The input sample.
|
878
878
|
|
879
879
|
Returns:
|
880
|
-
`torch.
|
880
|
+
`torch.Tensor`:
|
881
881
|
A scaled input sample.
|
882
882
|
"""
|
883
883
|
return sample
|
884
884
|
|
885
885
|
def add_noise(
|
886
886
|
self,
|
887
|
-
original_samples: torch.
|
888
|
-
noise: torch.
|
887
|
+
original_samples: torch.Tensor,
|
888
|
+
noise: torch.Tensor,
|
889
889
|
timesteps: torch.IntTensor,
|
890
|
-
) -> torch.
|
890
|
+
) -> torch.Tensor:
|
891
891
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
892
892
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
893
893
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|