diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -40,15 +40,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
|
40
40
|
@register_to_config
|
41
41
|
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
|
42
42
|
super().__init__()
|
43
|
-
conv_cls = nn.Conv2d
|
44
|
-
linear_cls = nn.Linear
|
45
43
|
|
46
44
|
self.c_r = c_r
|
47
|
-
self.projection =
|
45
|
+
self.projection = nn.Conv2d(c_in, c, kernel_size=1)
|
48
46
|
self.cond_mapper = nn.Sequential(
|
49
|
-
|
47
|
+
nn.Linear(c_cond, c),
|
50
48
|
nn.LeakyReLU(0.2),
|
51
|
-
|
49
|
+
nn.Linear(c, c),
|
52
50
|
)
|
53
51
|
|
54
52
|
self.blocks = nn.ModuleList()
|
@@ -58,7 +56,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
|
58
56
|
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
|
59
57
|
self.out = nn.Sequential(
|
60
58
|
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
|
61
|
-
|
59
|
+
nn.Conv2d(c, c_in * 2, kernel_size=1),
|
62
60
|
)
|
63
61
|
|
64
62
|
self.gradient_checkpointing = False
|
@@ -209,7 +209,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
|
209
209
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
210
210
|
def __call__(
|
211
211
|
self,
|
212
|
-
image_embeddings: Union[torch.
|
212
|
+
image_embeddings: Union[torch.Tensor, List[torch.Tensor]],
|
213
213
|
prompt: Union[str, List[str]] = None,
|
214
214
|
num_inference_steps: int = 12,
|
215
215
|
timesteps: Optional[List[float]] = None,
|
@@ -217,7 +217,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
|
217
217
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
218
218
|
num_images_per_prompt: int = 1,
|
219
219
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
220
|
-
latents: Optional[torch.
|
220
|
+
latents: Optional[torch.Tensor] = None,
|
221
221
|
output_type: Optional[str] = "pil",
|
222
222
|
return_dict: bool = True,
|
223
223
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
@@ -228,7 +228,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
|
228
228
|
Function invoked when calling the pipeline for generation.
|
229
229
|
|
230
230
|
Args:
|
231
|
-
image_embedding (`torch.
|
231
|
+
image_embedding (`torch.Tensor` or `List[torch.Tensor]`):
|
232
232
|
Image Embeddings either extracted from an image or generated by a Prior Model.
|
233
233
|
prompt (`str` or `List[str]`):
|
234
234
|
The prompt or prompts to guide the image generation.
|
@@ -252,7 +252,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
|
252
252
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
253
253
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
254
254
|
to make generation deterministic.
|
255
|
-
latents (`torch.
|
255
|
+
latents (`torch.Tensor`, *optional*):
|
256
256
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
257
257
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
258
258
|
tensor will ge generated by sampling using the supplied random `generator`.
|
@@ -112,25 +112,25 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
|
112
112
|
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
113
113
|
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
|
114
114
|
|
115
|
-
def enable_model_cpu_offload(self, gpu_id=
|
115
|
+
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
116
116
|
r"""
|
117
117
|
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
118
118
|
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
119
119
|
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
120
120
|
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
121
121
|
"""
|
122
|
-
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
|
123
|
-
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
|
122
|
+
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
|
123
|
+
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
|
124
124
|
|
125
|
-
def enable_sequential_cpu_offload(self, gpu_id=
|
125
|
+
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
126
126
|
r"""
|
127
127
|
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
|
128
128
|
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
|
129
129
|
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
|
130
130
|
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
|
131
131
|
"""
|
132
|
-
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
133
|
-
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
|
132
|
+
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
133
|
+
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
|
134
134
|
|
135
135
|
def progress_bar(self, iterable=None, total=None):
|
136
136
|
self.prior_pipe.progress_bar(iterable=iterable, total=total)
|
@@ -154,11 +154,11 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
|
154
154
|
decoder_timesteps: Optional[List[float]] = None,
|
155
155
|
decoder_guidance_scale: float = 0.0,
|
156
156
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
157
|
-
prompt_embeds: Optional[torch.
|
158
|
-
negative_prompt_embeds: Optional[torch.
|
157
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
158
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
159
159
|
num_images_per_prompt: int = 1,
|
160
160
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
161
|
-
latents: Optional[torch.
|
161
|
+
latents: Optional[torch.Tensor] = None,
|
162
162
|
output_type: Optional[str] = "pil",
|
163
163
|
return_dict: bool = True,
|
164
164
|
prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
@@ -176,10 +176,10 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
|
176
176
|
negative_prompt (`str` or `List[str]`, *optional*):
|
177
177
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
178
178
|
if `guidance_scale` is less than `1`).
|
179
|
-
prompt_embeds (`torch.
|
179
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
180
180
|
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
|
181
181
|
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
|
182
|
-
negative_prompt_embeds (`torch.
|
182
|
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
183
183
|
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
|
184
184
|
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
|
185
185
|
input argument.
|
@@ -218,7 +218,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
|
218
218
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
219
219
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
220
220
|
to make generation deterministic.
|
221
|
-
latents (`torch.
|
221
|
+
latents (`torch.Tensor`, *optional*):
|
222
222
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
223
223
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
224
224
|
tensor will ge generated by sampling using the supplied random `generator`.
|
@@ -54,12 +54,12 @@ class WuerstchenPriorPipelineOutput(BaseOutput):
|
|
54
54
|
Output class for WuerstchenPriorPipeline.
|
55
55
|
|
56
56
|
Args:
|
57
|
-
image_embeddings (`torch.
|
57
|
+
image_embeddings (`torch.Tensor` or `np.ndarray`)
|
58
58
|
Prior image embeddings for text prompt
|
59
59
|
|
60
60
|
"""
|
61
61
|
|
62
|
-
image_embeddings: Union[torch.
|
62
|
+
image_embeddings: Union[torch.Tensor, np.ndarray]
|
63
63
|
|
64
64
|
|
65
65
|
class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
@@ -136,8 +136,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
|
136
136
|
do_classifier_free_guidance,
|
137
137
|
prompt=None,
|
138
138
|
negative_prompt=None,
|
139
|
-
prompt_embeds: Optional[torch.
|
140
|
-
negative_prompt_embeds: Optional[torch.
|
139
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
140
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
141
141
|
):
|
142
142
|
if prompt is not None and isinstance(prompt, str):
|
143
143
|
batch_size = 1
|
@@ -288,11 +288,11 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
|
288
288
|
timesteps: List[float] = None,
|
289
289
|
guidance_scale: float = 8.0,
|
290
290
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
291
|
-
prompt_embeds: Optional[torch.
|
292
|
-
negative_prompt_embeds: Optional[torch.
|
291
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
292
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
293
293
|
num_images_per_prompt: Optional[int] = 1,
|
294
294
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
295
|
-
latents: Optional[torch.
|
295
|
+
latents: Optional[torch.Tensor] = None,
|
296
296
|
output_type: Optional[str] = "pt",
|
297
297
|
return_dict: bool = True,
|
298
298
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
@@ -324,10 +324,10 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
|
324
324
|
negative_prompt (`str` or `List[str]`, *optional*):
|
325
325
|
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
326
326
|
if `decoder_guidance_scale` is less than `1`).
|
327
|
-
prompt_embeds (`torch.
|
327
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
328
328
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
329
329
|
provided, text embeddings will be generated from `prompt` input argument.
|
330
|
-
negative_prompt_embeds (`torch.
|
330
|
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
331
331
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
332
332
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
333
333
|
argument.
|
@@ -336,7 +336,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
|
336
336
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
337
337
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
338
338
|
to make generation deterministic.
|
339
|
-
latents (`torch.
|
339
|
+
latents (`torch.Tensor`, *optional*):
|
340
340
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
341
341
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
342
342
|
tensor will ge generated by sampling using the supplied random `generator`.
|
diffusers/schedulers/__init__.py
CHANGED
@@ -68,7 +68,7 @@ else:
|
|
68
68
|
_import_structure["scheduling_tcd"] = ["TCDScheduler"]
|
69
69
|
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
|
70
70
|
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
|
71
|
-
_import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
|
71
|
+
_import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"]
|
72
72
|
_import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"]
|
73
73
|
|
74
74
|
try:
|
@@ -163,7 +163,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
163
163
|
from .scheduling_tcd import TCDScheduler
|
164
164
|
from .scheduling_unclip import UnCLIPScheduler
|
165
165
|
from .scheduling_unipc_multistep import UniPCMultistepScheduler
|
166
|
-
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
166
|
+
from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin
|
167
167
|
from .scheduling_vq_diffusion import VQDiffusionScheduler
|
168
168
|
|
169
169
|
try:
|
@@ -30,7 +30,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
30
30
|
raise OptionalDependencyNotAvailable()
|
31
31
|
|
32
32
|
except OptionalDependencyNotAvailable:
|
33
|
-
from
|
33
|
+
from ...utils.dummy_pt_objects import * # noqa F403
|
34
34
|
else:
|
35
35
|
from .scheduling_karras_ve import KarrasVeScheduler
|
36
36
|
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
@@ -31,19 +31,19 @@ class KarrasVeOutput(BaseOutput):
|
|
31
31
|
Output class for the scheduler's step function output.
|
32
32
|
|
33
33
|
Args:
|
34
|
-
prev_sample (`torch.
|
34
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
35
35
|
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
36
36
|
denoising loop.
|
37
|
-
derivative (`torch.
|
37
|
+
derivative (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
38
38
|
Derivative of predicted original image sample (x_0).
|
39
|
-
pred_original_sample (`torch.
|
39
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
40
40
|
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
41
41
|
`pred_original_sample` can be used to preview progress or for guidance.
|
42
42
|
"""
|
43
43
|
|
44
|
-
prev_sample: torch.
|
45
|
-
derivative: torch.
|
46
|
-
pred_original_sample: Optional[torch.
|
44
|
+
prev_sample: torch.Tensor
|
45
|
+
derivative: torch.Tensor
|
46
|
+
pred_original_sample: Optional[torch.Tensor] = None
|
47
47
|
|
48
48
|
|
49
49
|
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
@@ -94,21 +94,21 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|
94
94
|
# setable values
|
95
95
|
self.num_inference_steps: int = None
|
96
96
|
self.timesteps: np.IntTensor = None
|
97
|
-
self.schedule: torch.
|
97
|
+
self.schedule: torch.Tensor = None # sigma(t_i)
|
98
98
|
|
99
|
-
def scale_model_input(self, sample: torch.
|
99
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
100
100
|
"""
|
101
101
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
102
102
|
current timestep.
|
103
103
|
|
104
104
|
Args:
|
105
|
-
sample (`torch.
|
105
|
+
sample (`torch.Tensor`):
|
106
106
|
The input sample.
|
107
107
|
timestep (`int`, *optional*):
|
108
108
|
The current timestep in the diffusion chain.
|
109
109
|
|
110
110
|
Returns:
|
111
|
-
`torch.
|
111
|
+
`torch.Tensor`:
|
112
112
|
A scaled input sample.
|
113
113
|
"""
|
114
114
|
return sample
|
@@ -136,14 +136,14 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|
136
136
|
self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
|
137
137
|
|
138
138
|
def add_noise_to_input(
|
139
|
-
self, sample: torch.
|
140
|
-
) -> Tuple[torch.
|
139
|
+
self, sample: torch.Tensor, sigma: float, generator: Optional[torch.Generator] = None
|
140
|
+
) -> Tuple[torch.Tensor, float]:
|
141
141
|
"""
|
142
142
|
Explicit Langevin-like "churn" step of adding noise to the sample according to a `gamma_i ≥ 0` to reach a
|
143
143
|
higher noise level `sigma_hat = sigma_i + gamma_i*sigma_i`.
|
144
144
|
|
145
145
|
Args:
|
146
|
-
sample (`torch.
|
146
|
+
sample (`torch.Tensor`):
|
147
147
|
The input sample.
|
148
148
|
sigma (`float`):
|
149
149
|
generator (`torch.Generator`, *optional*):
|
@@ -163,10 +163,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|
163
163
|
|
164
164
|
def step(
|
165
165
|
self,
|
166
|
-
model_output: torch.
|
166
|
+
model_output: torch.Tensor,
|
167
167
|
sigma_hat: float,
|
168
168
|
sigma_prev: float,
|
169
|
-
sample_hat: torch.
|
169
|
+
sample_hat: torch.Tensor,
|
170
170
|
return_dict: bool = True,
|
171
171
|
) -> Union[KarrasVeOutput, Tuple]:
|
172
172
|
"""
|
@@ -174,11 +174,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|
174
174
|
process from the learned model outputs (most often the predicted noise).
|
175
175
|
|
176
176
|
Args:
|
177
|
-
model_output (`torch.
|
177
|
+
model_output (`torch.Tensor`):
|
178
178
|
The direct output from learned diffusion model.
|
179
179
|
sigma_hat (`float`):
|
180
180
|
sigma_prev (`float`):
|
181
|
-
sample_hat (`torch.
|
181
|
+
sample_hat (`torch.Tensor`):
|
182
182
|
return_dict (`bool`, *optional*, defaults to `True`):
|
183
183
|
Whether or not to return a [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`.
|
184
184
|
|
@@ -202,25 +202,25 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|
202
202
|
|
203
203
|
def step_correct(
|
204
204
|
self,
|
205
|
-
model_output: torch.
|
205
|
+
model_output: torch.Tensor,
|
206
206
|
sigma_hat: float,
|
207
207
|
sigma_prev: float,
|
208
|
-
sample_hat: torch.
|
209
|
-
sample_prev: torch.
|
210
|
-
derivative: torch.
|
208
|
+
sample_hat: torch.Tensor,
|
209
|
+
sample_prev: torch.Tensor,
|
210
|
+
derivative: torch.Tensor,
|
211
211
|
return_dict: bool = True,
|
212
212
|
) -> Union[KarrasVeOutput, Tuple]:
|
213
213
|
"""
|
214
214
|
Corrects the predicted sample based on the `model_output` of the network.
|
215
215
|
|
216
216
|
Args:
|
217
|
-
model_output (`torch.
|
217
|
+
model_output (`torch.Tensor`):
|
218
218
|
The direct output from learned diffusion model.
|
219
219
|
sigma_hat (`float`): TODO
|
220
220
|
sigma_prev (`float`): TODO
|
221
|
-
sample_hat (`torch.
|
222
|
-
sample_prev (`torch.
|
223
|
-
derivative (`torch.
|
221
|
+
sample_hat (`torch.Tensor`): TODO
|
222
|
+
sample_prev (`torch.Tensor`): TODO
|
223
|
+
derivative (`torch.Tensor`): TODO
|
224
224
|
return_dict (`bool`, *optional*, defaults to `True`):
|
225
225
|
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
|
226
226
|
|
@@ -29,16 +29,16 @@ class AmusedSchedulerOutput(BaseOutput):
|
|
29
29
|
Output class for the scheduler's `step` function output.
|
30
30
|
|
31
31
|
Args:
|
32
|
-
prev_sample (`torch.
|
32
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
33
33
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
34
34
|
denoising loop.
|
35
|
-
pred_original_sample (`torch.
|
35
|
+
pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
36
36
|
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
37
37
|
`pred_original_sample` can be used to preview progress or for guidance.
|
38
38
|
"""
|
39
39
|
|
40
|
-
prev_sample: torch.
|
41
|
-
pred_original_sample: torch.
|
40
|
+
prev_sample: torch.Tensor
|
41
|
+
pred_original_sample: torch.Tensor = None
|
42
42
|
|
43
43
|
|
44
44
|
class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
@@ -70,7 +70,7 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
|
|
70
70
|
|
71
71
|
def step(
|
72
72
|
self,
|
73
|
-
model_output: torch.
|
73
|
+
model_output: torch.Tensor,
|
74
74
|
timestep: torch.long,
|
75
75
|
sample: torch.LongTensor,
|
76
76
|
starting_mask_ratio: int = 1,
|
@@ -45,7 +45,7 @@ def betas_for_alpha_bar(
|
|
45
45
|
return math.exp(t * -12.0)
|
46
46
|
|
47
47
|
else:
|
48
|
-
raise ValueError(f"Unsupported
|
48
|
+
raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
|
49
49
|
|
50
50
|
betas = []
|
51
51
|
for i in range(num_diffusion_timesteps):
|
@@ -61,12 +61,12 @@ class ConsistencyDecoderSchedulerOutput(BaseOutput):
|
|
61
61
|
Output class for the scheduler's `step` function.
|
62
62
|
|
63
63
|
Args:
|
64
|
-
prev_sample (`torch.
|
64
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
65
65
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
66
66
|
denoising loop.
|
67
67
|
"""
|
68
68
|
|
69
|
-
prev_sample: torch.
|
69
|
+
prev_sample: torch.Tensor
|
70
70
|
|
71
71
|
|
72
72
|
class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
|
@@ -113,28 +113,28 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
|
|
113
113
|
def init_noise_sigma(self):
|
114
114
|
return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
|
115
115
|
|
116
|
-
def scale_model_input(self, sample: torch.
|
116
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
117
117
|
"""
|
118
118
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
119
119
|
current timestep.
|
120
120
|
|
121
121
|
Args:
|
122
|
-
sample (`torch.
|
122
|
+
sample (`torch.Tensor`):
|
123
123
|
The input sample.
|
124
124
|
timestep (`int`, *optional*):
|
125
125
|
The current timestep in the diffusion chain.
|
126
126
|
|
127
127
|
Returns:
|
128
|
-
`torch.
|
128
|
+
`torch.Tensor`:
|
129
129
|
A scaled input sample.
|
130
130
|
"""
|
131
131
|
return sample * self.c_in[timestep]
|
132
132
|
|
133
133
|
def step(
|
134
134
|
self,
|
135
|
-
model_output: torch.
|
136
|
-
timestep: Union[float, torch.
|
137
|
-
sample: torch.
|
135
|
+
model_output: torch.Tensor,
|
136
|
+
timestep: Union[float, torch.Tensor],
|
137
|
+
sample: torch.Tensor,
|
138
138
|
generator: Optional[torch.Generator] = None,
|
139
139
|
return_dict: bool = True,
|
140
140
|
) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]:
|
@@ -143,11 +143,11 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
|
|
143
143
|
process from the learned model outputs (most often the predicted noise).
|
144
144
|
|
145
145
|
Args:
|
146
|
-
model_output (`torch.
|
146
|
+
model_output (`torch.Tensor`):
|
147
147
|
The direct output from the learned diffusion model.
|
148
148
|
timestep (`float`):
|
149
149
|
The current timestep in the diffusion chain.
|
150
|
-
sample (`torch.
|
150
|
+
sample (`torch.Tensor`):
|
151
151
|
A current instance of a sample created by the diffusion process.
|
152
152
|
generator (`torch.Generator`, *optional*):
|
153
153
|
A random number generator.
|
@@ -33,12 +33,12 @@ class CMStochasticIterativeSchedulerOutput(BaseOutput):
|
|
33
33
|
Output class for the scheduler's `step` function.
|
34
34
|
|
35
35
|
Args:
|
36
|
-
prev_sample (`torch.
|
36
|
+
prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
|
37
37
|
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
38
38
|
denoising loop.
|
39
39
|
"""
|
40
40
|
|
41
|
-
prev_sample: torch.
|
41
|
+
prev_sample: torch.Tensor
|
42
42
|
|
43
43
|
|
44
44
|
class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
@@ -104,7 +104,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
104
104
|
@property
|
105
105
|
def step_index(self):
|
106
106
|
"""
|
107
|
-
The index counter for current timestep. It will
|
107
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
108
108
|
"""
|
109
109
|
return self._step_index
|
110
110
|
|
@@ -126,20 +126,18 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
126
126
|
"""
|
127
127
|
self._begin_index = begin_index
|
128
128
|
|
129
|
-
def scale_model_input(
|
130
|
-
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
131
|
-
) -> torch.FloatTensor:
|
129
|
+
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
132
130
|
"""
|
133
131
|
Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`.
|
134
132
|
|
135
133
|
Args:
|
136
|
-
sample (`torch.
|
134
|
+
sample (`torch.Tensor`):
|
137
135
|
The input sample.
|
138
|
-
timestep (`float` or `torch.
|
136
|
+
timestep (`float` or `torch.Tensor`):
|
139
137
|
The current timestep in the diffusion chain.
|
140
138
|
|
141
139
|
Returns:
|
142
|
-
`torch.
|
140
|
+
`torch.Tensor`:
|
143
141
|
A scaled input sample.
|
144
142
|
"""
|
145
143
|
# Get sigma corresponding to timestep
|
@@ -233,7 +231,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
233
231
|
sigmas = self._convert_to_karras(ramp)
|
234
232
|
timesteps = self.sigma_to_t(sigmas)
|
235
233
|
|
236
|
-
sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32)
|
234
|
+
sigmas = np.concatenate([sigmas, [self.config.sigma_min]]).astype(np.float32)
|
237
235
|
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
238
236
|
|
239
237
|
if str(device).startswith("mps"):
|
@@ -278,7 +276,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
278
276
|
</Tip>
|
279
277
|
|
280
278
|
Args:
|
281
|
-
sigma (`torch.
|
279
|
+
sigma (`torch.Tensor`):
|
282
280
|
The current sigma in the Karras sigma schedule.
|
283
281
|
|
284
282
|
Returns:
|
@@ -319,9 +317,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
319
317
|
|
320
318
|
def step(
|
321
319
|
self,
|
322
|
-
model_output: torch.
|
323
|
-
timestep: Union[float, torch.
|
324
|
-
sample: torch.
|
320
|
+
model_output: torch.Tensor,
|
321
|
+
timestep: Union[float, torch.Tensor],
|
322
|
+
sample: torch.Tensor,
|
325
323
|
generator: Optional[torch.Generator] = None,
|
326
324
|
return_dict: bool = True,
|
327
325
|
) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]:
|
@@ -330,11 +328,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
330
328
|
process from the learned model outputs (most often the predicted noise).
|
331
329
|
|
332
330
|
Args:
|
333
|
-
model_output (`torch.
|
331
|
+
model_output (`torch.Tensor`):
|
334
332
|
The direct output from the learned diffusion model.
|
335
333
|
timestep (`float`):
|
336
334
|
The current timestep in the diffusion chain.
|
337
|
-
sample (`torch.
|
335
|
+
sample (`torch.Tensor`):
|
338
336
|
A current instance of a sample created by the diffusion process.
|
339
337
|
generator (`torch.Generator`, *optional*):
|
340
338
|
A random number generator.
|
@@ -349,11 +347,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
349
347
|
otherwise a tuple is returned where the first element is the sample tensor.
|
350
348
|
"""
|
351
349
|
|
352
|
-
if (
|
353
|
-
isinstance(timestep, int)
|
354
|
-
or isinstance(timestep, torch.IntTensor)
|
355
|
-
or isinstance(timestep, torch.LongTensor)
|
356
|
-
):
|
350
|
+
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
357
351
|
raise ValueError(
|
358
352
|
(
|
359
353
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
@@ -417,10 +411,10 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
417
411
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
|
418
412
|
def add_noise(
|
419
413
|
self,
|
420
|
-
original_samples: torch.
|
421
|
-
noise: torch.
|
422
|
-
timesteps: torch.
|
423
|
-
) -> torch.
|
414
|
+
original_samples: torch.Tensor,
|
415
|
+
noise: torch.Tensor,
|
416
|
+
timesteps: torch.Tensor,
|
417
|
+
) -> torch.Tensor:
|
424
418
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
425
419
|
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
426
420
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
@@ -434,7 +428,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
|
434
428
|
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
435
429
|
if self.begin_index is None:
|
436
430
|
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
431
|
+
elif self.step_index is not None:
|
432
|
+
# add_noise is called after first denoising step (for inpainting)
|
433
|
+
step_indices = [self.step_index] * timesteps.shape[0]
|
437
434
|
else:
|
435
|
+
# add noise is called before first denoising step to create initial latent(img2img)
|
438
436
|
step_indices = [self.begin_index] * timesteps.shape[0]
|
439
437
|
|
440
438
|
sigma = sigmas[step_indices].flatten()
|