diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -32,16 +32,16 @@ class LMSDiscreteSchedulerOutput(BaseOutput):
|
|
32
32
|
Output class for the scheduler's `step` function output.
|
33
33
|
|
34
34
|
Args:
|
35
|
-
prev_sample (`torch.
|
35
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
36
36
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
37
37
|
denoising loop.
|
38
|
-
pred_original_sample (`torch.
|
38
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
39
39
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
40
40
|
`pred_original_sample` can be used to preview progress or for guidance.
|
41
41
|
"""
|
42
42
|
|
43
|
-
prev_sample: torch.
|
44
|
-
pred_original_sample: Optional[torch.
|
43
|
+
prev_sample: torch.Tensor
|
44
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
45
45
|
|
46
46
|
|
47
47
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -79,7 +79,7 @@ def betas_for_alpha_bar(
|
|
79
79
|
return math.exp(t * -12.0)
|
80
80
|
|
81
81
|
else:
|
82
|
-
raise ValueError(f"Unsupported
|
82
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
83
83
|
|
84
84
|
betas = []
|
85
85
|
for i in range(num_diffusion_timesteps):
|
@@ -149,7 +149,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
149
149
|
# Glide cosine schedule
|
150
150
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
151
151
|
else:
|
152
|
-
raise NotImplementedError(f"{beta_schedule}
|
152
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
153
153
|
|
154
154
|
self.alphas = 1.0 - self.betas
|
155
155
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
@@ -180,7 +180,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
180
180
|
@property
|
181
181
|
def step_index(self):
|
182
182
|
"""
|
183
|
-
The index counter for current timestep. It will
|
183
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
184
184
|
"""
|
185
185
|
return self._step_index
|
186
186
|
|
@@ -202,21 +202,19 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
202
202
|
"""
|
203
203
|
self._begin_index = begin_index
|
204
204
|
|
205
|
-
def scale_model_input(
|
206
|
-
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
207
|
-
) -> torch.FloatTensor:
|
205
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
208
206
|
"""
|
209
207
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
210
208
|
current timestep.
|
211
209
|
|
212
210
|
Args:
|
213
|
-
sample (`torch.
|
211
|
+
sample (`torch.Tensor`):
|
214
212
|
The input sample.
|
215
|
-
timestep (`float` or `torch.
|
213
|
+
timestep (`float` or `torch.Tensor`):
|
216
214
|
The current timestep in the diffusion chain.
|
217
215
|
|
218
216
|
Returns:
|
219
|
-
`torch.
|
217
|
+
`torch.Tensor`:
|
220
218
|
A scaled input sample.
|
221
219
|
"""
|
222
220
|
|
@@ -288,7 +286,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
288
286
|
log_sigmas = np.log(sigmas)
|
289
287
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
290
288
|
|
291
|
-
if self.use_karras_sigmas:
|
289
|
+
if self.config.use_karras_sigmas:
|
292
290
|
sigmas = self._convert_to_karras(in_sigmas=sigmas)
|
293
291
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
294
292
|
|
@@ -351,7 +349,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
351
349
|
return t
|
352
350
|
|
353
351
|
# copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
|
354
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
352
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
|
355
353
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
356
354
|
|
357
355
|
sigma_min: float = in_sigmas[-1].item()
|
@@ -366,9 +364,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
366
364
|
|
367
365
|
def step(
|
368
366
|
self,
|
369
|
-
model_output: torch.
|
370
|
-
timestep: Union[float, torch.
|
371
|
-
sample: torch.
|
367
|
+
model_output: torch.Tensor,
|
368
|
+
timestep: Union[float, torch.Tensor],
|
369
|
+
sample: torch.Tensor,
|
372
370
|
order: int = 4,
|
373
371
|
return_dict: bool = True,
|
374
372
|
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
|
@@ -377,11 +375,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
377
375
|
process from the learned model outputs (most often the predicted noise).
|
378
376
|
|
379
377
|
Args:
|
380
|
-
model_output (`torch.
|
378
|
+
model_output (`torch.Tensor`):
|
381
379
|
The direct output from learned diffusion model.
|
382
|
-
timestep (`float` or `torch.
|
380
|
+
timestep (`float` or `torch.Tensor`):
|
383
381
|
The current discrete timestep in the diffusion chain.
|
384
|
-
sample (`torch.
|
382
|
+
sample (`torch.Tensor`):
|
385
383
|
A current instance of a sample created by the diffusion process.
|
386
384
|
order (`int`, defaults to 4):
|
387
385
|
The order of the linear multistep method.
|
@@ -444,10 +442,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
444
442
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
445
443
|
def add_noise(
|
446
444
|
self,
|
447
|
-
original_samples: torch.
|
448
|
-
noise: torch.
|
449
|
-
timesteps: torch.
|
450
|
-
) -> torch.
|
445
|
+
original_samples: torch.Tensor,
|
446
|
+
noise: torch.Tensor,
|
447
|
+
timesteps: torch.Tensor,
|
448
|
+
) -> torch.Tensor:
|
451
449
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
452
450
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
453
451
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -461,7 +459,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
461
459
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
462
460
|
if self.begin_index is None:
|
463
461
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
462
|
+
elif self.step_index is not None:
|
463
|
+
# add_noise is called after first denoising step (for inpainting)
|
464
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
464
465
|
else:
|
466
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
465
467
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
466
468
|
|
467
469
|
sigma = sigmas[step_indices].flatten()
|
@@ -59,7 +59,7 @@ def betas_for_alpha_bar(
|
|
59
59
|
return math.exp(t * -12.0)
|
60
60
|
|
61
61
|
else:
|
62
|
-
raise ValueError(f"Unsupported
|
62
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
63
63
|
|
64
64
|
betas = []
|
65
65
|
for i in range(num_diffusion_timesteps):
|
@@ -135,7 +135,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
135
135
|
# Glide cosine schedule
|
136
136
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
137
137
|
else:
|
138
|
-
raise NotImplementedError(f"{beta_schedule}
|
138
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
139
139
|
|
140
140
|
self.alphas = 1.0 - self.betas
|
141
141
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
@@ -225,9 +225,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
225
225
|
|
226
226
|
def step(
|
227
227
|
self,
|
228
|
-
model_output: torch.
|
228
|
+
model_output: torch.Tensor,
|
229
229
|
timestep: int,
|
230
|
-
sample: torch.
|
230
|
+
sample: torch.Tensor,
|
231
231
|
return_dict: bool = True,
|
232
232
|
) -> Union[SchedulerOutput, Tuple]:
|
233
233
|
"""
|
@@ -236,11 +236,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
236
236
|
or [`~PNDMScheduler.step_plms`] depending on the internal variable `counter`.
|
237
237
|
|
238
238
|
Args:
|
239
|
-
model_output (`torch.
|
239
|
+
model_output (`torch.Tensor`):
|
240
240
|
The direct output from learned diffusion model.
|
241
241
|
timestep (`int`):
|
242
242
|
The current discrete timestep in the diffusion chain.
|
243
|
-
sample (`torch.
|
243
|
+
sample (`torch.Tensor`):
|
244
244
|
A current instance of a sample created by the diffusion process.
|
245
245
|
return_dict (`bool`):
|
246
246
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
@@ -258,9 +258,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
258
258
|
|
259
259
|
def step_prk(
|
260
260
|
self,
|
261
|
-
model_output: torch.
|
261
|
+
model_output: torch.Tensor,
|
262
262
|
timestep: int,
|
263
|
-
sample: torch.
|
263
|
+
sample: torch.Tensor,
|
264
264
|
return_dict: bool = True,
|
265
265
|
) -> Union[SchedulerOutput, Tuple]:
|
266
266
|
"""
|
@@ -269,11 +269,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
269
269
|
equation.
|
270
270
|
|
271
271
|
Args:
|
272
|
-
model_output (`torch.
|
272
|
+
model_output (`torch.Tensor`):
|
273
273
|
The direct output from learned diffusion model.
|
274
274
|
timestep (`int`):
|
275
275
|
The current discrete timestep in the diffusion chain.
|
276
|
-
sample (`torch.
|
276
|
+
sample (`torch.Tensor`):
|
277
277
|
A current instance of a sample created by the diffusion process.
|
278
278
|
return_dict (`bool`):
|
279
279
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
@@ -318,9 +318,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
318
318
|
|
319
319
|
def step_plms(
|
320
320
|
self,
|
321
|
-
model_output: torch.
|
321
|
+
model_output: torch.Tensor,
|
322
322
|
timestep: int,
|
323
|
-
sample: torch.
|
323
|
+
sample: torch.Tensor,
|
324
324
|
return_dict: bool = True,
|
325
325
|
) -> Union[SchedulerOutput, Tuple]:
|
326
326
|
"""
|
@@ -328,11 +328,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
328
328
|
the linear multistep method. It performs one forward pass multiple times to approximate the solution.
|
329
329
|
|
330
330
|
Args:
|
331
|
-
model_output (`torch.
|
331
|
+
model_output (`torch.Tensor`):
|
332
332
|
The direct output from learned diffusion model.
|
333
333
|
timestep (`int`):
|
334
334
|
The current discrete timestep in the diffusion chain.
|
335
|
-
sample (`torch.
|
335
|
+
sample (`torch.Tensor`):
|
336
336
|
A current instance of a sample created by the diffusion process.
|
337
337
|
return_dict (`bool`):
|
338
338
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
@@ -387,17 +387,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
387
387
|
|
388
388
|
return SchedulerOutput(prev_sample=prev_sample)
|
389
389
|
|
390
|
-
def scale_model_input(self, sample: torch.
|
390
|
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
391
391
|
"""
|
392
392
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
393
393
|
current timestep.
|
394
394
|
|
395
395
|
Args:
|
396
|
-
sample (`torch.
|
396
|
+
sample (`torch.Tensor`):
|
397
397
|
The input sample.
|
398
398
|
|
399
399
|
Returns:
|
400
|
-
`torch.
|
400
|
+
`torch.Tensor`:
|
401
401
|
A scaled input sample.
|
402
402
|
"""
|
403
403
|
return sample
|
@@ -448,10 +448,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
448
448
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
449
449
|
def add_noise(
|
450
450
|
self,
|
451
|
-
original_samples: torch.
|
452
|
-
noise: torch.
|
451
|
+
original_samples: torch.Tensor,
|
452
|
+
noise: torch.Tensor,
|
453
453
|
timesteps: torch.IntTensor,
|
454
|
-
) -> torch.
|
454
|
+
) -> torch.Tensor:
|
455
455
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
456
456
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
457
457
|
# for the subsequent add_noise calls
|
@@ -31,16 +31,16 @@ class RePaintSchedulerOutput(BaseOutput):
|
|
31
31
|
Output class for the scheduler's step function output.
|
32
32
|
|
33
33
|
Args:
|
34
|
-
prev_sample (`torch.
|
34
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
35
35
|
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
36
36
|
denoising loop.
|
37
|
-
pred_original_sample (`torch.
|
37
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
38
38
|
The predicted denoised sample (x_{0}) based on the model output from
|
39
39
|
the current timestep. `pred_original_sample` can be used to preview progress or for guidance.
|
40
40
|
"""
|
41
41
|
|
42
|
-
prev_sample: torch.
|
43
|
-
pred_original_sample: torch.
|
42
|
+
prev_sample: torch.Tensor
|
43
|
+
pred_original_sample: torch.Tensor
|
44
44
|
|
45
45
|
|
46
46
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -78,7 +78,7 @@ def betas_for_alpha_bar(
|
|
78
78
|
return math.exp(t * -12.0)
|
79
79
|
|
80
80
|
else:
|
81
|
-
raise ValueError(f"Unsupported
|
81
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
82
82
|
|
83
83
|
betas = []
|
84
84
|
for i in range(num_diffusion_timesteps):
|
@@ -143,7 +143,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
|
143
143
|
betas = torch.linspace(-6, 6, num_train_timesteps)
|
144
144
|
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
145
145
|
else:
|
146
|
-
raise NotImplementedError(f"{beta_schedule}
|
146
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
147
147
|
|
148
148
|
self.alphas = 1.0 - self.betas
|
149
149
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
@@ -160,19 +160,19 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
|
160
160
|
|
161
161
|
self.eta = eta
|
162
162
|
|
163
|
-
def scale_model_input(self, sample: torch.
|
163
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
164
164
|
"""
|
165
165
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
166
166
|
current timestep.
|
167
167
|
|
168
168
|
Args:
|
169
|
-
sample (`torch.
|
169
|
+
sample (`torch.Tensor`):
|
170
170
|
The input sample.
|
171
171
|
timestep (`int`, *optional*):
|
172
172
|
The current timestep in the diffusion chain.
|
173
173
|
|
174
174
|
Returns:
|
175
|
-
`torch.
|
175
|
+
`torch.Tensor`:
|
176
176
|
A scaled input sample.
|
177
177
|
"""
|
178
178
|
return sample
|
@@ -245,11 +245,11 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
|
245
245
|
|
246
246
|
def step(
|
247
247
|
self,
|
248
|
-
model_output: torch.
|
248
|
+
model_output: torch.Tensor,
|
249
249
|
timestep: int,
|
250
|
-
sample: torch.
|
251
|
-
original_image: torch.
|
252
|
-
mask: torch.
|
250
|
+
sample: torch.Tensor,
|
251
|
+
original_image: torch.Tensor,
|
252
|
+
mask: torch.Tensor,
|
253
253
|
generator: Optional[torch.Generator] = None,
|
254
254
|
return_dict: bool = True,
|
255
255
|
) -> Union[RePaintSchedulerOutput, Tuple]:
|
@@ -258,15 +258,15 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
|
258
258
|
process from the learned model outputs (most often the predicted noise).
|
259
259
|
|
260
260
|
Args:
|
261
|
-
model_output (`torch.
|
261
|
+
model_output (`torch.Tensor`):
|
262
262
|
The direct output from learned diffusion model.
|
263
263
|
timestep (`int`):
|
264
264
|
The current discrete timestep in the diffusion chain.
|
265
|
-
sample (`torch.
|
265
|
+
sample (`torch.Tensor`):
|
266
266
|
A current instance of a sample created by the diffusion process.
|
267
|
-
original_image (`torch.
|
267
|
+
original_image (`torch.Tensor`):
|
268
268
|
The original image to inpaint on.
|
269
|
-
mask (`torch.
|
269
|
+
mask (`torch.Tensor`):
|
270
270
|
The mask where a value of 0.0 indicates which part of the original image to inpaint.
|
271
271
|
generator (`torch.Generator`, *optional*):
|
272
272
|
A random number generator.
|
@@ -351,10 +351,10 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
|
351
351
|
|
352
352
|
def add_noise(
|
353
353
|
self,
|
354
|
-
original_samples: torch.
|
355
|
-
noise: torch.
|
354
|
+
original_samples: torch.Tensor,
|
355
|
+
noise: torch.Tensor,
|
356
356
|
timesteps: torch.IntTensor,
|
357
|
-
) -> torch.
|
357
|
+
) -> torch.Tensor:
|
358
358
|
raise NotImplementedError("Use `DDPMScheduler.add_noise()` to train for sampling with RePaint.")
|
359
359
|
|
360
360
|
def __len__(self):
|
@@ -62,7 +62,7 @@ def betas_for_alpha_bar(
|
|
62
62
|
return math.exp(t * -12.0)
|
63
63
|
|
64
64
|
else:
|
65
|
-
raise ValueError(f"Unsupported
|
65
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
66
66
|
|
67
67
|
betas = []
|
68
68
|
for i in range(num_diffusion_timesteps):
|
@@ -92,19 +92,20 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
92
92
|
trained_betas (`np.ndarray`, *optional*):
|
93
93
|
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
94
94
|
predictor_order (`int`, defaults to 2):
|
95
|
-
The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for
|
96
|
-
sampling, and `predictor_order=3` for unconditional sampling.
|
95
|
+
The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for
|
96
|
+
guided sampling, and `predictor_order=3` for unconditional sampling.
|
97
97
|
corrector_order (`int`, defaults to 2):
|
98
|
-
The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for
|
99
|
-
sampling, and `corrector_order=3` for unconditional sampling.
|
98
|
+
The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for
|
99
|
+
guided sampling, and `corrector_order=3` for unconditional sampling.
|
100
100
|
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
101
101
|
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
102
102
|
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
103
103
|
Video](https://imagen.research.google/video/paper.pdf) paper).
|
104
104
|
tau_func (`Callable`, *optional*):
|
105
|
-
Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`.
|
106
|
-
will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample
|
107
|
-
diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check
|
105
|
+
Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`.
|
106
|
+
SA-Solver will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample
|
107
|
+
from vanilla diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check
|
108
|
+
https://arxiv.org/abs/2309.05019
|
108
109
|
thresholding (`bool`, defaults to `False`):
|
109
110
|
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
110
111
|
as Stable Diffusion.
|
@@ -114,8 +115,8 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
114
115
|
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
115
116
|
`algorithm_type="dpmsolver++"`.
|
116
117
|
algorithm_type (`str`, defaults to `data_prediction`):
|
117
|
-
Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use
|
118
|
-
with `solver_order=2` for guided sampling like in Stable Diffusion.
|
118
|
+
Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use
|
119
|
+
`data_prediction` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
119
120
|
lower_order_final (`bool`, defaults to `True`):
|
120
121
|
Whether to use lower-order solvers in the final steps. Default = True.
|
121
122
|
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
@@ -179,7 +180,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
179
180
|
# Glide cosine schedule
|
180
181
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
181
182
|
else:
|
182
|
-
raise NotImplementedError(f"{beta_schedule}
|
183
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
183
184
|
|
184
185
|
self.alphas = 1.0 - self.betas
|
185
186
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
@@ -193,7 +194,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
193
194
|
self.init_noise_sigma = 1.0
|
194
195
|
|
195
196
|
if algorithm_type not in ["data_prediction", "noise_prediction"]:
|
196
|
-
raise NotImplementedError(f"{algorithm_type}
|
197
|
+
raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
|
197
198
|
|
198
199
|
# setable values
|
199
200
|
self.num_inference_steps = None
|
@@ -216,7 +217,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
216
217
|
@property
|
217
218
|
def step_index(self):
|
218
219
|
"""
|
219
|
-
The index counter for current timestep. It will
|
220
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
220
221
|
"""
|
221
222
|
return self._step_index
|
222
223
|
|
@@ -304,7 +305,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
304
305
|
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
305
306
|
|
306
307
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
307
|
-
def _threshold_sample(self, sample: torch.
|
308
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
308
309
|
"""
|
309
310
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
310
311
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -369,7 +370,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
369
370
|
return alpha_t, sigma_t
|
370
371
|
|
371
372
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
372
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
373
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
373
374
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
374
375
|
|
375
376
|
# Hack to make sure that other schedulers which copy this function don't break
|
@@ -396,31 +397,31 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
396
397
|
|
397
398
|
def convert_model_output(
|
398
399
|
self,
|
399
|
-
model_output: torch.
|
400
|
+
model_output: torch.Tensor,
|
400
401
|
*args,
|
401
|
-
sample: torch.
|
402
|
+
sample: torch.Tensor = None,
|
402
403
|
**kwargs,
|
403
|
-
) -> torch.
|
404
|
+
) -> torch.Tensor:
|
404
405
|
"""
|
405
|
-
Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs.
|
406
|
-
designed to discretize an integral of the noise prediction model, and data_prediction is
|
407
|
-
integral of the data prediction model.
|
406
|
+
Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs.
|
407
|
+
Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is
|
408
|
+
designed to discretize an integral of the data prediction model.
|
408
409
|
|
409
410
|
<Tip>
|
410
411
|
|
411
|
-
The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction for both
|
412
|
-
prediction and data prediction models.
|
412
|
+
The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction for both
|
413
|
+
noise prediction and data prediction models.
|
413
414
|
|
414
415
|
</Tip>
|
415
416
|
|
416
417
|
Args:
|
417
|
-
model_output (`torch.
|
418
|
+
model_output (`torch.Tensor`):
|
418
419
|
The direct output from the learned diffusion model.
|
419
|
-
sample (`torch.
|
420
|
+
sample (`torch.Tensor`):
|
420
421
|
A current instance of a sample created by the diffusion process.
|
421
422
|
|
422
423
|
Returns:
|
423
|
-
`torch.
|
424
|
+
`torch.Tensor`:
|
424
425
|
The converted model output.
|
425
426
|
"""
|
426
427
|
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
@@ -685,29 +686,29 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
685
686
|
|
686
687
|
def stochastic_adams_bashforth_update(
|
687
688
|
self,
|
688
|
-
model_output: torch.
|
689
|
+
model_output: torch.Tensor,
|
689
690
|
*args,
|
690
|
-
sample: torch.
|
691
|
-
noise: torch.
|
691
|
+
sample: torch.Tensor,
|
692
|
+
noise: torch.Tensor,
|
692
693
|
order: int,
|
693
|
-
tau: torch.
|
694
|
+
tau: torch.Tensor,
|
694
695
|
**kwargs,
|
695
|
-
) -> torch.
|
696
|
+
) -> torch.Tensor:
|
696
697
|
"""
|
697
698
|
One step for the SA-Predictor.
|
698
699
|
|
699
700
|
Args:
|
700
|
-
model_output (`torch.
|
701
|
+
model_output (`torch.Tensor`):
|
701
702
|
The direct output from the learned diffusion model at the current timestep.
|
702
703
|
prev_timestep (`int`):
|
703
704
|
The previous discrete timestep in the diffusion chain.
|
704
|
-
sample (`torch.
|
705
|
+
sample (`torch.Tensor`):
|
705
706
|
A current instance of a sample created by the diffusion process.
|
706
707
|
order (`int`):
|
707
708
|
The order of SA-Predictor at this timestep.
|
708
709
|
|
709
710
|
Returns:
|
710
|
-
`torch.
|
711
|
+
`torch.Tensor`:
|
711
712
|
The sample tensor at the previous timestep.
|
712
713
|
"""
|
713
714
|
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
@@ -812,32 +813,32 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
812
813
|
|
813
814
|
def stochastic_adams_moulton_update(
|
814
815
|
self,
|
815
|
-
this_model_output: torch.
|
816
|
+
this_model_output: torch.Tensor,
|
816
817
|
*args,
|
817
|
-
last_sample: torch.
|
818
|
-
last_noise: torch.
|
819
|
-
this_sample: torch.
|
818
|
+
last_sample: torch.Tensor,
|
819
|
+
last_noise: torch.Tensor,
|
820
|
+
this_sample: torch.Tensor,
|
820
821
|
order: int,
|
821
|
-
tau: torch.
|
822
|
+
tau: torch.Tensor,
|
822
823
|
**kwargs,
|
823
|
-
) -> torch.
|
824
|
+
) -> torch.Tensor:
|
824
825
|
"""
|
825
826
|
One step for the SA-Corrector.
|
826
827
|
|
827
828
|
Args:
|
828
|
-
this_model_output (`torch.
|
829
|
+
this_model_output (`torch.Tensor`):
|
829
830
|
The model outputs at `x_t`.
|
830
831
|
this_timestep (`int`):
|
831
832
|
The current timestep `t`.
|
832
|
-
last_sample (`torch.
|
833
|
+
last_sample (`torch.Tensor`):
|
833
834
|
The generated sample before the last predictor `x_{t-1}`.
|
834
|
-
this_sample (`torch.
|
835
|
+
this_sample (`torch.Tensor`):
|
835
836
|
The generated sample after the last predictor `x_{t}`.
|
836
837
|
order (`int`):
|
837
838
|
The order of SA-Corrector at this step.
|
838
839
|
|
839
840
|
Returns:
|
840
|
-
`torch.
|
841
|
+
`torch.Tensor`:
|
841
842
|
The corrected sample tensor at the current timestep.
|
842
843
|
"""
|
843
844
|
|
@@ -978,9 +979,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
978
979
|
|
979
980
|
def step(
|
980
981
|
self,
|
981
|
-
model_output: torch.
|
982
|
+
model_output: torch.Tensor,
|
982
983
|
timestep: int,
|
983
|
-
sample: torch.
|
984
|
+
sample: torch.Tensor,
|
984
985
|
generator=None,
|
985
986
|
return_dict: bool = True,
|
986
987
|
) -> Union[SchedulerOutput, Tuple]:
|
@@ -989,11 +990,11 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
989
990
|
the SA-Solver.
|
990
991
|
|
991
992
|
Args:
|
992
|
-
model_output (`torch.
|
993
|
+
model_output (`torch.Tensor`):
|
993
994
|
The direct output from learned diffusion model.
|
994
995
|
timestep (`int`):
|
995
996
|
The current discrete timestep in the diffusion chain.
|
996
|
-
sample (`torch.
|
997
|
+
sample (`torch.Tensor`):
|
997
998
|
A current instance of a sample created by the diffusion process.
|
998
999
|
generator (`torch.Generator`, *optional*):
|
999
1000
|
A random number generator.
|
@@ -1078,17 +1079,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
1078
1079
|
|
1079
1080
|
return SchedulerOutput(prev_sample=prev_sample)
|
1080
1081
|
|
1081
|
-
def scale_model_input(self, sample: torch.
|
1082
|
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
1082
1083
|
"""
|
1083
1084
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
1084
1085
|
current timestep.
|
1085
1086
|
|
1086
1087
|
Args:
|
1087
|
-
sample (`torch.
|
1088
|
+
sample (`torch.Tensor`):
|
1088
1089
|
The input sample.
|
1089
1090
|
|
1090
1091
|
Returns:
|
1091
|
-
`torch.
|
1092
|
+
`torch.Tensor`:
|
1092
1093
|
A scaled input sample.
|
1093
1094
|
"""
|
1094
1095
|
return sample
|
@@ -1096,10 +1097,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
|
1096
1097
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
1097
1098
|
def add_noise(
|
1098
1099
|
self,
|
1099
|
-
original_samples: torch.
|
1100
|
-
noise: torch.
|
1100
|
+
original_samples: torch.Tensor,
|
1101
|
+
noise: torch.Tensor,
|
1101
1102
|
timesteps: torch.IntTensor,
|
1102
|
-
) -> torch.
|
1103
|
+
) -> torch.Tensor:
|
1103
1104
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
1104
1105
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
1105
1106
|
# for the subsequent add_noise calls
|