diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -35,16 +35,16 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
|
|
35
35
|
Output class for the scheduler's `step` function output.
|
36
36
|
|
37
37
|
Args:
|
38
|
-
prev_sample (`torch.
|
38
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
39
39
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
40
40
|
denoising loop.
|
41
|
-
pred_original_sample (`torch.
|
41
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
42
42
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
43
43
|
`pred_original_sample` can be used to preview progress or for guidance.
|
44
44
|
"""
|
45
45
|
|
46
|
-
prev_sample: torch.
|
47
|
-
pred_original_sample: Optional[torch.
|
46
|
+
prev_sample: torch.Tensor
|
47
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
48
48
|
|
49
49
|
|
50
50
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -82,7 +82,7 @@ def betas_for_alpha_bar(
|
|
82
82
|
return math.exp(t * -12.0)
|
83
83
|
|
84
84
|
else:
|
85
|
-
raise ValueError(f"Unsupported
|
85
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
86
86
|
|
87
87
|
betas = []
|
88
88
|
for i in range(num_diffusion_timesteps):
|
@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas):
|
|
99
99
|
|
100
100
|
|
101
101
|
Args:
|
102
|
-
betas (`torch.
|
102
|
+
betas (`torch.Tensor`):
|
103
103
|
the betas that the scheduler is being initialized with.
|
104
104
|
|
105
105
|
Returns:
|
106
|
-
`torch.
|
106
|
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
107
107
|
"""
|
108
108
|
# Convert betas to alphas_bar_sqrt
|
109
109
|
alphas = 1.0 - betas
|
@@ -190,7 +190,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
190
190
|
# Glide cosine schedule
|
191
191
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
192
192
|
else:
|
193
|
-
raise NotImplementedError(f"{beta_schedule}
|
193
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
194
194
|
|
195
195
|
if rescale_betas_zero_snr:
|
196
196
|
self.betas = rescale_zero_terminal_snr(self.betas)
|
@@ -228,7 +228,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
228
228
|
@property
|
229
229
|
def step_index(self):
|
230
230
|
"""
|
231
|
-
The index counter for current timestep. It will
|
231
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
232
232
|
"""
|
233
233
|
return self._step_index
|
234
234
|
|
@@ -250,21 +250,19 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
250
250
|
"""
|
251
251
|
self._begin_index = begin_index
|
252
252
|
|
253
|
-
def scale_model_input(
|
254
|
-
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
255
|
-
) -> torch.FloatTensor:
|
253
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
256
254
|
"""
|
257
255
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
258
256
|
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
259
257
|
|
260
258
|
Args:
|
261
|
-
sample (`torch.
|
259
|
+
sample (`torch.Tensor`):
|
262
260
|
The input sample.
|
263
261
|
timestep (`int`, *optional*):
|
264
262
|
The current timestep in the diffusion chain.
|
265
263
|
|
266
264
|
Returns:
|
267
|
-
`torch.
|
265
|
+
`torch.Tensor`:
|
268
266
|
A scaled input sample.
|
269
267
|
"""
|
270
268
|
|
@@ -346,9 +344,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
346
344
|
|
347
345
|
def step(
|
348
346
|
self,
|
349
|
-
model_output: torch.
|
350
|
-
timestep: Union[float, torch.
|
351
|
-
sample: torch.
|
347
|
+
model_output: torch.Tensor,
|
348
|
+
timestep: Union[float, torch.Tensor],
|
349
|
+
sample: torch.Tensor,
|
352
350
|
generator: Optional[torch.Generator] = None,
|
353
351
|
return_dict: bool = True,
|
354
352
|
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
@@ -357,11 +355,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
357
355
|
process from the learned model outputs (most often the predicted noise).
|
358
356
|
|
359
357
|
Args:
|
360
|
-
model_output (`torch.
|
358
|
+
model_output (`torch.Tensor`):
|
361
359
|
The direct output from learned diffusion model.
|
362
360
|
timestep (`float`):
|
363
361
|
The current discrete timestep in the diffusion chain.
|
364
|
-
sample (`torch.
|
362
|
+
sample (`torch.Tensor`):
|
365
363
|
A current instance of a sample created by the diffusion process.
|
366
364
|
generator (`torch.Generator`, *optional*):
|
367
365
|
A random number generator.
|
@@ -377,11 +375,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
377
375
|
|
378
376
|
"""
|
379
377
|
|
380
|
-
if (
|
381
|
-
isinstance(timestep, int)
|
382
|
-
or isinstance(timestep, torch.IntTensor)
|
383
|
-
or isinstance(timestep, torch.LongTensor)
|
384
|
-
):
|
378
|
+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
385
379
|
raise ValueError(
|
386
380
|
(
|
387
381
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
@@ -450,10 +444,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
450
444
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
451
445
|
def add_noise(
|
452
446
|
self,
|
453
|
-
original_samples: torch.
|
454
|
-
noise: torch.
|
455
|
-
timesteps: torch.
|
456
|
-
) -> torch.
|
447
|
+
original_samples: torch.Tensor,
|
448
|
+
noise: torch.Tensor,
|
449
|
+
timesteps: torch.Tensor,
|
450
|
+
) -> torch.Tensor:
|
457
451
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
458
452
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
459
453
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -467,7 +461,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
467
461
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
468
462
|
if self.begin_index is None:
|
469
463
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
464
|
+
elif self.step_index is not None:
|
465
|
+
# add_noise is called after first denoising step (for inpainting)
|
466
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
470
467
|
else:
|
468
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
471
469
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
472
470
|
|
473
471
|
sigma = sigmas[step_indices].flatten()
|
@@ -35,16 +35,16 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
|
|
35
35
|
Output class for the scheduler's `step` function output.
|
36
36
|
|
37
37
|
Args:
|
38
|
-
prev_sample (`torch.
|
38
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
39
39
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
40
40
|
denoising loop.
|
41
|
-
pred_original_sample (`torch.
|
41
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
42
42
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
43
43
|
`pred_original_sample` can be used to preview progress or for guidance.
|
44
44
|
"""
|
45
45
|
|
46
|
-
prev_sample: torch.
|
47
|
-
pred_original_sample: Optional[torch.
|
46
|
+
prev_sample: torch.Tensor
|
47
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
48
48
|
|
49
49
|
|
50
50
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -82,7 +82,7 @@ def betas_for_alpha_bar(
|
|
82
82
|
return math.exp(t * -12.0)
|
83
83
|
|
84
84
|
else:
|
85
|
-
raise ValueError(f"Unsupported
|
85
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
86
86
|
|
87
87
|
betas = []
|
88
88
|
for i in range(num_diffusion_timesteps):
|
@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas):
|
|
99
99
|
|
100
100
|
|
101
101
|
Args:
|
102
|
-
betas (`torch.
|
102
|
+
betas (`torch.Tensor`):
|
103
103
|
the betas that the scheduler is being initialized with.
|
104
104
|
|
105
105
|
Returns:
|
106
|
-
`torch.
|
106
|
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
107
107
|
"""
|
108
108
|
# Convert betas to alphas_bar_sqrt
|
109
109
|
alphas = 1.0 - betas
|
@@ -167,6 +167,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
167
167
|
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
168
168
|
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
169
169
|
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
170
|
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
171
|
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
172
|
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
170
173
|
"""
|
171
174
|
|
172
175
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
@@ -189,6 +192,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
189
192
|
timestep_type: str = "discrete", # can be "discrete" or "continuous"
|
190
193
|
steps_offset: int = 0,
|
191
194
|
rescale_betas_zero_snr: bool = False,
|
195
|
+
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
192
196
|
):
|
193
197
|
if trained_betas is not None:
|
194
198
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
@@ -201,7 +205,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
201
205
|
# Glide cosine schedule
|
202
206
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
203
207
|
else:
|
204
|
-
raise NotImplementedError(f"{beta_schedule}
|
208
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
205
209
|
|
206
210
|
if rescale_betas_zero_snr:
|
207
211
|
self.betas = rescale_zero_terminal_snr(self.betas)
|
@@ -248,7 +252,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
248
252
|
@property
|
249
253
|
def step_index(self):
|
250
254
|
"""
|
251
|
-
The index counter for current timestep. It will
|
255
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
252
256
|
"""
|
253
257
|
return self._step_index
|
254
258
|
|
@@ -270,21 +274,19 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
270
274
|
"""
|
271
275
|
self._begin_index = begin_index
|
272
276
|
|
273
|
-
def scale_model_input(
|
274
|
-
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
275
|
-
) -> torch.FloatTensor:
|
277
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
276
278
|
"""
|
277
279
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
278
280
|
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
279
281
|
|
280
282
|
Args:
|
281
|
-
sample (`torch.
|
283
|
+
sample (`torch.Tensor`):
|
282
284
|
The input sample.
|
283
285
|
timestep (`int`, *optional*):
|
284
286
|
The current timestep in the diffusion chain.
|
285
287
|
|
286
288
|
Returns:
|
287
|
-
`torch.
|
289
|
+
`torch.Tensor`:
|
288
290
|
A scaled input sample.
|
289
291
|
"""
|
290
292
|
if self.step_index is None:
|
@@ -296,7 +298,13 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
296
298
|
self.is_scale_input_called = True
|
297
299
|
return sample
|
298
300
|
|
299
|
-
def set_timesteps(
|
301
|
+
def set_timesteps(
|
302
|
+
self,
|
303
|
+
num_inference_steps: int = None,
|
304
|
+
device: Union[str, torch.device] = None,
|
305
|
+
timesteps: Optional[List[int]] = None,
|
306
|
+
sigmas: Optional[List[float]] = None,
|
307
|
+
):
|
300
308
|
"""
|
301
309
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
302
310
|
|
@@ -305,60 +313,111 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
305
313
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
306
314
|
device (`str` or `torch.device`, *optional*):
|
307
315
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
316
|
+
timesteps (`List[int]`, *optional*):
|
317
|
+
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
318
|
+
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
319
|
+
must be `None`, and `timestep_spacing` attribute will be ignored.
|
320
|
+
sigmas (`List[float]`, *optional*):
|
321
|
+
Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas
|
322
|
+
will be generated based on the relevant scheduler attributes. If `sigmas` is passed,
|
323
|
+
`num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the
|
324
|
+
custom sigmas schedule.
|
308
325
|
"""
|
309
|
-
self.num_inference_steps = num_inference_steps
|
310
326
|
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
timesteps
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
# creates integer timesteps by multiplying by ratio
|
325
|
-
# casting to int to avoid issues when num_inference_step is power of 3
|
326
|
-
timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
327
|
-
timesteps -= 1
|
328
|
-
else:
|
327
|
+
if timesteps is not None and sigmas is not None:
|
328
|
+
raise ValueError("Only one of `timesteps` or `sigmas` should be set.")
|
329
|
+
if num_inference_steps is None and timesteps is None and sigmas is None:
|
330
|
+
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.")
|
331
|
+
if num_inference_steps is not None and (timesteps is not None or sigmas is not None):
|
332
|
+
raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.")
|
333
|
+
if timesteps is not None and self.config.use_karras_sigmas:
|
334
|
+
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
|
335
|
+
if (
|
336
|
+
timesteps is not None
|
337
|
+
and self.config.timestep_type == "continuous"
|
338
|
+
and self.config.prediction_type == "v_prediction"
|
339
|
+
):
|
329
340
|
raise ValueError(
|
330
|
-
|
341
|
+
"Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
|
331
342
|
)
|
332
343
|
|
333
|
-
|
334
|
-
|
344
|
+
if num_inference_steps is None:
|
345
|
+
num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
|
346
|
+
self.num_inference_steps = num_inference_steps
|
347
|
+
|
348
|
+
if sigmas is not None:
|
349
|
+
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
|
350
|
+
sigmas = np.array(sigmas).astype(np.float32)
|
351
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
|
335
352
|
|
336
|
-
if self.config.interpolation_type == "linear":
|
337
|
-
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
338
|
-
elif self.config.interpolation_type == "log_linear":
|
339
|
-
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
|
340
353
|
else:
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
354
|
+
if timesteps is not None:
|
355
|
+
timesteps = np.array(timesteps).astype(np.float32)
|
356
|
+
else:
|
357
|
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
358
|
+
if self.config.timestep_spacing == "linspace":
|
359
|
+
timesteps = np.linspace(
|
360
|
+
0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
|
361
|
+
)[::-1].copy()
|
362
|
+
elif self.config.timestep_spacing == "leading":
|
363
|
+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
364
|
+
# creates integer timesteps by multiplying by ratio
|
365
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
366
|
+
timesteps = (
|
367
|
+
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
368
|
+
)
|
369
|
+
timesteps += self.config.steps_offset
|
370
|
+
elif self.config.timestep_spacing == "trailing":
|
371
|
+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
372
|
+
# creates integer timesteps by multiplying by ratio
|
373
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
374
|
+
timesteps = (
|
375
|
+
(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
376
|
+
)
|
377
|
+
timesteps -= 1
|
378
|
+
else:
|
379
|
+
raise ValueError(
|
380
|
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
381
|
+
)
|
382
|
+
|
383
|
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
384
|
+
log_sigmas = np.log(sigmas)
|
385
|
+
if self.config.interpolation_type == "linear":
|
386
|
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
387
|
+
elif self.config.interpolation_type == "log_linear":
|
388
|
+
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
|
389
|
+
else:
|
390
|
+
raise ValueError(
|
391
|
+
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
|
392
|
+
" 'linear' or 'log_linear'"
|
393
|
+
)
|
394
|
+
|
395
|
+
if self.config.use_karras_sigmas:
|
396
|
+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
397
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
398
|
+
|
399
|
+
if self.config.final_sigmas_type == "sigma_min":
|
400
|
+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
401
|
+
elif self.config.final_sigmas_type == "zero":
|
402
|
+
sigma_last = 0
|
403
|
+
else:
|
404
|
+
raise ValueError(
|
405
|
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
406
|
+
)
|
407
|
+
|
408
|
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
349
409
|
|
350
410
|
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
351
411
|
|
352
412
|
# TODO: Support the full EDM scalings for all prediction types and timestep types
|
353
413
|
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
|
354
|
-
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
|
414
|
+
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
|
355
415
|
else:
|
356
416
|
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
|
357
417
|
|
358
|
-
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
359
418
|
self._step_index = None
|
360
419
|
self._begin_index = None
|
361
|
-
self.sigmas =
|
420
|
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
362
421
|
|
363
422
|
def _sigma_to_t(self, sigma, log_sigmas):
|
364
423
|
# get log sigma
|
@@ -384,7 +443,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
384
443
|
return t
|
385
444
|
|
386
445
|
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
|
387
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
446
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
388
447
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
389
448
|
|
390
449
|
# Hack to make sure that other schedulers which copy this function don't break
|
@@ -433,9 +492,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
433
492
|
|
434
493
|
def step(
|
435
494
|
self,
|
436
|
-
model_output: torch.
|
437
|
-
timestep: Union[float, torch.
|
438
|
-
sample: torch.
|
495
|
+
model_output: torch.Tensor,
|
496
|
+
timestep: Union[float, torch.Tensor],
|
497
|
+
sample: torch.Tensor,
|
439
498
|
s_churn: float = 0.0,
|
440
499
|
s_tmin: float = 0.0,
|
441
500
|
s_tmax: float = float("inf"),
|
@@ -448,11 +507,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
448
507
|
process from the learned model outputs (most often the predicted noise).
|
449
508
|
|
450
509
|
Args:
|
451
|
-
model_output (`torch.
|
510
|
+
model_output (`torch.Tensor`):
|
452
511
|
The direct output from learned diffusion model.
|
453
512
|
timestep (`float`):
|
454
513
|
The current discrete timestep in the diffusion chain.
|
455
|
-
sample (`torch.
|
514
|
+
sample (`torch.Tensor`):
|
456
515
|
A current instance of a sample created by the diffusion process.
|
457
516
|
s_churn (`float`):
|
458
517
|
s_tmin (`float`):
|
@@ -471,11 +530,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
471
530
|
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
472
531
|
"""
|
473
532
|
|
474
|
-
if (
|
475
|
-
isinstance(timestep, int)
|
476
|
-
or isinstance(timestep, torch.IntTensor)
|
477
|
-
or isinstance(timestep, torch.LongTensor)
|
478
|
-
):
|
533
|
+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
479
534
|
raise ValueError(
|
480
535
|
(
|
481
536
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
@@ -545,10 +600,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
545
600
|
|
546
601
|
def add_noise(
|
547
602
|
self,
|
548
|
-
original_samples: torch.
|
549
|
-
noise: torch.
|
550
|
-
timesteps: torch.
|
551
|
-
) -> torch.
|
603
|
+
original_samples: torch.Tensor,
|
604
|
+
noise: torch.Tensor,
|
605
|
+
timesteps: torch.Tensor,
|
606
|
+
) -> torch.Tensor:
|
552
607
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
553
608
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
554
609
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -562,7 +617,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
562
617
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
563
618
|
if self.begin_index is None:
|
564
619
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
620
|
+
elif self.step_index is not None:
|
621
|
+
# add_noise is called after first denoising step (for inpainting)
|
622
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
565
623
|
else:
|
624
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
566
625
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
567
626
|
|
568
627
|
sigma = sigmas[step_indices].flatten()
|
@@ -572,5 +631,42 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
572
631
|
noisy_samples = original_samples + noise * sigma
|
573
632
|
return noisy_samples
|
574
633
|
|
634
|
+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
|
635
|
+
if (
|
636
|
+
isinstance(timesteps, int)
|
637
|
+
or isinstance(timesteps, torch.IntTensor)
|
638
|
+
or isinstance(timesteps, torch.LongTensor)
|
639
|
+
):
|
640
|
+
raise ValueError(
|
641
|
+
(
|
642
|
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
643
|
+
" `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass"
|
644
|
+
" one of the `scheduler.timesteps` as a timestep."
|
645
|
+
),
|
646
|
+
)
|
647
|
+
|
648
|
+
if sample.device.type == "mps" and torch.is_floating_point(timesteps):
|
649
|
+
# mps does not support float64
|
650
|
+
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
651
|
+
timesteps = timesteps.to(sample.device, dtype=torch.float32)
|
652
|
+
else:
|
653
|
+
schedule_timesteps = self.timesteps.to(sample.device)
|
654
|
+
timesteps = timesteps.to(sample.device)
|
655
|
+
|
656
|
+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
657
|
+
alphas_cumprod = self.alphas_cumprod.to(sample)
|
658
|
+
sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5
|
659
|
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
660
|
+
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
661
|
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
662
|
+
|
663
|
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5
|
664
|
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
665
|
+
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
666
|
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
667
|
+
|
668
|
+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
669
|
+
return velocity
|
670
|
+
|
575
671
|
def __len__(self):
|
576
672
|
return self.config.num_train_timesteps
|