diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -34,16 +34,16 @@ class DDPMParallelSchedulerOutput(BaseOutput):
|
|
34
34
|
Output class for the scheduler's `step` function output.
|
35
35
|
|
36
36
|
Args:
|
37
|
-
prev_sample (`torch.
|
37
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
38
38
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
39
39
|
denoising loop.
|
40
|
-
pred_original_sample (`torch.
|
40
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41
41
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
42
42
|
`pred_original_sample` can be used to preview progress or for guidance.
|
43
43
|
"""
|
44
44
|
|
45
|
-
prev_sample: torch.
|
46
|
-
pred_original_sample: Optional[torch.
|
45
|
+
prev_sample: torch.Tensor
|
46
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
47
47
|
|
48
48
|
|
49
49
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -81,7 +81,7 @@ def betas_for_alpha_bar(
|
|
81
81
|
return math.exp(t * -12.0)
|
82
82
|
|
83
83
|
else:
|
84
|
-
raise ValueError(f"Unsupported
|
84
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
85
85
|
|
86
86
|
betas = []
|
87
87
|
for i in range(num_diffusion_timesteps):
|
@@ -98,11 +98,11 @@ def rescale_zero_terminal_snr(betas):
|
|
98
98
|
|
99
99
|
|
100
100
|
Args:
|
101
|
-
betas (`torch.
|
101
|
+
betas (`torch.Tensor`):
|
102
102
|
the betas that the scheduler is being initialized with.
|
103
103
|
|
104
104
|
Returns:
|
105
|
-
`torch.
|
105
|
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
106
106
|
"""
|
107
107
|
# Convert betas to alphas_bar_sqrt
|
108
108
|
alphas = 1.0 - betas
|
@@ -219,7 +219,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
219
219
|
betas = torch.linspace(-6, 6, num_train_timesteps)
|
220
220
|
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
221
221
|
else:
|
222
|
-
raise NotImplementedError(f"{beta_schedule}
|
222
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
223
223
|
|
224
224
|
# Rescale for zero SNR
|
225
225
|
if rescale_betas_zero_snr:
|
@@ -240,19 +240,19 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
240
240
|
self.variance_type = variance_type
|
241
241
|
|
242
242
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input
|
243
|
-
def scale_model_input(self, sample: torch.
|
243
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
244
244
|
"""
|
245
245
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
246
246
|
current timestep.
|
247
247
|
|
248
248
|
Args:
|
249
|
-
sample (`torch.
|
249
|
+
sample (`torch.Tensor`):
|
250
250
|
The input sample.
|
251
251
|
timestep (`int`, *optional*):
|
252
252
|
The current timestep in the diffusion chain.
|
253
253
|
|
254
254
|
Returns:
|
255
|
-
`torch.
|
255
|
+
`torch.Tensor`:
|
256
256
|
A scaled input sample.
|
257
257
|
"""
|
258
258
|
return sample
|
@@ -375,7 +375,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
375
375
|
return variance
|
376
376
|
|
377
377
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
378
|
-
def _threshold_sample(self, sample: torch.
|
378
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
379
379
|
"""
|
380
380
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
381
381
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -410,9 +410,9 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
410
410
|
|
411
411
|
def step(
|
412
412
|
self,
|
413
|
-
model_output: torch.
|
413
|
+
model_output: torch.Tensor,
|
414
414
|
timestep: int,
|
415
|
-
sample: torch.
|
415
|
+
sample: torch.Tensor,
|
416
416
|
generator=None,
|
417
417
|
return_dict: bool = True,
|
418
418
|
) -> Union[DDPMParallelSchedulerOutput, Tuple]:
|
@@ -421,9 +421,9 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
421
421
|
process from the learned model outputs (most often the predicted noise).
|
422
422
|
|
423
423
|
Args:
|
424
|
-
model_output (`torch.
|
424
|
+
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
425
425
|
timestep (`int`): current discrete timestep in the diffusion chain.
|
426
|
-
sample (`torch.
|
426
|
+
sample (`torch.Tensor`):
|
427
427
|
current instance of sample being created by diffusion process.
|
428
428
|
generator: random number generator.
|
429
429
|
return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
|
@@ -506,10 +506,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
506
506
|
|
507
507
|
def batch_step_no_noise(
|
508
508
|
self,
|
509
|
-
model_output: torch.
|
509
|
+
model_output: torch.Tensor,
|
510
510
|
timesteps: List[int],
|
511
|
-
sample: torch.
|
512
|
-
) -> torch.
|
511
|
+
sample: torch.Tensor,
|
512
|
+
) -> torch.Tensor:
|
513
513
|
"""
|
514
514
|
Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once.
|
515
515
|
Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise
|
@@ -519,14 +519,14 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
519
519
|
process from the learned model outputs (most often the predicted noise).
|
520
520
|
|
521
521
|
Args:
|
522
|
-
model_output (`torch.
|
522
|
+
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
523
523
|
timesteps (`List[int]`):
|
524
524
|
current discrete timesteps in the diffusion chain. This is now a list of integers.
|
525
|
-
sample (`torch.
|
525
|
+
sample (`torch.Tensor`):
|
526
526
|
current instance of sample being created by diffusion process.
|
527
527
|
|
528
528
|
Returns:
|
529
|
-
`torch.
|
529
|
+
`torch.Tensor`: sample tensor at previous timestep.
|
530
530
|
"""
|
531
531
|
t = timesteps
|
532
532
|
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
@@ -587,10 +587,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
587
587
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
588
588
|
def add_noise(
|
589
589
|
self,
|
590
|
-
original_samples: torch.
|
591
|
-
noise: torch.
|
590
|
+
original_samples: torch.Tensor,
|
591
|
+
noise: torch.Tensor,
|
592
592
|
timesteps: torch.IntTensor,
|
593
|
-
) -> torch.
|
593
|
+
) -> torch.Tensor:
|
594
594
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
595
595
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
596
596
|
# for the subsequent add_noise calls
|
@@ -612,9 +612,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
612
612
|
return noisy_samples
|
613
613
|
|
614
614
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
615
|
-
def get_velocity(
|
616
|
-
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
617
|
-
) -> torch.FloatTensor:
|
615
|
+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
618
616
|
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
619
617
|
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
620
618
|
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
@@ -33,12 +33,12 @@ class DDPMWuerstchenSchedulerOutput(BaseOutput):
|
|
33
33
|
Output class for the scheduler's step function output.
|
34
34
|
|
35
35
|
Args:
|
36
|
-
prev_sample (`torch.
|
36
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
37
37
|
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
38
38
|
denoising loop.
|
39
39
|
"""
|
40
40
|
|
41
|
-
prev_sample: torch.
|
41
|
+
prev_sample: torch.Tensor
|
42
42
|
|
43
43
|
|
44
44
|
def betas_for_alpha_bar(
|
@@ -75,7 +75,7 @@ def betas_for_alpha_bar(
|
|
75
75
|
return math.exp(t * -12.0)
|
76
76
|
|
77
77
|
else:
|
78
|
-
raise ValueError(f"Unsupported
|
78
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
79
79
|
|
80
80
|
betas = []
|
81
81
|
for i in range(num_diffusion_timesteps):
|
@@ -125,17 +125,17 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
|
|
125
125
|
) ** 2 / self._init_alpha_cumprod.to(device)
|
126
126
|
return alpha_cumprod.clamp(0.0001, 0.9999)
|
127
127
|
|
128
|
-
def scale_model_input(self, sample: torch.
|
128
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
129
129
|
"""
|
130
130
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
131
131
|
current timestep.
|
132
132
|
|
133
133
|
Args:
|
134
|
-
sample (`torch.
|
134
|
+
sample (`torch.Tensor`): input sample
|
135
135
|
timestep (`int`, optional): current timestep
|
136
136
|
|
137
137
|
Returns:
|
138
|
-
`torch.
|
138
|
+
`torch.Tensor`: scaled input sample
|
139
139
|
"""
|
140
140
|
return sample
|
141
141
|
|
@@ -163,9 +163,9 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
|
|
163
163
|
|
164
164
|
def step(
|
165
165
|
self,
|
166
|
-
model_output: torch.
|
166
|
+
model_output: torch.Tensor,
|
167
167
|
timestep: int,
|
168
|
-
sample: torch.
|
168
|
+
sample: torch.Tensor,
|
169
169
|
generator=None,
|
170
170
|
return_dict: bool = True,
|
171
171
|
) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]:
|
@@ -174,9 +174,9 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
|
|
174
174
|
process from the learned model outputs (most often the predicted noise).
|
175
175
|
|
176
176
|
Args:
|
177
|
-
model_output (`torch.
|
177
|
+
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
178
178
|
timestep (`int`): current discrete timestep in the diffusion chain.
|
179
|
-
sample (`torch.
|
179
|
+
sample (`torch.Tensor`):
|
180
180
|
current instance of sample being created by diffusion process.
|
181
181
|
generator: random number generator.
|
182
182
|
return_dict (`bool`): option for returning tuple rather than DDPMWuerstchenSchedulerOutput class
|
@@ -209,10 +209,10 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
|
|
209
209
|
|
210
210
|
def add_noise(
|
211
211
|
self,
|
212
|
-
original_samples: torch.
|
213
|
-
noise: torch.
|
214
|
-
timesteps: torch.
|
215
|
-
) -> torch.
|
212
|
+
original_samples: torch.Tensor,
|
213
|
+
noise: torch.Tensor,
|
214
|
+
timesteps: torch.Tensor,
|
215
|
+
) -> torch.Tensor:
|
216
216
|
device = original_samples.device
|
217
217
|
dtype = original_samples.dtype
|
218
218
|
alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
|
@@ -61,7 +61,7 @@ def betas_for_alpha_bar(
|
|
61
61
|
return math.exp(t * -12.0)
|
62
62
|
|
63
63
|
else:
|
64
|
-
raise ValueError(f"Unsupported
|
64
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
65
65
|
|
66
66
|
betas = []
|
67
67
|
for i in range(num_diffusion_timesteps):
|
@@ -152,7 +152,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
152
152
|
# Glide cosine schedule
|
153
153
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
154
154
|
else:
|
155
|
-
raise NotImplementedError(f"{beta_schedule}
|
155
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
156
156
|
|
157
157
|
self.alphas = 1.0 - self.betas
|
158
158
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
@@ -170,13 +170,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
170
170
|
if algorithm_type in ["dpmsolver", "dpmsolver++"]:
|
171
171
|
self.register_to_config(algorithm_type="deis")
|
172
172
|
else:
|
173
|
-
raise NotImplementedError(f"{algorithm_type}
|
173
|
+
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
174
174
|
|
175
175
|
if solver_type not in ["logrho"]:
|
176
176
|
if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
|
177
177
|
self.register_to_config(solver_type="logrho")
|
178
178
|
else:
|
179
|
-
raise NotImplementedError(f"solver type {solver_type}
|
179
|
+
raise NotImplementedError(f"solver type {solver_type} is not implemented for {self.__class__}")
|
180
180
|
|
181
181
|
# setable values
|
182
182
|
self.num_inference_steps = None
|
@@ -191,7 +191,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
191
191
|
@property
|
192
192
|
def step_index(self):
|
193
193
|
"""
|
194
|
-
The index counter for current timestep. It will
|
194
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
195
195
|
"""
|
196
196
|
return self._step_index
|
197
197
|
|
@@ -276,7 +276,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
276
276
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
277
277
|
|
278
278
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
279
|
-
def _threshold_sample(self, sample: torch.
|
279
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
280
280
|
"""
|
281
281
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
282
282
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -341,7 +341,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
341
341
|
return alpha_t, sigma_t
|
342
342
|
|
343
343
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
344
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
344
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
345
345
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
346
346
|
|
347
347
|
# Hack to make sure that other schedulers which copy this function don't break
|
@@ -368,24 +368,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
368
368
|
|
369
369
|
def convert_model_output(
|
370
370
|
self,
|
371
|
-
model_output: torch.
|
371
|
+
model_output: torch.Tensor,
|
372
372
|
*args,
|
373
|
-
sample: torch.
|
373
|
+
sample: torch.Tensor = None,
|
374
374
|
**kwargs,
|
375
|
-
) -> torch.
|
375
|
+
) -> torch.Tensor:
|
376
376
|
"""
|
377
377
|
Convert the model output to the corresponding type the DEIS algorithm needs.
|
378
378
|
|
379
379
|
Args:
|
380
|
-
model_output (`torch.
|
380
|
+
model_output (`torch.Tensor`):
|
381
381
|
The direct output from the learned diffusion model.
|
382
382
|
timestep (`int`):
|
383
383
|
The current discrete timestep in the diffusion chain.
|
384
|
-
sample (`torch.
|
384
|
+
sample (`torch.Tensor`):
|
385
385
|
A current instance of a sample created by the diffusion process.
|
386
386
|
|
387
387
|
Returns:
|
388
|
-
`torch.
|
388
|
+
`torch.Tensor`:
|
389
389
|
The converted model output.
|
390
390
|
"""
|
391
391
|
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
@@ -425,26 +425,26 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
425
425
|
|
426
426
|
def deis_first_order_update(
|
427
427
|
self,
|
428
|
-
model_output: torch.
|
428
|
+
model_output: torch.Tensor,
|
429
429
|
*args,
|
430
|
-
sample: torch.
|
430
|
+
sample: torch.Tensor = None,
|
431
431
|
**kwargs,
|
432
|
-
) -> torch.
|
432
|
+
) -> torch.Tensor:
|
433
433
|
"""
|
434
434
|
One step for the first-order DEIS (equivalent to DDIM).
|
435
435
|
|
436
436
|
Args:
|
437
|
-
model_output (`torch.
|
437
|
+
model_output (`torch.Tensor`):
|
438
438
|
The direct output from the learned diffusion model.
|
439
439
|
timestep (`int`):
|
440
440
|
The current discrete timestep in the diffusion chain.
|
441
441
|
prev_timestep (`int`):
|
442
442
|
The previous discrete timestep in the diffusion chain.
|
443
|
-
sample (`torch.
|
443
|
+
sample (`torch.Tensor`):
|
444
444
|
A current instance of a sample created by the diffusion process.
|
445
445
|
|
446
446
|
Returns:
|
447
|
-
`torch.
|
447
|
+
`torch.Tensor`:
|
448
448
|
The sample tensor at the previous timestep.
|
449
449
|
"""
|
450
450
|
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
@@ -483,22 +483,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
483
483
|
|
484
484
|
def multistep_deis_second_order_update(
|
485
485
|
self,
|
486
|
-
model_output_list: List[torch.
|
486
|
+
model_output_list: List[torch.Tensor],
|
487
487
|
*args,
|
488
|
-
sample: torch.
|
488
|
+
sample: torch.Tensor = None,
|
489
489
|
**kwargs,
|
490
|
-
) -> torch.
|
490
|
+
) -> torch.Tensor:
|
491
491
|
"""
|
492
492
|
One step for the second-order multistep DEIS.
|
493
493
|
|
494
494
|
Args:
|
495
|
-
model_output_list (`List[torch.
|
495
|
+
model_output_list (`List[torch.Tensor]`):
|
496
496
|
The direct outputs from learned diffusion model at current and latter timesteps.
|
497
|
-
sample (`torch.
|
497
|
+
sample (`torch.Tensor`):
|
498
498
|
A current instance of a sample created by the diffusion process.
|
499
499
|
|
500
500
|
Returns:
|
501
|
-
`torch.
|
501
|
+
`torch.Tensor`:
|
502
502
|
The sample tensor at the previous timestep.
|
503
503
|
"""
|
504
504
|
timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
|
@@ -552,22 +552,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
552
552
|
|
553
553
|
def multistep_deis_third_order_update(
|
554
554
|
self,
|
555
|
-
model_output_list: List[torch.
|
555
|
+
model_output_list: List[torch.Tensor],
|
556
556
|
*args,
|
557
|
-
sample: torch.
|
557
|
+
sample: torch.Tensor = None,
|
558
558
|
**kwargs,
|
559
|
-
) -> torch.
|
559
|
+
) -> torch.Tensor:
|
560
560
|
"""
|
561
561
|
One step for the third-order multistep DEIS.
|
562
562
|
|
563
563
|
Args:
|
564
|
-
model_output_list (`List[torch.
|
564
|
+
model_output_list (`List[torch.Tensor]`):
|
565
565
|
The direct outputs from learned diffusion model at current and latter timesteps.
|
566
|
-
sample (`torch.
|
566
|
+
sample (`torch.Tensor`):
|
567
567
|
A current instance of a sample created by diffusion process.
|
568
568
|
|
569
569
|
Returns:
|
570
|
-
`torch.
|
570
|
+
`torch.Tensor`:
|
571
571
|
The sample tensor at the previous timestep.
|
572
572
|
"""
|
573
573
|
|
@@ -673,9 +673,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
673
673
|
|
674
674
|
def step(
|
675
675
|
self,
|
676
|
-
model_output: torch.
|
676
|
+
model_output: torch.Tensor,
|
677
677
|
timestep: int,
|
678
|
-
sample: torch.
|
678
|
+
sample: torch.Tensor,
|
679
679
|
return_dict: bool = True,
|
680
680
|
) -> Union[SchedulerOutput, Tuple]:
|
681
681
|
"""
|
@@ -683,11 +683,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
683
683
|
the multistep DEIS.
|
684
684
|
|
685
685
|
Args:
|
686
|
-
model_output (`torch.
|
686
|
+
model_output (`torch.Tensor`):
|
687
687
|
The direct output from learned diffusion model.
|
688
688
|
timestep (`float`):
|
689
689
|
The current discrete timestep in the diffusion chain.
|
690
|
-
sample (`torch.
|
690
|
+
sample (`torch.Tensor`):
|
691
691
|
A current instance of a sample created by the diffusion process.
|
692
692
|
return_dict (`bool`):
|
693
693
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
@@ -736,17 +736,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
736
736
|
|
737
737
|
return SchedulerOutput(prev_sample=prev_sample)
|
738
738
|
|
739
|
-
def scale_model_input(self, sample: torch.
|
739
|
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
740
740
|
"""
|
741
741
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
742
742
|
current timestep.
|
743
743
|
|
744
744
|
Args:
|
745
|
-
sample (`torch.
|
745
|
+
sample (`torch.Tensor`):
|
746
746
|
The input sample.
|
747
747
|
|
748
748
|
Returns:
|
749
|
-
`torch.
|
749
|
+
`torch.Tensor`:
|
750
750
|
A scaled input sample.
|
751
751
|
"""
|
752
752
|
return sample
|
@@ -754,10 +754,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
754
754
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
755
755
|
def add_noise(
|
756
756
|
self,
|
757
|
-
original_samples: torch.
|
758
|
-
noise: torch.
|
757
|
+
original_samples: torch.Tensor,
|
758
|
+
noise: torch.Tensor,
|
759
759
|
timesteps: torch.IntTensor,
|
760
|
-
) -> torch.
|
760
|
+
) -> torch.Tensor:
|
761
761
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
762
762
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
763
763
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -775,7 +775,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
775
775
|
# add_noise is called after first denoising step (for inpainting)
|
776
776
|
step_indices = [self.step_index] * timesteps.shape[0]
|
777
777
|
else:
|
778
|
-
# add noise is called
|
778
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
779
779
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
780
780
|
|
781
781
|
sigma = sigmas[step_indices].flatten()
|