diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -57,7 +57,7 @@ def betas_for_alpha_bar(
|
|
57
57
|
return math.exp(t * -12.0)
|
58
58
|
|
59
59
|
else:
|
60
|
-
raise ValueError(f"Unsupported
|
60
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
61
61
|
|
62
62
|
betas = []
|
63
63
|
for i in range(num_diffusion_timesteps):
|
@@ -135,7 +135,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
135
135
|
elif beta_schedule == "exp":
|
136
136
|
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp")
|
137
137
|
else:
|
138
|
-
raise NotImplementedError(f"{beta_schedule}
|
138
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
139
139
|
|
140
140
|
self.alphas = 1.0 - self.betas
|
141
141
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
@@ -174,7 +174,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
174
174
|
@property
|
175
175
|
def step_index(self):
|
176
176
|
"""
|
177
|
-
The index counter for current timestep. It will
|
177
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
178
178
|
"""
|
179
179
|
return self._step_index
|
180
180
|
|
@@ -198,21 +198,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
198
198
|
|
199
199
|
def scale_model_input(
|
200
200
|
self,
|
201
|
-
sample: torch.
|
202
|
-
timestep: Union[float, torch.
|
203
|
-
) -> torch.
|
201
|
+
sample: torch.Tensor,
|
202
|
+
timestep: Union[float, torch.Tensor],
|
203
|
+
) -> torch.Tensor:
|
204
204
|
"""
|
205
205
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
206
206
|
current timestep.
|
207
207
|
|
208
208
|
Args:
|
209
|
-
sample (`torch.
|
209
|
+
sample (`torch.Tensor`):
|
210
210
|
The input sample.
|
211
211
|
timestep (`int`, *optional*):
|
212
212
|
The current timestep in the diffusion chain.
|
213
213
|
|
214
214
|
Returns:
|
215
|
-
`torch.
|
215
|
+
`torch.Tensor`:
|
216
216
|
A scaled input sample.
|
217
217
|
"""
|
218
218
|
if self.step_index is None:
|
@@ -224,9 +224,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
224
224
|
|
225
225
|
def set_timesteps(
|
226
226
|
self,
|
227
|
-
num_inference_steps: int,
|
227
|
+
num_inference_steps: Optional[int] = None,
|
228
228
|
device: Union[str, torch.device] = None,
|
229
229
|
num_train_timesteps: Optional[int] = None,
|
230
|
+
timesteps: Optional[List[int]] = None,
|
230
231
|
):
|
231
232
|
"""
|
232
233
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
@@ -236,30 +237,47 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
236
237
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
237
238
|
device (`str` or `torch.device`, *optional*):
|
238
239
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
240
|
+
num_train_timesteps (`int`, *optional*):
|
241
|
+
The number of diffusion steps used when training the model. If `None`, the default
|
242
|
+
`num_train_timesteps` attribute is used.
|
243
|
+
timesteps (`List[int]`, *optional*):
|
244
|
+
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be
|
245
|
+
generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps`
|
246
|
+
must be `None`, and `timestep_spacing` attribute will be ignored.
|
239
247
|
"""
|
248
|
+
if num_inference_steps is None and timesteps is None:
|
249
|
+
raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
|
250
|
+
if num_inference_steps is not None and timesteps is not None:
|
251
|
+
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
252
|
+
if timesteps is not None and self.config.use_karras_sigmas:
|
253
|
+
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
|
254
|
+
|
255
|
+
num_inference_steps = num_inference_steps or len(timesteps)
|
240
256
|
self.num_inference_steps = num_inference_steps
|
241
|
-
|
242
257
|
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
243
258
|
|
244
|
-
|
245
|
-
|
246
|
-
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
|
247
|
-
elif self.config.timestep_spacing == "leading":
|
248
|
-
step_ratio = num_train_timesteps // self.num_inference_steps
|
249
|
-
# creates integer timesteps by multiplying by ratio
|
250
|
-
# casting to int to avoid issues when num_inference_step is power of 3
|
251
|
-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
252
|
-
timesteps += self.config.steps_offset
|
253
|
-
elif self.config.timestep_spacing == "trailing":
|
254
|
-
step_ratio = num_train_timesteps / self.num_inference_steps
|
255
|
-
# creates integer timesteps by multiplying by ratio
|
256
|
-
# casting to int to avoid issues when num_inference_step is power of 3
|
257
|
-
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
258
|
-
timesteps -= 1
|
259
|
+
if timesteps is not None:
|
260
|
+
timesteps = np.array(timesteps, dtype=np.float32)
|
259
261
|
else:
|
260
|
-
|
261
|
-
|
262
|
-
|
262
|
+
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
263
|
+
if self.config.timestep_spacing == "linspace":
|
264
|
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
|
265
|
+
elif self.config.timestep_spacing == "leading":
|
266
|
+
step_ratio = num_train_timesteps // self.num_inference_steps
|
267
|
+
# creates integer timesteps by multiplying by ratio
|
268
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
269
|
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
270
|
+
timesteps += self.config.steps_offset
|
271
|
+
elif self.config.timestep_spacing == "trailing":
|
272
|
+
step_ratio = num_train_timesteps / self.num_inference_steps
|
273
|
+
# creates integer timesteps by multiplying by ratio
|
274
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
275
|
+
timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
276
|
+
timesteps -= 1
|
277
|
+
else:
|
278
|
+
raise ValueError(
|
279
|
+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
280
|
+
)
|
263
281
|
|
264
282
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
265
283
|
log_sigmas = np.log(sigmas)
|
@@ -311,7 +329,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
311
329
|
return t
|
312
330
|
|
313
331
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
314
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
332
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
315
333
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
316
334
|
|
317
335
|
# Hack to make sure that other schedulers which copy this function don't break
|
@@ -351,9 +369,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
351
369
|
|
352
370
|
def step(
|
353
371
|
self,
|
354
|
-
model_output: Union[torch.
|
355
|
-
timestep: Union[float, torch.
|
356
|
-
sample: Union[torch.
|
372
|
+
model_output: Union[torch.Tensor, np.ndarray],
|
373
|
+
timestep: Union[float, torch.Tensor],
|
374
|
+
sample: Union[torch.Tensor, np.ndarray],
|
357
375
|
return_dict: bool = True,
|
358
376
|
) -> Union[SchedulerOutput, Tuple]:
|
359
377
|
"""
|
@@ -361,11 +379,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
361
379
|
process from the learned model outputs (most often the predicted noise).
|
362
380
|
|
363
381
|
Args:
|
364
|
-
model_output (`torch.
|
382
|
+
model_output (`torch.Tensor`):
|
365
383
|
The direct output from learned diffusion model.
|
366
384
|
timestep (`float`):
|
367
385
|
The current discrete timestep in the diffusion chain.
|
368
|
-
sample (`torch.
|
386
|
+
sample (`torch.Tensor`):
|
369
387
|
A current instance of a sample created by the diffusion process.
|
370
388
|
return_dict (`bool`):
|
371
389
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
@@ -451,10 +469,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
451
469
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
452
470
|
def add_noise(
|
453
471
|
self,
|
454
|
-
original_samples: torch.
|
455
|
-
noise: torch.
|
456
|
-
timesteps: torch.
|
457
|
-
) -> torch.
|
472
|
+
original_samples: torch.Tensor,
|
473
|
+
noise: torch.Tensor,
|
474
|
+
timesteps: torch.Tensor,
|
475
|
+
) -> torch.Tensor:
|
458
476
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
459
477
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
460
478
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -468,7 +486,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
468
486
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
469
487
|
if self.begin_index is None:
|
470
488
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
489
|
+
elif self.step_index is not None:
|
490
|
+
# add_noise is called after first denoising step (for inpainting)
|
491
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
471
492
|
else:
|
493
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
472
494
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
473
495
|
|
474
496
|
sigma = sigmas[step_indices].flatten()
|
@@ -61,7 +61,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
61
61
|
@property
|
62
62
|
def step_index(self):
|
63
63
|
"""
|
64
|
-
The index counter for current timestep. It will
|
64
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
65
65
|
"""
|
66
66
|
return self._step_index
|
67
67
|
|
@@ -137,9 +137,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
137
137
|
|
138
138
|
def step(
|
139
139
|
self,
|
140
|
-
model_output: torch.
|
140
|
+
model_output: torch.Tensor,
|
141
141
|
timestep: int,
|
142
|
-
sample: torch.
|
142
|
+
sample: torch.Tensor,
|
143
143
|
return_dict: bool = True,
|
144
144
|
) -> Union[SchedulerOutput, Tuple]:
|
145
145
|
"""
|
@@ -147,11 +147,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
147
147
|
the linear multistep method. It performs one forward pass multiple times to approximate the solution.
|
148
148
|
|
149
149
|
Args:
|
150
|
-
model_output (`torch.
|
150
|
+
model_output (`torch.Tensor`):
|
151
151
|
The direct output from learned diffusion model.
|
152
152
|
timestep (`int`):
|
153
153
|
The current discrete timestep in the diffusion chain.
|
154
|
-
sample (`torch.
|
154
|
+
sample (`torch.Tensor`):
|
155
155
|
A current instance of a sample created by the diffusion process.
|
156
156
|
return_dict (`bool`):
|
157
157
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
@@ -193,17 +193,17 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
193
193
|
|
194
194
|
return SchedulerOutput(prev_sample=prev_sample)
|
195
195
|
|
196
|
-
def scale_model_input(self, sample: torch.
|
196
|
+
def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
197
197
|
"""
|
198
198
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
199
199
|
current timestep.
|
200
200
|
|
201
201
|
Args:
|
202
|
-
sample (`torch.
|
202
|
+
sample (`torch.Tensor`):
|
203
203
|
The input sample.
|
204
204
|
|
205
205
|
Returns:
|
206
|
-
`torch.
|
206
|
+
`torch.Tensor`:
|
207
207
|
A scaled input sample.
|
208
208
|
"""
|
209
209
|
return sample
|
@@ -58,7 +58,7 @@ def betas_for_alpha_bar(
|
|
58
58
|
return math.exp(t * -12.0)
|
59
59
|
|
60
60
|
else:
|
61
|
-
raise ValueError(f"Unsupported
|
61
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
62
62
|
|
63
63
|
betas = []
|
64
64
|
for i in range(num_diffusion_timesteps):
|
@@ -129,7 +129,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
129
129
|
# Glide cosine schedule
|
130
130
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
131
131
|
else:
|
132
|
-
raise NotImplementedError(f"{beta_schedule}
|
132
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
133
133
|
|
134
134
|
self.alphas = 1.0 - self.betas
|
135
135
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
@@ -151,7 +151,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
151
151
|
@property
|
152
152
|
def step_index(self):
|
153
153
|
"""
|
154
|
-
The index counter for current timestep. It will
|
154
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
155
155
|
"""
|
156
156
|
return self._step_index
|
157
157
|
|
@@ -175,21 +175,21 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
175
175
|
|
176
176
|
def scale_model_input(
|
177
177
|
self,
|
178
|
-
sample: torch.
|
179
|
-
timestep: Union[float, torch.
|
180
|
-
) -> torch.
|
178
|
+
sample: torch.Tensor,
|
179
|
+
timestep: Union[float, torch.Tensor],
|
180
|
+
) -> torch.Tensor:
|
181
181
|
"""
|
182
182
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
183
183
|
current timestep.
|
184
184
|
|
185
185
|
Args:
|
186
|
-
sample (`torch.
|
186
|
+
sample (`torch.Tensor`):
|
187
187
|
The input sample.
|
188
188
|
timestep (`int`, *optional*):
|
189
189
|
The current timestep in the diffusion chain.
|
190
190
|
|
191
191
|
Returns:
|
192
|
-
`torch.
|
192
|
+
`torch.Tensor`:
|
193
193
|
A scaled input sample.
|
194
194
|
"""
|
195
195
|
if self.step_index is None:
|
@@ -321,7 +321,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
321
321
|
return t
|
322
322
|
|
323
323
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
324
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
324
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
325
325
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
326
326
|
|
327
327
|
# Hack to make sure that other schedulers which copy this function don't break
|
@@ -376,9 +376,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
376
376
|
|
377
377
|
def step(
|
378
378
|
self,
|
379
|
-
model_output: Union[torch.
|
380
|
-
timestep: Union[float, torch.
|
381
|
-
sample: Union[torch.
|
379
|
+
model_output: Union[torch.Tensor, np.ndarray],
|
380
|
+
timestep: Union[float, torch.Tensor],
|
381
|
+
sample: Union[torch.Tensor, np.ndarray],
|
382
382
|
generator: Optional[torch.Generator] = None,
|
383
383
|
return_dict: bool = True,
|
384
384
|
) -> Union[SchedulerOutput, Tuple]:
|
@@ -387,11 +387,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
387
387
|
process from the learned model outputs (most often the predicted noise).
|
388
388
|
|
389
389
|
Args:
|
390
|
-
model_output (`torch.
|
390
|
+
model_output (`torch.Tensor`):
|
391
391
|
The direct output from learned diffusion model.
|
392
392
|
timestep (`float`):
|
393
393
|
The current discrete timestep in the diffusion chain.
|
394
|
-
sample (`torch.
|
394
|
+
sample (`torch.Tensor`):
|
395
395
|
A current instance of a sample created by the diffusion process.
|
396
396
|
generator (`torch.Generator`, *optional*):
|
397
397
|
A random number generator.
|
@@ -477,10 +477,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
477
477
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
478
478
|
def add_noise(
|
479
479
|
self,
|
480
|
-
original_samples: torch.
|
481
|
-
noise: torch.
|
482
|
-
timesteps: torch.
|
483
|
-
) -> torch.
|
480
|
+
original_samples: torch.Tensor,
|
481
|
+
noise: torch.Tensor,
|
482
|
+
timesteps: torch.Tensor,
|
483
|
+
) -> torch.Tensor:
|
484
484
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
485
485
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
486
486
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -494,7 +494,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
494
494
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
495
495
|
if self.begin_index is None:
|
496
496
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
497
|
+
elif self.step_index is not None:
|
498
|
+
# add_noise is called after first denoising step (for inpainting)
|
499
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
497
500
|
else:
|
501
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
498
502
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
499
503
|
|
500
504
|
sigma = sigmas[step_indices].flatten()
|
@@ -57,7 +57,7 @@ def betas_for_alpha_bar(
|
|
57
57
|
return math.exp(t * -12.0)
|
58
58
|
|
59
59
|
else:
|
60
|
-
raise ValueError(f"Unsupported
|
60
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
61
61
|
|
62
62
|
betas = []
|
63
63
|
for i in range(num_diffusion_timesteps):
|
@@ -128,7 +128,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
128
128
|
# Glide cosine schedule
|
129
129
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
130
130
|
else:
|
131
|
-
raise NotImplementedError(f"{beta_schedule}
|
131
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
132
132
|
|
133
133
|
self.alphas = 1.0 - self.betas
|
134
134
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
@@ -151,7 +151,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
151
151
|
@property
|
152
152
|
def step_index(self):
|
153
153
|
"""
|
154
|
-
The index counter for current timestep. It will
|
154
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
155
155
|
"""
|
156
156
|
return self._step_index
|
157
157
|
|
@@ -175,21 +175,21 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
175
175
|
|
176
176
|
def scale_model_input(
|
177
177
|
self,
|
178
|
-
sample: torch.
|
179
|
-
timestep: Union[float, torch.
|
180
|
-
) -> torch.
|
178
|
+
sample: torch.Tensor,
|
179
|
+
timestep: Union[float, torch.Tensor],
|
180
|
+
) -> torch.Tensor:
|
181
181
|
"""
|
182
182
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
183
183
|
current timestep.
|
184
184
|
|
185
185
|
Args:
|
186
|
-
sample (`torch.
|
186
|
+
sample (`torch.Tensor`):
|
187
187
|
The input sample.
|
188
188
|
timestep (`int`, *optional*):
|
189
189
|
The current timestep in the diffusion chain.
|
190
190
|
|
191
191
|
Returns:
|
192
|
-
`torch.
|
192
|
+
`torch.Tensor`:
|
193
193
|
A scaled input sample.
|
194
194
|
"""
|
195
195
|
if self.step_index is None:
|
@@ -334,7 +334,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
334
334
|
return t
|
335
335
|
|
336
336
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
337
|
-
def _convert_to_karras(self, in_sigmas: torch.
|
337
|
+
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
338
338
|
"""Constructs the noise schedule of Karras et al. (2022)."""
|
339
339
|
|
340
340
|
# Hack to make sure that other schedulers which copy this function don't break
|
@@ -361,9 +361,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
361
361
|
|
362
362
|
def step(
|
363
363
|
self,
|
364
|
-
model_output: Union[torch.
|
365
|
-
timestep: Union[float, torch.
|
366
|
-
sample: Union[torch.
|
364
|
+
model_output: Union[torch.Tensor, np.ndarray],
|
365
|
+
timestep: Union[float, torch.Tensor],
|
366
|
+
sample: Union[torch.Tensor, np.ndarray],
|
367
367
|
return_dict: bool = True,
|
368
368
|
) -> Union[SchedulerOutput, Tuple]:
|
369
369
|
"""
|
@@ -371,11 +371,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
371
371
|
process from the learned model outputs (most often the predicted noise).
|
372
372
|
|
373
373
|
Args:
|
374
|
-
model_output (`torch.
|
374
|
+
model_output (`torch.Tensor`):
|
375
375
|
The direct output from learned diffusion model.
|
376
376
|
timestep (`float`):
|
377
377
|
The current discrete timestep in the diffusion chain.
|
378
|
-
sample (`torch.
|
378
|
+
sample (`torch.Tensor`):
|
379
379
|
A current instance of a sample created by the diffusion process.
|
380
380
|
return_dict (`bool`):
|
381
381
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
@@ -452,10 +452,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
452
452
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
453
453
|
def add_noise(
|
454
454
|
self,
|
455
|
-
original_samples: torch.
|
456
|
-
noise: torch.
|
457
|
-
timesteps: torch.
|
458
|
-
) -> torch.
|
455
|
+
original_samples: torch.Tensor,
|
456
|
+
noise: torch.Tensor,
|
457
|
+
timesteps: torch.Tensor,
|
458
|
+
) -> torch.Tensor:
|
459
459
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
460
460
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
461
461
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -469,7 +469,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
469
469
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
470
470
|
if self.begin_index is None:
|
471
471
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
472
|
+
elif self.step_index is not None:
|
473
|
+
# add_noise is called after first denoising step (for inpainting)
|
474
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
472
475
|
else:
|
476
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
473
477
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
474
478
|
|
475
479
|
sigma = sigmas[step_indices].flatten()
|
@@ -176,10 +176,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
176
176
|
|
177
177
|
Args:
|
178
178
|
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
|
179
|
-
model_output (`torch.
|
179
|
+
model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model.
|
180
180
|
sigma_hat (`float`): TODO
|
181
181
|
sigma_prev (`float`): TODO
|
182
|
-
sample_hat (`torch.
|
182
|
+
sample_hat (`torch.Tensor` or `np.ndarray`): TODO
|
183
183
|
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
|
184
184
|
|
185
185
|
Returns:
|
@@ -213,12 +213,12 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|
213
213
|
|
214
214
|
Args:
|
215
215
|
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
|
216
|
-
model_output (`torch.
|
216
|
+
model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model.
|
217
217
|
sigma_hat (`float`): TODO
|
218
218
|
sigma_prev (`float`): TODO
|
219
|
-
sample_hat (`torch.
|
220
|
-
sample_prev (`torch.
|
221
|
-
derivative (`torch.
|
219
|
+
sample_hat (`torch.Tensor` or `np.ndarray`): TODO
|
220
|
+
sample_prev (`torch.Tensor` or `np.ndarray`): TODO
|
221
|
+
derivative (`torch.Tensor` or `np.ndarray`): TODO
|
222
222
|
return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
|
223
223
|
|
224
224
|
Returns:
|
@@ -37,16 +37,16 @@ class LCMSchedulerOutput(BaseOutput):
|
|
37
37
|
Output class for the scheduler's `step` function output.
|
38
38
|
|
39
39
|
Args:
|
40
|
-
prev_sample (`torch.
|
40
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41
41
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
42
42
|
denoising loop.
|
43
|
-
pred_original_sample (`torch.
|
43
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
44
44
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
45
45
|
`pred_original_sample` can be used to preview progress or for guidance.
|
46
46
|
"""
|
47
47
|
|
48
|
-
prev_sample: torch.
|
49
|
-
denoised: Optional[torch.
|
48
|
+
prev_sample: torch.Tensor
|
49
|
+
denoised: Optional[torch.Tensor] = None
|
50
50
|
|
51
51
|
|
52
52
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -84,7 +84,7 @@ def betas_for_alpha_bar(
|
|
84
84
|
return math.exp(t * -12.0)
|
85
85
|
|
86
86
|
else:
|
87
|
-
raise ValueError(f"Unsupported
|
87
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
88
88
|
|
89
89
|
betas = []
|
90
90
|
for i in range(num_diffusion_timesteps):
|
@@ -95,17 +95,17 @@ def betas_for_alpha_bar(
|
|
95
95
|
|
96
96
|
|
97
97
|
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
98
|
-
def rescale_zero_terminal_snr(betas: torch.
|
98
|
+
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
99
99
|
"""
|
100
100
|
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
101
101
|
|
102
102
|
|
103
103
|
Args:
|
104
|
-
betas (`torch.
|
104
|
+
betas (`torch.Tensor`):
|
105
105
|
the betas that the scheduler is being initialized with.
|
106
106
|
|
107
107
|
Returns:
|
108
|
-
`torch.
|
108
|
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
109
109
|
"""
|
110
110
|
# Convert betas to alphas_bar_sqrt
|
111
111
|
alphas = 1.0 - betas
|
@@ -224,7 +224,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
224
224
|
# Glide cosine schedule
|
225
225
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
226
226
|
else:
|
227
|
-
raise NotImplementedError(f"{beta_schedule}
|
227
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
228
228
|
|
229
229
|
# Rescale for zero SNR
|
230
230
|
if rescale_betas_zero_snr:
|
@@ -296,24 +296,24 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
296
296
|
"""
|
297
297
|
self._begin_index = begin_index
|
298
298
|
|
299
|
-
def scale_model_input(self, sample: torch.
|
299
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
300
300
|
"""
|
301
301
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
302
302
|
current timestep.
|
303
303
|
|
304
304
|
Args:
|
305
|
-
sample (`torch.
|
305
|
+
sample (`torch.Tensor`):
|
306
306
|
The input sample.
|
307
307
|
timestep (`int`, *optional*):
|
308
308
|
The current timestep in the diffusion chain.
|
309
309
|
Returns:
|
310
|
-
`torch.
|
310
|
+
`torch.Tensor`:
|
311
311
|
A scaled input sample.
|
312
312
|
"""
|
313
313
|
return sample
|
314
314
|
|
315
315
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
316
|
-
def _threshold_sample(self, sample: torch.
|
316
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
317
317
|
"""
|
318
318
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
319
319
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -497,9 +497,9 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
497
497
|
|
498
498
|
def step(
|
499
499
|
self,
|
500
|
-
model_output: torch.
|
500
|
+
model_output: torch.Tensor,
|
501
501
|
timestep: int,
|
502
|
-
sample: torch.
|
502
|
+
sample: torch.Tensor,
|
503
503
|
generator: Optional[torch.Generator] = None,
|
504
504
|
return_dict: bool = True,
|
505
505
|
) -> Union[LCMSchedulerOutput, Tuple]:
|
@@ -508,11 +508,11 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
508
508
|
process from the learned model outputs (most often the predicted noise).
|
509
509
|
|
510
510
|
Args:
|
511
|
-
model_output (`torch.
|
511
|
+
model_output (`torch.Tensor`):
|
512
512
|
The direct output from learned diffusion model.
|
513
513
|
timestep (`float`):
|
514
514
|
The current discrete timestep in the diffusion chain.
|
515
|
-
sample (`torch.
|
515
|
+
sample (`torch.Tensor`):
|
516
516
|
A current instance of a sample created by the diffusion process.
|
517
517
|
generator (`torch.Generator`, *optional*):
|
518
518
|
A random number generator.
|
@@ -594,10 +594,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
594
594
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
595
595
|
def add_noise(
|
596
596
|
self,
|
597
|
-
original_samples: torch.
|
598
|
-
noise: torch.
|
597
|
+
original_samples: torch.Tensor,
|
598
|
+
noise: torch.Tensor,
|
599
599
|
timesteps: torch.IntTensor,
|
600
|
-
) -> torch.
|
600
|
+
) -> torch.Tensor:
|
601
601
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
602
602
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
603
603
|
# for the subsequent add_noise calls
|
@@ -619,9 +619,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
|
619
619
|
return noisy_samples
|
620
620
|
|
621
621
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
622
|
-
def get_velocity(
|
623
|
-
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
624
|
-
) -> torch.FloatTensor:
|
622
|
+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
625
623
|
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
626
624
|
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
627
625
|
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|