diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -110,7 +110,7 @@ def betas_for_alpha_bar(
|
|
110
110
|
return math.exp(t * -12.0)
|
111
111
|
|
112
112
|
else:
|
113
|
-
raise ValueError(f"Unsupported
|
113
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
114
114
|
|
115
115
|
betas = []
|
116
116
|
for i in range(num_diffusion_timesteps):
|
@@ -184,7 +184,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
184
184
|
# Glide cosine schedule
|
185
185
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
186
186
|
else:
|
187
|
-
raise NotImplementedError(f"{beta_schedule}
|
187
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
188
188
|
|
189
189
|
self.alphas = 1.0 - self.betas
|
190
190
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
@@ -233,7 +233,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
233
233
|
@property
|
234
234
|
def step_index(self):
|
235
235
|
"""
|
236
|
-
The index counter for current timestep. It will
|
236
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
237
237
|
"""
|
238
238
|
return self._step_index
|
239
239
|
|
@@ -257,21 +257,21 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
257
257
|
|
258
258
|
def scale_model_input(
|
259
259
|
self,
|
260
|
-
sample: torch.
|
261
|
-
timestep: Union[float, torch.
|
262
|
-
) -> torch.
|
260
|
+
sample: torch.Tensor,
|
261
|
+
timestep: Union[float, torch.Tensor],
|
262
|
+
) -> torch.Tensor:
|
263
263
|
"""
|
264
264
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
265
265
|
current timestep.
|
266
266
|
|
267
267
|
Args:
|
268
|
-
sample (`torch.
|
268
|
+
sample (`torch.Tensor`):
|
269
269
|
The input sample.
|
270
270
|
timestep (`int`, *optional*):
|
271
271
|
The current timestep in the diffusion chain.
|
272
272
|
|
273
273
|
Returns:
|
274
|
-
`torch.
|
274
|
+
`torch.Tensor`:
|
275
275
|
A scaled input sample.
|
276
276
|
"""
|
277
277
|
if self.step_index is None:
|
@@ -325,7 +325,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
325
325
|
log_sigmas = np.log(sigmas)
|
326
326
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
327
327
|
|
328
|
-
if self.use_karras_sigmas:
|
328
|
+
if self.config.use_karras_sigmas:
|
329
329
|
sigmas = self._convert_to_karras(in_sigmas=sigmas)
|
330
330
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
331
331
|
|
@@ -395,7 +395,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
395
395
|
return t
|
396
396
|
|
397
397
|
# copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
|
398
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
398
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
|
399
399
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
400
400
|
|
401
401
|
sigma_min: float = in_sigmas[-1].item()
|
@@ -414,9 +414,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
414
414
|
|
415
415
|
def step(
|
416
416
|
self,
|
417
|
-
model_output: Union[torch.
|
418
|
-
timestep: Union[float, torch.
|
419
|
-
sample: Union[torch.
|
417
|
+
model_output: Union[torch.Tensor, np.ndarray],
|
418
|
+
timestep: Union[float, torch.Tensor],
|
419
|
+
sample: Union[torch.Tensor, np.ndarray],
|
420
420
|
return_dict: bool = True,
|
421
421
|
s_noise: float = 1.0,
|
422
422
|
) -> Union[SchedulerOutput, Tuple]:
|
@@ -425,11 +425,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
425
425
|
process from the learned model outputs (most often the predicted noise).
|
426
426
|
|
427
427
|
Args:
|
428
|
-
model_output (`torch.
|
428
|
+
model_output (`torch.Tensor` or `np.ndarray`):
|
429
429
|
The direct output from learned diffusion model.
|
430
|
-
timestep (`float` or `torch.
|
430
|
+
timestep (`float` or `torch.Tensor`):
|
431
431
|
The current discrete timestep in the diffusion chain.
|
432
|
-
sample (`torch.
|
432
|
+
sample (`torch.Tensor` or `np.ndarray`):
|
433
433
|
A current instance of a sample created by the diffusion process.
|
434
434
|
return_dict (`bool`, *optional*, defaults to `True`):
|
435
435
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
@@ -450,10 +450,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
450
450
|
self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed)
|
451
451
|
|
452
452
|
# Define functions to compute sigma and t from each other
|
453
|
-
def sigma_fn(_t: torch.
|
453
|
+
def sigma_fn(_t: torch.Tensor) -> torch.Tensor:
|
454
454
|
return _t.neg().exp()
|
455
455
|
|
456
|
-
def t_fn(_sigma: torch.
|
456
|
+
def t_fn(_sigma: torch.Tensor) -> torch.Tensor:
|
457
457
|
return _sigma.log().neg()
|
458
458
|
|
459
459
|
if self.state_in_first_order:
|
@@ -526,10 +526,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
526
526
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
527
527
|
def add_noise(
|
528
528
|
self,
|
529
|
-
original_samples: torch.
|
530
|
-
noise: torch.
|
531
|
-
timesteps: torch.
|
532
|
-
) -> torch.
|
529
|
+
original_samples: torch.Tensor,
|
530
|
+
noise: torch.Tensor,
|
531
|
+
timesteps: torch.Tensor,
|
532
|
+
) -> torch.Tensor:
|
533
533
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
534
534
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
535
535
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -543,7 +543,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
|
543
543
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
544
544
|
if self.begin_index is None:
|
545
545
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
546
|
+
elif self.step_index is not None:
|
547
|
+
# add_noise is called after first denoising step (for inpainting)
|
548
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
546
549
|
else:
|
550
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
547
551
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
548
552
|
|
549
553
|
sigma = sigmas[step_indices].flatten()
|
@@ -63,7 +63,7 @@ def betas_for_alpha_bar(
|
|
63
63
|
return math.exp(t * -12.0)
|
64
64
|
|
65
65
|
else:
|
66
|
-
raise ValueError(f"Unsupported
|
66
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
67
67
|
|
68
68
|
betas = []
|
69
69
|
for i in range(num_diffusion_timesteps):
|
@@ -108,11 +108,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
108
108
|
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
109
109
|
`algorithm_type="dpmsolver++"`.
|
110
110
|
algorithm_type (`str`, defaults to `dpmsolver++`):
|
111
|
-
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
111
|
+
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
|
112
|
+
algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type
|
113
|
+
implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is
|
114
|
+
recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in
|
115
|
+
Stable Diffusion.
|
116
116
|
solver_type (`str`, defaults to `midpoint`):
|
117
117
|
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
118
118
|
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
@@ -123,8 +123,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
123
123
|
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
124
124
|
the sigmas are determined according to a sequence of noise levels {σi}.
|
125
125
|
final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
|
126
|
-
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
127
|
-
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
126
|
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
127
|
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
128
128
|
lambda_min_clipped (`float`, defaults to `-inf`):
|
129
129
|
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
130
130
|
cosine (`squaredcos_cap_v2`) noise schedule.
|
@@ -172,7 +172,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
172
172
|
# Glide cosine schedule
|
173
173
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
174
174
|
else:
|
175
|
-
raise NotImplementedError(f"{beta_schedule}
|
175
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
176
176
|
|
177
177
|
self.alphas = 1.0 - self.betas
|
178
178
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
@@ -190,12 +190,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
190
190
|
if algorithm_type == "deis":
|
191
191
|
self.register_to_config(algorithm_type="dpmsolver++")
|
192
192
|
else:
|
193
|
-
raise NotImplementedError(f"{algorithm_type}
|
193
|
+
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
194
194
|
if solver_type not in ["midpoint", "heun"]:
|
195
195
|
if solver_type in ["logrho", "bh1", "bh2"]:
|
196
196
|
self.register_to_config(solver_type="midpoint")
|
197
197
|
else:
|
198
|
-
raise NotImplementedError(f"{solver_type}
|
198
|
+
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
|
199
199
|
|
200
200
|
if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
|
201
201
|
raise ValueError(
|
@@ -252,7 +252,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
252
252
|
@property
|
253
253
|
def step_index(self):
|
254
254
|
"""
|
255
|
-
The index counter for current timestep. It will
|
255
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
256
256
|
"""
|
257
257
|
return self._step_index
|
258
258
|
|
@@ -274,7 +274,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
274
274
|
"""
|
275
275
|
self._begin_index = begin_index
|
276
276
|
|
277
|
-
def set_timesteps(
|
277
|
+
def set_timesteps(
|
278
|
+
self,
|
279
|
+
num_inference_steps: int = None,
|
280
|
+
device: Union[str, torch.device] = None,
|
281
|
+
timesteps: Optional[List[int]] = None,
|
282
|
+
):
|
278
283
|
"""
|
279
284
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
280
285
|
|
@@ -283,17 +288,33 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
283
288
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
284
289
|
device (`str` or `torch.device`, *optional*):
|
285
290
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
291
|
+
timesteps (`List[int]`, *optional*):
|
292
|
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
293
|
+
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
|
294
|
+
passed, `num_inference_steps` must be `None`.
|
286
295
|
"""
|
296
|
+
if num_inference_steps is None and timesteps is None:
|
297
|
+
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
298
|
+
if num_inference_steps is not None and timesteps is not None:
|
299
|
+
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
|
300
|
+
if timesteps is not None and self.config.use_karras_sigmas:
|
301
|
+
raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.")
|
302
|
+
|
303
|
+
num_inference_steps = num_inference_steps or len(timesteps)
|
287
304
|
self.num_inference_steps = num_inference_steps
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
.
|
295
|
-
|
296
|
-
|
305
|
+
|
306
|
+
if timesteps is not None:
|
307
|
+
timesteps = np.array(timesteps).astype(np.int64)
|
308
|
+
else:
|
309
|
+
# Clipping the minimum of all lambda(t) for numerical stability.
|
310
|
+
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
|
311
|
+
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
|
312
|
+
timesteps = (
|
313
|
+
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
|
314
|
+
.round()[::-1][:-1]
|
315
|
+
.copy()
|
316
|
+
.astype(np.int64)
|
317
|
+
)
|
297
318
|
|
298
319
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
299
320
|
if self.config.use_karras_sigmas:
|
@@ -340,7 +361,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
340
361
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
341
362
|
|
342
363
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
343
|
-
def _threshold_sample(self, sample: torch.
|
364
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
344
365
|
"""
|
345
366
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
346
367
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -405,7 +426,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
405
426
|
return alpha_t, sigma_t
|
406
427
|
|
407
428
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
408
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
429
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
409
430
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
410
431
|
|
411
432
|
# Hack to make sure that other schedulers which copy this function don't break
|
@@ -432,11 +453,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
432
453
|
|
433
454
|
def convert_model_output(
|
434
455
|
self,
|
435
|
-
model_output: torch.
|
456
|
+
model_output: torch.Tensor,
|
436
457
|
*args,
|
437
|
-
sample: torch.
|
458
|
+
sample: torch.Tensor = None,
|
438
459
|
**kwargs,
|
439
|
-
) -> torch.
|
460
|
+
) -> torch.Tensor:
|
440
461
|
"""
|
441
462
|
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
442
463
|
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
@@ -450,13 +471,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
450
471
|
</Tip>
|
451
472
|
|
452
473
|
Args:
|
453
|
-
model_output (`torch.
|
474
|
+
model_output (`torch.Tensor`):
|
454
475
|
The direct output from the learned diffusion model.
|
455
|
-
sample (`torch.
|
476
|
+
sample (`torch.Tensor`):
|
456
477
|
A current instance of a sample created by the diffusion process.
|
457
478
|
|
458
479
|
Returns:
|
459
|
-
`torch.
|
480
|
+
`torch.Tensor`:
|
460
481
|
The converted model output.
|
461
482
|
"""
|
462
483
|
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
@@ -521,26 +542,26 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
521
542
|
|
522
543
|
def dpm_solver_first_order_update(
|
523
544
|
self,
|
524
|
-
model_output: torch.
|
545
|
+
model_output: torch.Tensor,
|
525
546
|
*args,
|
526
|
-
sample: torch.
|
547
|
+
sample: torch.Tensor = None,
|
527
548
|
**kwargs,
|
528
|
-
) -> torch.
|
549
|
+
) -> torch.Tensor:
|
529
550
|
"""
|
530
551
|
One step for the first-order DPMSolver (equivalent to DDIM).
|
531
552
|
|
532
553
|
Args:
|
533
|
-
model_output (`torch.
|
554
|
+
model_output (`torch.Tensor`):
|
534
555
|
The direct output from the learned diffusion model.
|
535
556
|
timestep (`int`):
|
536
557
|
The current discrete timestep in the diffusion chain.
|
537
558
|
prev_timestep (`int`):
|
538
559
|
The previous discrete timestep in the diffusion chain.
|
539
|
-
sample (`torch.
|
560
|
+
sample (`torch.Tensor`):
|
540
561
|
A current instance of a sample created by the diffusion process.
|
541
562
|
|
542
563
|
Returns:
|
543
|
-
`torch.
|
564
|
+
`torch.Tensor`:
|
544
565
|
The sample tensor at the previous timestep.
|
545
566
|
"""
|
546
567
|
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
@@ -577,27 +598,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
577
598
|
|
578
599
|
def singlestep_dpm_solver_second_order_update(
|
579
600
|
self,
|
580
|
-
model_output_list: List[torch.
|
601
|
+
model_output_list: List[torch.Tensor],
|
581
602
|
*args,
|
582
|
-
sample: torch.
|
603
|
+
sample: torch.Tensor = None,
|
583
604
|
**kwargs,
|
584
|
-
) -> torch.
|
605
|
+
) -> torch.Tensor:
|
585
606
|
"""
|
586
607
|
One step for the second-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
|
587
608
|
time `timestep_list[-2]`.
|
588
609
|
|
589
610
|
Args:
|
590
|
-
model_output_list (`List[torch.
|
611
|
+
model_output_list (`List[torch.Tensor]`):
|
591
612
|
The direct outputs from learned diffusion model at current and latter timesteps.
|
592
613
|
timestep (`int`):
|
593
614
|
The current and latter discrete timestep in the diffusion chain.
|
594
615
|
prev_timestep (`int`):
|
595
616
|
The previous discrete timestep in the diffusion chain.
|
596
|
-
sample (`torch.
|
617
|
+
sample (`torch.Tensor`):
|
597
618
|
A current instance of a sample created by the diffusion process.
|
598
619
|
|
599
620
|
Returns:
|
600
|
-
`torch.
|
621
|
+
`torch.Tensor`:
|
601
622
|
The sample tensor at the previous timestep.
|
602
623
|
"""
|
603
624
|
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
@@ -671,27 +692,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
671
692
|
|
672
693
|
def singlestep_dpm_solver_third_order_update(
|
673
694
|
self,
|
674
|
-
model_output_list: List[torch.
|
695
|
+
model_output_list: List[torch.Tensor],
|
675
696
|
*args,
|
676
|
-
sample: torch.
|
697
|
+
sample: torch.Tensor = None,
|
677
698
|
**kwargs,
|
678
|
-
) -> torch.
|
699
|
+
) -> torch.Tensor:
|
679
700
|
"""
|
680
701
|
One step for the third-order singlestep DPMSolver that computes the solution at time `prev_timestep` from the
|
681
702
|
time `timestep_list[-3]`.
|
682
703
|
|
683
704
|
Args:
|
684
|
-
model_output_list (`List[torch.
|
705
|
+
model_output_list (`List[torch.Tensor]`):
|
685
706
|
The direct outputs from learned diffusion model at current and latter timesteps.
|
686
707
|
timestep (`int`):
|
687
708
|
The current and latter discrete timestep in the diffusion chain.
|
688
709
|
prev_timestep (`int`):
|
689
710
|
The previous discrete timestep in the diffusion chain.
|
690
|
-
sample (`torch.
|
711
|
+
sample (`torch.Tensor`):
|
691
712
|
A current instance of a sample created by diffusion process.
|
692
713
|
|
693
714
|
Returns:
|
694
|
-
`torch.
|
715
|
+
`torch.Tensor`:
|
695
716
|
The sample tensor at the previous timestep.
|
696
717
|
"""
|
697
718
|
|
@@ -775,29 +796,29 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
775
796
|
|
776
797
|
def singlestep_dpm_solver_update(
|
777
798
|
self,
|
778
|
-
model_output_list: List[torch.
|
799
|
+
model_output_list: List[torch.Tensor],
|
779
800
|
*args,
|
780
|
-
sample: torch.
|
801
|
+
sample: torch.Tensor = None,
|
781
802
|
order: int = None,
|
782
803
|
**kwargs,
|
783
|
-
) -> torch.
|
804
|
+
) -> torch.Tensor:
|
784
805
|
"""
|
785
806
|
One step for the singlestep DPMSolver.
|
786
807
|
|
787
808
|
Args:
|
788
|
-
model_output_list (`List[torch.
|
809
|
+
model_output_list (`List[torch.Tensor]`):
|
789
810
|
The direct outputs from learned diffusion model at current and latter timesteps.
|
790
811
|
timestep (`int`):
|
791
812
|
The current and latter discrete timestep in the diffusion chain.
|
792
813
|
prev_timestep (`int`):
|
793
814
|
The previous discrete timestep in the diffusion chain.
|
794
|
-
sample (`torch.
|
815
|
+
sample (`torch.Tensor`):
|
795
816
|
A current instance of a sample created by diffusion process.
|
796
817
|
order (`int`):
|
797
818
|
The solver order at this step.
|
798
819
|
|
799
820
|
Returns:
|
800
|
-
`torch.
|
821
|
+
`torch.Tensor`:
|
801
822
|
The sample tensor at the previous timestep.
|
802
823
|
"""
|
803
824
|
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
@@ -870,9 +891,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
870
891
|
|
871
892
|
def step(
|
872
893
|
self,
|
873
|
-
model_output: torch.
|
894
|
+
model_output: torch.Tensor,
|
874
895
|
timestep: int,
|
875
|
-
sample: torch.
|
896
|
+
sample: torch.Tensor,
|
876
897
|
return_dict: bool = True,
|
877
898
|
) -> Union[SchedulerOutput, Tuple]:
|
878
899
|
"""
|
@@ -880,11 +901,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
880
901
|
the singlestep DPMSolver.
|
881
902
|
|
882
903
|
Args:
|
883
|
-
model_output (`torch.
|
904
|
+
model_output (`torch.Tensor`):
|
884
905
|
The direct output from learned diffusion model.
|
885
906
|
timestep (`int`):
|
886
907
|
The current discrete timestep in the diffusion chain.
|
887
|
-
sample (`torch.
|
908
|
+
sample (`torch.Tensor`):
|
888
909
|
A current instance of a sample created by the diffusion process.
|
889
910
|
return_dict (`bool`):
|
890
911
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
@@ -929,17 +950,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
929
950
|
|
930
951
|
return SchedulerOutput(prev_sample=prev_sample)
|
931
952
|
|
932
|
-
def scale_model_input(self, sample: torch.
|
953
|
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
933
954
|
"""
|
934
955
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
935
956
|
current timestep.
|
936
957
|
|
937
958
|
Args:
|
938
|
-
sample (`torch.
|
959
|
+
sample (`torch.Tensor`):
|
939
960
|
The input sample.
|
940
961
|
|
941
962
|
Returns:
|
942
|
-
`torch.
|
963
|
+
`torch.Tensor`:
|
943
964
|
A scaled input sample.
|
944
965
|
"""
|
945
966
|
return sample
|
@@ -947,10 +968,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
947
968
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
948
969
|
def add_noise(
|
949
970
|
self,
|
950
|
-
original_samples: torch.
|
951
|
-
noise: torch.
|
971
|
+
original_samples: torch.Tensor,
|
972
|
+
noise: torch.Tensor,
|
952
973
|
timesteps: torch.IntTensor,
|
953
|
-
) -> torch.
|
974
|
+
) -> torch.Tensor:
|
954
975
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
955
976
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
956
977
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -961,10 +982,14 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|
961
982
|
schedule_timesteps = self.timesteps.to(original_samples.device)
|
962
983
|
timesteps = timesteps.to(original_samples.device)
|
963
984
|
|
964
|
-
# begin_index is None when the scheduler is used for training
|
985
|
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
965
986
|
if self.begin_index is None:
|
966
987
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
988
|
+
elif self.step_index is not None:
|
989
|
+
# add_noise is called after first denoising step (for inpainting)
|
990
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
967
991
|
else:
|
992
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
968
993
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
969
994
|
|
970
995
|
sigma = sigmas[step_indices].flatten()
|