diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -32,15 +32,15 @@ class SdeVeOutput(BaseOutput):
|
|
32
32
|
Output class for the scheduler's `step` function output.
|
33
33
|
|
34
34
|
Args:
|
35
|
-
prev_sample (`torch.
|
35
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
36
36
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
37
37
|
denoising loop.
|
38
|
-
prev_sample_mean (`torch.
|
38
|
+
prev_sample_mean (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
39
39
|
Mean averaged `prev_sample` over previous timesteps.
|
40
40
|
"""
|
41
41
|
|
42
|
-
prev_sample: torch.
|
43
|
-
prev_sample_mean: torch.
|
42
|
+
prev_sample: torch.Tensor
|
43
|
+
prev_sample_mean: torch.Tensor
|
44
44
|
|
45
45
|
|
46
46
|
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
@@ -86,19 +86,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|
86
86
|
|
87
87
|
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
|
88
88
|
|
89
|
-
def scale_model_input(self, sample: torch.
|
89
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
90
90
|
"""
|
91
91
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
92
92
|
current timestep.
|
93
93
|
|
94
94
|
Args:
|
95
|
-
sample (`torch.
|
95
|
+
sample (`torch.Tensor`):
|
96
96
|
The input sample.
|
97
97
|
timestep (`int`, *optional*):
|
98
98
|
The current timestep in the diffusion chain.
|
99
99
|
|
100
100
|
Returns:
|
101
|
-
`torch.
|
101
|
+
`torch.Tensor`:
|
102
102
|
A scaled input sample.
|
103
103
|
"""
|
104
104
|
return sample
|
@@ -159,9 +159,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|
159
159
|
|
160
160
|
def step_pred(
|
161
161
|
self,
|
162
|
-
model_output: torch.
|
162
|
+
model_output: torch.Tensor,
|
163
163
|
timestep: int,
|
164
|
-
sample: torch.
|
164
|
+
sample: torch.Tensor,
|
165
165
|
generator: Optional[torch.Generator] = None,
|
166
166
|
return_dict: bool = True,
|
167
167
|
) -> Union[SdeVeOutput, Tuple]:
|
@@ -170,11 +170,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|
170
170
|
process from the learned model outputs (most often the predicted noise).
|
171
171
|
|
172
172
|
Args:
|
173
|
-
model_output (`torch.
|
173
|
+
model_output (`torch.Tensor`):
|
174
174
|
The direct output from learned diffusion model.
|
175
175
|
timestep (`int`):
|
176
176
|
The current discrete timestep in the diffusion chain.
|
177
|
-
sample (`torch.
|
177
|
+
sample (`torch.Tensor`):
|
178
178
|
A current instance of a sample created by the diffusion process.
|
179
179
|
generator (`torch.Generator`, *optional*):
|
180
180
|
A random number generator.
|
@@ -227,8 +227,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|
227
227
|
|
228
228
|
def step_correct(
|
229
229
|
self,
|
230
|
-
model_output: torch.
|
231
|
-
sample: torch.
|
230
|
+
model_output: torch.Tensor,
|
231
|
+
sample: torch.Tensor,
|
232
232
|
generator: Optional[torch.Generator] = None,
|
233
233
|
return_dict: bool = True,
|
234
234
|
) -> Union[SchedulerOutput, Tuple]:
|
@@ -237,9 +237,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|
237
237
|
making the prediction for the previous timestep.
|
238
238
|
|
239
239
|
Args:
|
240
|
-
model_output (`torch.
|
240
|
+
model_output (`torch.Tensor`):
|
241
241
|
The direct output from learned diffusion model.
|
242
|
-
sample (`torch.
|
242
|
+
sample (`torch.Tensor`):
|
243
243
|
A current instance of a sample created by the diffusion process.
|
244
244
|
generator (`torch.Generator`, *optional*):
|
245
245
|
A random number generator.
|
@@ -282,10 +282,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|
282
282
|
|
283
283
|
def add_noise(
|
284
284
|
self,
|
285
|
-
original_samples: torch.
|
286
|
-
noise: torch.
|
287
|
-
timesteps: torch.
|
288
|
-
) -> torch.
|
285
|
+
original_samples: torch.Tensor,
|
286
|
+
noise: torch.Tensor,
|
287
|
+
timesteps: torch.Tensor,
|
288
|
+
) -> torch.Tensor:
|
289
289
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
290
290
|
timesteps = timesteps.to(original_samples.device)
|
291
291
|
sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps]
|
@@ -37,15 +37,15 @@ class TCDSchedulerOutput(BaseOutput):
|
|
37
37
|
Output class for the scheduler's `step` function output.
|
38
38
|
|
39
39
|
Args:
|
40
|
-
prev_sample (`torch.
|
40
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
41
41
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
42
42
|
denoising loop.
|
43
|
-
pred_noised_sample (`torch.
|
43
|
+
pred_noised_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
44
44
|
The predicted noised sample `(x_{s})` based on the model output from the current timestep.
|
45
45
|
"""
|
46
46
|
|
47
|
-
prev_sample: torch.
|
48
|
-
pred_noised_sample: Optional[torch.
|
47
|
+
prev_sample: torch.Tensor
|
48
|
+
pred_noised_sample: Optional[torch.Tensor] = None
|
49
49
|
|
50
50
|
|
51
51
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -83,7 +83,7 @@ def betas_for_alpha_bar(
|
|
83
83
|
return math.exp(t * -12.0)
|
84
84
|
|
85
85
|
else:
|
86
|
-
raise ValueError(f"Unsupported
|
86
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
87
87
|
|
88
88
|
betas = []
|
89
89
|
for i in range(num_diffusion_timesteps):
|
@@ -94,17 +94,17 @@ def betas_for_alpha_bar(
|
|
94
94
|
|
95
95
|
|
96
96
|
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
97
|
-
def rescale_zero_terminal_snr(betas: torch.
|
97
|
+
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
98
98
|
"""
|
99
99
|
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
100
100
|
|
101
101
|
|
102
102
|
Args:
|
103
|
-
betas (`torch.
|
103
|
+
betas (`torch.Tensor`):
|
104
104
|
the betas that the scheduler is being initialized with.
|
105
105
|
|
106
106
|
Returns:
|
107
|
-
`torch.
|
107
|
+
`torch.Tensor`: rescaled betas with zero terminal SNR
|
108
108
|
"""
|
109
109
|
# Convert betas to alphas_bar_sqrt
|
110
110
|
alphas = 1.0 - betas
|
@@ -132,8 +132,8 @@ def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
|
|
132
132
|
|
133
133
|
class TCDScheduler(SchedulerMixin, ConfigMixin):
|
134
134
|
"""
|
135
|
-
`TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency
|
136
|
-
extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal.
|
135
|
+
`TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency
|
136
|
+
Distillation`, extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal.
|
137
137
|
|
138
138
|
This code is based on the official repo of TCD(https://github.com/jabir-zheng/TCD).
|
139
139
|
|
@@ -225,7 +225,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
225
225
|
# Glide cosine schedule
|
226
226
|
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
227
227
|
else:
|
228
|
-
raise NotImplementedError(f"{beta_schedule}
|
228
|
+
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
229
229
|
|
230
230
|
# Rescale for zero SNR
|
231
231
|
if rescale_betas_zero_snr:
|
@@ -297,18 +297,19 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
297
297
|
"""
|
298
298
|
self._begin_index = begin_index
|
299
299
|
|
300
|
-
def scale_model_input(self, sample: torch.
|
300
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
301
301
|
"""
|
302
302
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
303
303
|
current timestep.
|
304
304
|
|
305
305
|
Args:
|
306
|
-
sample (`torch.
|
306
|
+
sample (`torch.Tensor`):
|
307
307
|
The input sample.
|
308
308
|
timestep (`int`, *optional*):
|
309
309
|
The current timestep in the diffusion chain.
|
310
|
+
|
310
311
|
Returns:
|
311
|
-
`torch.
|
312
|
+
`torch.Tensor`:
|
312
313
|
A scaled input sample.
|
313
314
|
"""
|
314
315
|
return sample
|
@@ -325,7 +326,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
325
326
|
return variance
|
326
327
|
|
327
328
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
328
|
-
def _threshold_sample(self, sample: torch.
|
329
|
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
329
330
|
"""
|
330
331
|
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
331
332
|
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
@@ -364,7 +365,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
364
365
|
device: Union[str, torch.device] = None,
|
365
366
|
original_inference_steps: Optional[int] = None,
|
366
367
|
timesteps: Optional[List[int]] = None,
|
367
|
-
strength:
|
368
|
+
strength: float = 1.0,
|
368
369
|
):
|
369
370
|
"""
|
370
371
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
@@ -384,6 +385,8 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
384
385
|
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
385
386
|
timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
|
386
387
|
schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
|
388
|
+
strength (`float`, *optional*, defaults to 1.0):
|
389
|
+
Used to determine the number of timesteps used for inference when using img2img, inpaint, etc.
|
387
390
|
"""
|
388
391
|
# 0. Check inputs
|
389
392
|
if num_inference_steps is None and timesteps is None:
|
@@ -521,9 +524,9 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
521
524
|
|
522
525
|
def step(
|
523
526
|
self,
|
524
|
-
model_output: torch.
|
527
|
+
model_output: torch.Tensor,
|
525
528
|
timestep: int,
|
526
|
-
sample: torch.
|
529
|
+
sample: torch.Tensor,
|
527
530
|
eta: float = 0.3,
|
528
531
|
generator: Optional[torch.Generator] = None,
|
529
532
|
return_dict: bool = True,
|
@@ -533,15 +536,16 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
533
536
|
process from the learned model outputs (most often the predicted noise).
|
534
537
|
|
535
538
|
Args:
|
536
|
-
model_output (`torch.
|
539
|
+
model_output (`torch.Tensor`):
|
537
540
|
The direct output from learned diffusion model.
|
538
541
|
timestep (`int`):
|
539
542
|
The current discrete timestep in the diffusion chain.
|
540
|
-
sample (`torch.
|
543
|
+
sample (`torch.Tensor`):
|
541
544
|
A current instance of a sample created by the diffusion process.
|
542
545
|
eta (`float`):
|
543
|
-
A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every
|
544
|
-
When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic
|
546
|
+
A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every
|
547
|
+
step. When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic
|
548
|
+
sampling.
|
545
549
|
generator (`torch.Generator`, *optional*):
|
546
550
|
A random number generator.
|
547
551
|
return_dict (`bool`, *optional*, defaults to `True`):
|
@@ -624,14 +628,18 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
624
628
|
|
625
629
|
return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
|
626
630
|
|
631
|
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
627
632
|
def add_noise(
|
628
633
|
self,
|
629
|
-
original_samples: torch.
|
630
|
-
noise: torch.
|
634
|
+
original_samples: torch.Tensor,
|
635
|
+
noise: torch.Tensor,
|
631
636
|
timesteps: torch.IntTensor,
|
632
|
-
) -> torch.
|
637
|
+
) -> torch.Tensor:
|
633
638
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
634
|
-
|
639
|
+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
640
|
+
# for the subsequent add_noise calls
|
641
|
+
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
|
642
|
+
alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
|
635
643
|
timesteps = timesteps.to(original_samples.device)
|
636
644
|
|
637
645
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
@@ -647,11 +655,11 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
647
655
|
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
648
656
|
return noisy_samples
|
649
657
|
|
650
|
-
|
651
|
-
|
652
|
-
) -> torch.FloatTensor:
|
658
|
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
659
|
+
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
|
653
660
|
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
654
|
-
alphas_cumprod = self.alphas_cumprod.to(device=sample.device
|
661
|
+
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
|
662
|
+
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
|
655
663
|
timesteps = timesteps.to(sample.device)
|
656
664
|
|
657
665
|
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
@@ -670,6 +678,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
|
670
678
|
def __len__(self):
|
671
679
|
return self.config.num_train_timesteps
|
672
680
|
|
681
|
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
673
682
|
def previous_timestep(self, timestep):
|
674
683
|
if self.custom_timesteps:
|
675
684
|
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
@@ -32,16 +32,16 @@ class UnCLIPSchedulerOutput(BaseOutput):
|
|
32
32
|
Output class for the scheduler's `step` function output.
|
33
33
|
|
34
34
|
Args:
|
35
|
-
prev_sample (`torch.
|
35
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
36
36
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
37
37
|
denoising loop.
|
38
|
-
pred_original_sample (`torch.
|
38
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
39
39
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
40
40
|
`pred_original_sample` can be used to preview progress or for guidance.
|
41
41
|
"""
|
42
42
|
|
43
|
-
prev_sample: torch.
|
44
|
-
pred_original_sample: Optional[torch.
|
43
|
+
prev_sample: torch.Tensor
|
44
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
45
45
|
|
46
46
|
|
47
47
|
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
@@ -79,7 +79,7 @@ def betas_for_alpha_bar(
|
|
79
79
|
return math.exp(t * -12.0)
|
80
80
|
|
81
81
|
else:
|
82
|
-
raise ValueError(f"Unsupported
|
82
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
83
83
|
|
84
84
|
betas = []
|
85
85
|
for i in range(num_diffusion_timesteps):
|
@@ -146,17 +146,17 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
|
|
146
146
|
|
147
147
|
self.variance_type = variance_type
|
148
148
|
|
149
|
-
def scale_model_input(self, sample: torch.
|
149
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
150
150
|
"""
|
151
151
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
152
152
|
current timestep.
|
153
153
|
|
154
154
|
Args:
|
155
|
-
sample (`torch.
|
155
|
+
sample (`torch.Tensor`): input sample
|
156
156
|
timestep (`int`, optional): current timestep
|
157
157
|
|
158
158
|
Returns:
|
159
|
-
`torch.
|
159
|
+
`torch.Tensor`: scaled input sample
|
160
160
|
"""
|
161
161
|
return sample
|
162
162
|
|
@@ -215,9 +215,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
|
|
215
215
|
|
216
216
|
def step(
|
217
217
|
self,
|
218
|
-
model_output: torch.
|
218
|
+
model_output: torch.Tensor,
|
219
219
|
timestep: int,
|
220
|
-
sample: torch.
|
220
|
+
sample: torch.Tensor,
|
221
221
|
prev_timestep: Optional[int] = None,
|
222
222
|
generator=None,
|
223
223
|
return_dict: bool = True,
|
@@ -227,9 +227,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
|
|
227
227
|
process from the learned model outputs (most often the predicted noise).
|
228
228
|
|
229
229
|
Args:
|
230
|
-
model_output (`torch.
|
230
|
+
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
231
231
|
timestep (`int`): current discrete timestep in the diffusion chain.
|
232
|
-
sample (`torch.
|
232
|
+
sample (`torch.Tensor`):
|
233
233
|
current instance of sample being created by diffusion process.
|
234
234
|
prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at.
|
235
235
|
Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used.
|
@@ -327,10 +327,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
|
|
327
327
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
328
328
|
def add_noise(
|
329
329
|
self,
|
330
|
-
original_samples: torch.
|
331
|
-
noise: torch.
|
330
|
+
original_samples: torch.Tensor,
|
331
|
+
noise: torch.Tensor,
|
332
332
|
timesteps: torch.IntTensor,
|
333
|
-
) -> torch.
|
333
|
+
) -> torch.Tensor:
|
334
334
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
335
335
|
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
336
336
|
# for the subsequent add_noise calls
|