diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -35,16 +35,16 @@ class DDIMSchedulerOutput(BaseOutput):
|
|
35
35
|
Output class for the scheduler's `step` function output.
|
36
36
|
|
37
37
|
Args:
|
38
|
-
prev_sample (`torch.
|
38
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
39
39
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
40
40
|
denoising loop.
|
41
|
-
pred_original_sample (`torch.
|
41
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
42
42
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
43
43
|
`pred_original_sample` can be used to preview progress or for guidance.
|
44
44
|
"""
|
45
45
|
|
46
|
-
prev_sample: torch.
|
47
|
-
pred_original_sample: Optional[torch.
|
46
|
+
prev_sample: torch.Tensor
|
47
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
48
48
|
|
49
49
|
|
50
50
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -82,7 +82,7 @@ def betas_for_alpha_bar(
|
|
82
82
|
return math.exp(t * -12.0)
|
83
83
|
|
84
84
|
else:
|
85
|
-
raise ValueError(f"Unsupported
|
85
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
86
86
|
|
87
87
|
betas = []
|
88
88
|
for i in range(num_diffusion_timesteps):
|
@@ -98,11 +98,11 @@ def rescale_zero_terminal_snr(betas):
|
|
98
98
|
|
99
99
|
|
100
100
|
Args:
|
101
|
-
betas (`torch.
|
101
|
+
betas (`torch.Tensor`):
|
102
102
|
the betas that the scheduler is being initialized with.
|
103
103
|
|
104
104
|
Returns:
|
105
|
-
`torch.
|
105
|
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
106
106
|
"""
|
107
107
|
# Convert betas to alphas_bar_sqrt
|
108
108
|
alphas = 1.0 - betas
|
@@ -211,7 +211,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
211
211
|
# Glide cosine schedule
|
212
212
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
213
213
|
else:
|
214
|
-
raise NotImplementedError(f"{beta_schedule}
|
214
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
215
215
|
|
216
216
|
# Rescale for zero SNR
|
217
217
|
if rescale_betas_zero_snr:
|
@@ -233,19 +233,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
233
233
|
self.num_inference_steps = None
|
234
234
|
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
235
235
|
|
236
|
-
def scale_model_input(self, sample: torch.
|
236
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
237
237
|
"""
|
238
238
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
239
239
|
current timestep.
|
240
240
|
|
241
241
|
Args:
|
242
|
-
sample (`torch.
|
242
|
+
sample (`torch.Tensor`):
|
243
243
|
The input sample.
|
244
244
|
timestep (`int`, *optional*):
|
245
245
|
The current timestep in the diffusion chain.
|
246
246
|
|
247
247
|
Returns:
|
248
|
-
`torch.
|
248
|
+
`torch.Tensor`:
|
249
249
|
A scaled input sample.
|
250
250
|
"""
|
251
251
|
return sample
|
@@ -261,7 +261,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
261
261
|
return variance
|
262
262
|
|
263
263
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
264
|
-
def _threshold_sample(self, sample: torch.
|
264
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
265
265
|
"""
|
266
266
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
267
267
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -341,13 +341,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
341
341
|
|
342
342
|
def step(
|
343
343
|
self,
|
344
|
-
model_output: torch.
|
344
|
+
model_output: torch.Tensor,
|
345
345
|
timestep: int,
|
346
|
-
sample: torch.
|
346
|
+
sample: torch.Tensor,
|
347
347
|
eta: float = 0.0,
|
348
348
|
use_clipped_model_output: bool = False,
|
349
349
|
generator=None,
|
350
|
-
variance_noise: Optional[torch.
|
350
|
+
variance_noise: Optional[torch.Tensor] = None,
|
351
351
|
return_dict: bool = True,
|
352
352
|
) -> Union[DDIMSchedulerOutput, Tuple]:
|
353
353
|
"""
|
@@ -355,11 +355,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
355
355
|
process from the learned model outputs (most often the predicted noise).
|
356
356
|
|
357
357
|
Args:
|
358
|
-
model_output (`torch.
|
358
|
+
model_output (`torch.Tensor`):
|
359
359
|
The direct output from learned diffusion model.
|
360
360
|
timestep (`float`):
|
361
361
|
The current discrete timestep in the diffusion chain.
|
362
|
-
sample (`torch.
|
362
|
+
sample (`torch.Tensor`):
|
363
363
|
A current instance of a sample created by the diffusion process.
|
364
364
|
eta (`float`):
|
365
365
|
The weight of noise for added noise in diffusion step.
|
@@ -370,7 +370,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
370
370
|
`use_clipped_model_output` has no effect.
|
371
371
|
generator (`torch.Generator`, *optional*):
|
372
372
|
A random number generator.
|
373
|
-
variance_noise (`torch.
|
373
|
+
variance_noise (`torch.Tensor`):
|
374
374
|
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
375
375
|
itself. Useful for methods such as [`CycleDiffusion`].
|
376
376
|
return_dict (`bool`, *optional*, defaults to `True`):
|
@@ -470,10 +470,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
470
470
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
471
471
|
def add_noise(
|
472
472
|
self,
|
473
|
-
original_samples: torch.
|
474
|
-
noise: torch.
|
473
|
+
original_samples: torch.Tensor,
|
474
|
+
noise: torch.Tensor,
|
475
475
|
timesteps: torch.IntTensor,
|
476
|
-
) -> torch.
|
476
|
+
) -> torch.Tensor:
|
477
477
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
478
478
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
479
479
|
# for the subsequent add_noise calls
|
@@ -495,9 +495,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|
495
495
|
return noisy_samples
|
496
496
|
|
497
497
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
498
|
-
def get_velocity(
|
499
|
-
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
500
|
-
) -> torch.FloatTensor:
|
498
|
+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
501
499
|
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
502
500
|
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
503
501
|
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
@@ -85,7 +85,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
85
85
|
trained_betas (`jnp.ndarray`, optional):
|
86
86
|
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
87
87
|
clip_sample (`bool`, default `True`):
|
88
|
-
option to clip predicted sample between for numerical stability. The clip range is determined by
|
88
|
+
option to clip predicted sample between for numerical stability. The clip range is determined by
|
89
|
+
`clip_sample_range`.
|
89
90
|
clip_sample_range (`float`, default `1.0`):
|
90
91
|
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
91
92
|
set_alpha_to_one (`bool`, default `True`):
|
@@ -33,16 +33,16 @@ class DDIMSchedulerOutput(BaseOutput):
|
|
33
33
|
Output class for the scheduler's `step` function output.
|
34
34
|
|
35
35
|
Args:
|
36
|
-
prev_sample (`torch.
|
36
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
37
37
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
38
38
|
denoising loop.
|
39
|
-
pred_original_sample (`torch.
|
39
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
40
40
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
41
41
|
`pred_original_sample` can be used to preview progress or for guidance.
|
42
42
|
"""
|
43
43
|
|
44
|
-
prev_sample: torch.
|
45
|
-
pred_original_sample: Optional[torch.
|
44
|
+
prev_sample: torch.Tensor
|
45
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
46
46
|
|
47
47
|
|
48
48
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -80,7 +80,7 @@ def betas_for_alpha_bar(
|
|
80
80
|
return math.exp(t * -12.0)
|
81
81
|
|
82
82
|
else:
|
83
|
-
raise ValueError(f"Unsupported
|
83
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
84
84
|
|
85
85
|
betas = []
|
86
86
|
for i in range(num_diffusion_timesteps):
|
@@ -97,11 +97,11 @@ def rescale_zero_terminal_snr(betas):
|
|
97
97
|
|
98
98
|
|
99
99
|
Args:
|
100
|
-
betas (`torch.
|
100
|
+
betas (`torch.Tensor`):
|
101
101
|
the betas that the scheduler is being initialized with.
|
102
102
|
|
103
103
|
Returns:
|
104
|
-
`torch.
|
104
|
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
105
105
|
"""
|
106
106
|
# Convert betas to alphas_bar_sqrt
|
107
107
|
alphas = 1.0 - betas
|
@@ -207,7 +207,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
207
207
|
# Glide cosine schedule
|
208
208
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
209
209
|
else:
|
210
|
-
raise NotImplementedError(f"{beta_schedule}
|
210
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
211
211
|
|
212
212
|
# Rescale for zero SNR
|
213
213
|
if rescale_betas_zero_snr:
|
@@ -231,19 +231,19 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
231
231
|
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))
|
232
232
|
|
233
233
|
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
|
234
|
-
def scale_model_input(self, sample: torch.
|
234
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
235
235
|
"""
|
236
236
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
237
237
|
current timestep.
|
238
238
|
|
239
239
|
Args:
|
240
|
-
sample (`torch.
|
240
|
+
sample (`torch.Tensor`):
|
241
241
|
The input sample.
|
242
242
|
timestep (`int`, *optional*):
|
243
243
|
The current timestep in the diffusion chain.
|
244
244
|
|
245
245
|
Returns:
|
246
|
-
`torch.
|
246
|
+
`torch.Tensor`:
|
247
247
|
A scaled input sample.
|
248
248
|
"""
|
249
249
|
return sample
|
@@ -288,9 +288,9 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
288
288
|
|
289
289
|
def step(
|
290
290
|
self,
|
291
|
-
model_output: torch.
|
291
|
+
model_output: torch.Tensor,
|
292
292
|
timestep: int,
|
293
|
-
sample: torch.
|
293
|
+
sample: torch.Tensor,
|
294
294
|
return_dict: bool = True,
|
295
295
|
) -> Union[DDIMSchedulerOutput, Tuple]:
|
296
296
|
"""
|
@@ -298,11 +298,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
298
298
|
process from the learned model outputs (most often the predicted noise).
|
299
299
|
|
300
300
|
Args:
|
301
|
-
model_output (`torch.
|
301
|
+
model_output (`torch.Tensor`):
|
302
302
|
The direct output from learned diffusion model.
|
303
303
|
timestep (`float`):
|
304
304
|
The current discrete timestep in the diffusion chain.
|
305
|
-
sample (`torch.
|
305
|
+
sample (`torch.Tensor`):
|
306
306
|
A current instance of a sample created by the diffusion process.
|
307
307
|
eta (`float`):
|
308
308
|
The weight of noise for added noise in diffusion step.
|
@@ -311,7 +311,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
|
311
311
|
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
312
312
|
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
313
313
|
`use_clipped_model_output` has no effect.
|
314
|
-
variance_noise (`torch.
|
314
|
+
variance_noise (`torch.Tensor`):
|
315
315
|
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
316
316
|
itself. Useful for methods such as [`CycleDiffusion`].
|
317
317
|
return_dict (`bool`, *optional*, defaults to `True`):
|
@@ -35,16 +35,16 @@ class DDIMParallelSchedulerOutput(BaseOutput):
|
|
35
35
|
Output class for the scheduler's `step` function output.
|
36
36
|
|
37
37
|
Args:
|
38
|
-
prev_sample (`torch.
|
38
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
39
39
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
40
40
|
denoising loop.
|
41
|
-
pred_original_sample (`torch.
|
41
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
42
42
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
43
43
|
`pred_original_sample` can be used to preview progress or for guidance.
|
44
44
|
"""
|
45
45
|
|
46
|
-
prev_sample: torch.
|
47
|
-
pred_original_sample: Optional[torch.
|
46
|
+
prev_sample: torch.Tensor
|
47
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
48
48
|
|
49
49
|
|
50
50
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -82,7 +82,7 @@ def betas_for_alpha_bar(
|
|
82
82
|
return math.exp(t * -12.0)
|
83
83
|
|
84
84
|
else:
|
85
|
-
raise ValueError(f"Unsupported
|
85
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
86
86
|
|
87
87
|
betas = []
|
88
88
|
for i in range(num_diffusion_timesteps):
|
@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas):
|
|
99
99
|
|
100
100
|
|
101
101
|
Args:
|
102
|
-
betas (`torch.
|
102
|
+
betas (`torch.Tensor`):
|
103
103
|
the betas that the scheduler is being initialized with.
|
104
104
|
|
105
105
|
Returns:
|
106
|
-
`torch.
|
106
|
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
107
107
|
"""
|
108
108
|
# Convert betas to alphas_bar_sqrt
|
109
109
|
alphas = 1.0 - betas
|
@@ -218,7 +218,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
218
218
|
# Glide cosine schedule
|
219
219
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
220
220
|
else:
|
221
|
-
raise NotImplementedError(f"{beta_schedule}
|
221
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
222
222
|
|
223
223
|
# Rescale for zero SNR
|
224
224
|
if rescale_betas_zero_snr:
|
@@ -241,19 +241,19 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
241
241
|
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
242
242
|
|
243
243
|
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
|
244
|
-
def scale_model_input(self, sample: torch.
|
244
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
245
245
|
"""
|
246
246
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
247
247
|
current timestep.
|
248
248
|
|
249
249
|
Args:
|
250
|
-
sample (`torch.
|
250
|
+
sample (`torch.Tensor`):
|
251
251
|
The input sample.
|
252
252
|
timestep (`int`, *optional*):
|
253
253
|
The current timestep in the diffusion chain.
|
254
254
|
|
255
255
|
Returns:
|
256
|
-
`torch.
|
256
|
+
`torch.Tensor`:
|
257
257
|
A scaled input sample.
|
258
258
|
"""
|
259
259
|
return sample
|
@@ -283,7 +283,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
283
283
|
return variance
|
284
284
|
|
285
285
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
286
|
-
def _threshold_sample(self, sample: torch.
|
286
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
287
287
|
"""
|
288
288
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
289
289
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -364,13 +364,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
364
364
|
|
365
365
|
def step(
|
366
366
|
self,
|
367
|
-
model_output: torch.
|
367
|
+
model_output: torch.Tensor,
|
368
368
|
timestep: int,
|
369
|
-
sample: torch.
|
369
|
+
sample: torch.Tensor,
|
370
370
|
eta: float = 0.0,
|
371
371
|
use_clipped_model_output: bool = False,
|
372
372
|
generator=None,
|
373
|
-
variance_noise: Optional[torch.
|
373
|
+
variance_noise: Optional[torch.Tensor] = None,
|
374
374
|
return_dict: bool = True,
|
375
375
|
) -> Union[DDIMParallelSchedulerOutput, Tuple]:
|
376
376
|
"""
|
@@ -378,9 +378,9 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
378
378
|
process from the learned model outputs (most often the predicted noise).
|
379
379
|
|
380
380
|
Args:
|
381
|
-
model_output (`torch.
|
381
|
+
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
382
382
|
timestep (`int`): current discrete timestep in the diffusion chain.
|
383
|
-
sample (`torch.
|
383
|
+
sample (`torch.Tensor`):
|
384
384
|
current instance of sample being created by diffusion process.
|
385
385
|
eta (`float`): weight of noise for added noise in diffusion step.
|
386
386
|
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
|
@@ -388,7 +388,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
388
388
|
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
|
389
389
|
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
|
390
390
|
generator: random number generator.
|
391
|
-
variance_noise (`torch.
|
391
|
+
variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
|
392
392
|
can directly provide the noise for the variance itself. This is useful for methods such as
|
393
393
|
CycleDiffusion. (https://arxiv.org/abs/2210.05559)
|
394
394
|
return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class
|
@@ -486,12 +486,12 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
486
486
|
|
487
487
|
def batch_step_no_noise(
|
488
488
|
self,
|
489
|
-
model_output: torch.
|
489
|
+
model_output: torch.Tensor,
|
490
490
|
timesteps: List[int],
|
491
|
-
sample: torch.
|
491
|
+
sample: torch.Tensor,
|
492
492
|
eta: float = 0.0,
|
493
493
|
use_clipped_model_output: bool = False,
|
494
|
-
) -> torch.
|
494
|
+
) -> torch.Tensor:
|
495
495
|
"""
|
496
496
|
Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once.
|
497
497
|
Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise
|
@@ -501,10 +501,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
501
501
|
process from the learned model outputs (most often the predicted noise).
|
502
502
|
|
503
503
|
Args:
|
504
|
-
model_output (`torch.
|
504
|
+
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
505
505
|
timesteps (`List[int]`):
|
506
506
|
current discrete timesteps in the diffusion chain. This is now a list of integers.
|
507
|
-
sample (`torch.
|
507
|
+
sample (`torch.Tensor`):
|
508
508
|
current instance of sample being created by diffusion process.
|
509
509
|
eta (`float`): weight of noise for added noise in diffusion step.
|
510
510
|
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
|
@@ -513,7 +513,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
513
513
|
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
|
514
514
|
|
515
515
|
Returns:
|
516
|
-
`torch.
|
516
|
+
`torch.Tensor`: sample tensor at previous timestep.
|
517
517
|
|
518
518
|
"""
|
519
519
|
if self.num_inference_steps is None:
|
@@ -595,10 +595,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
595
595
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
596
596
|
def add_noise(
|
597
597
|
self,
|
598
|
-
original_samples: torch.
|
599
|
-
noise: torch.
|
598
|
+
original_samples: torch.Tensor,
|
599
|
+
noise: torch.Tensor,
|
600
600
|
timesteps: torch.IntTensor,
|
601
|
-
) -> torch.
|
601
|
+
) -> torch.Tensor:
|
602
602
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
603
603
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
604
604
|
# for the subsequent add_noise calls
|
@@ -620,9 +620,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
|
620
620
|
return noisy_samples
|
621
621
|
|
622
622
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
623
|
-
def get_velocity(
|
624
|
-
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
625
|
-
) -> torch.FloatTensor:
|
623
|
+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
626
624
|
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
627
625
|
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
628
626
|
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
@@ -33,16 +33,16 @@ class DDPMSchedulerOutput(BaseOutput):
|
|
33
33
|
Output class for the scheduler's `step` function output.
|
34
34
|
|
35
35
|
Args:
|
36
|
-
prev_sample (`torch.
|
36
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
37
37
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
38
38
|
denoising loop.
|
39
|
-
pred_original_sample (`torch.
|
39
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
40
40
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
41
41
|
`pred_original_sample` can be used to preview progress or for guidance.
|
42
42
|
"""
|
43
43
|
|
44
|
-
prev_sample: torch.
|
45
|
-
pred_original_sample: Optional[torch.
|
44
|
+
prev_sample: torch.Tensor
|
45
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
46
46
|
|
47
47
|
|
48
48
|
def betas_for_alpha_bar(
|
@@ -79,7 +79,7 @@ def betas_for_alpha_bar(
|
|
79
79
|
return math.exp(t * -12.0)
|
80
80
|
|
81
81
|
else:
|
82
|
-
raise ValueError(f"Unsupported
|
82
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
83
83
|
|
84
84
|
betas = []
|
85
85
|
for i in range(num_diffusion_timesteps):
|
@@ -96,11 +96,11 @@ def rescale_zero_terminal_snr(betas):
|
|
96
96
|
|
97
97
|
|
98
98
|
Args:
|
99
|
-
betas (`torch.
|
99
|
+
betas (`torch.Tensor`):
|
100
100
|
the betas that the scheduler is being initialized with.
|
101
101
|
|
102
102
|
Returns:
|
103
|
-
`torch.
|
103
|
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
104
104
|
"""
|
105
105
|
# Convert betas to alphas_bar_sqrt
|
106
106
|
alphas = 1.0 - betas
|
@@ -211,7 +211,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
211
211
|
betas = torch.linspace(-6, 6, num_train_timesteps)
|
212
212
|
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
213
213
|
else:
|
214
|
-
raise NotImplementedError(f"{beta_schedule}
|
214
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
215
215
|
|
216
216
|
# Rescale for zero SNR
|
217
217
|
if rescale_betas_zero_snr:
|
@@ -231,19 +231,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
231
231
|
|
232
232
|
self.variance_type = variance_type
|
233
233
|
|
234
|
-
def scale_model_input(self, sample: torch.
|
234
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
235
235
|
"""
|
236
236
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
237
237
|
current timestep.
|
238
238
|
|
239
239
|
Args:
|
240
|
-
sample (`torch.
|
240
|
+
sample (`torch.Tensor`):
|
241
241
|
The input sample.
|
242
242
|
timestep (`int`, *optional*):
|
243
243
|
The current timestep in the diffusion chain.
|
244
244
|
|
245
245
|
Returns:
|
246
|
-
`torch.
|
246
|
+
`torch.Tensor`:
|
247
247
|
A scaled input sample.
|
248
248
|
"""
|
249
249
|
return sample
|
@@ -363,7 +363,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
363
363
|
|
364
364
|
return variance
|
365
365
|
|
366
|
-
def _threshold_sample(self, sample: torch.
|
366
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
367
367
|
"""
|
368
368
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
369
369
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -398,9 +398,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
398
398
|
|
399
399
|
def step(
|
400
400
|
self,
|
401
|
-
model_output: torch.
|
401
|
+
model_output: torch.Tensor,
|
402
402
|
timestep: int,
|
403
|
-
sample: torch.
|
403
|
+
sample: torch.Tensor,
|
404
404
|
generator=None,
|
405
405
|
return_dict: bool = True,
|
406
406
|
) -> Union[DDPMSchedulerOutput, Tuple]:
|
@@ -409,11 +409,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
409
409
|
process from the learned model outputs (most often the predicted noise).
|
410
410
|
|
411
411
|
Args:
|
412
|
-
model_output (`torch.
|
412
|
+
model_output (`torch.Tensor`):
|
413
413
|
The direct output from learned diffusion model.
|
414
414
|
timestep (`float`):
|
415
415
|
The current discrete timestep in the diffusion chain.
|
416
|
-
sample (`torch.
|
416
|
+
sample (`torch.Tensor`):
|
417
417
|
A current instance of a sample created by the diffusion process.
|
418
418
|
generator (`torch.Generator`, *optional*):
|
419
419
|
A random number generator.
|
@@ -498,10 +498,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
498
498
|
|
499
499
|
def add_noise(
|
500
500
|
self,
|
501
|
-
original_samples: torch.
|
502
|
-
noise: torch.
|
501
|
+
original_samples: torch.Tensor,
|
502
|
+
noise: torch.Tensor,
|
503
503
|
timesteps: torch.IntTensor,
|
504
|
-
) -> torch.
|
504
|
+
) -> torch.Tensor:
|
505
505
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
506
506
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
507
507
|
# for the subsequent add_noise calls
|
@@ -522,9 +522,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|
522
522
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
523
523
|
return noisy_samples
|
524
524
|
|
525
|
-
def get_velocity(
|
526
|
-
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
527
|
-
) -> torch.FloatTensor:
|
525
|
+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
528
526
|
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
529
527
|
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
530
528
|
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
@@ -222,9 +222,13 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
222
222
|
t = timestep
|
223
223
|
|
224
224
|
if key is None:
|
225
|
-
key = jax.random.
|
225
|
+
key = jax.random.key(0)
|
226
226
|
|
227
|
-
if
|
227
|
+
if (
|
228
|
+
len(model_output.shape) > 1
|
229
|
+
and model_output.shape[1] == sample.shape[1] * 2
|
230
|
+
and self.config.variance_type in ["learned", "learned_range"]
|
231
|
+
):
|
228
232
|
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
|
229
233
|
else:
|
230
234
|
predicted_variance = None
|
@@ -264,7 +268,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
264
268
|
|
265
269
|
# 6. Add noise
|
266
270
|
def random_variance():
|
267
|
-
split_key = jax.random.split(key, num=1)
|
271
|
+
split_key = jax.random.split(key, num=1)[0]
|
268
272
|
noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)
|
269
273
|
return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise
|
270
274
|
|