diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +20 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +7 -3
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +2 -2
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +110 -4
- diffusers/loaders/autoencoder.py +7 -8
- diffusers/loaders/controlnet.py +17 -8
- diffusers/loaders/ip_adapter.py +86 -23
- diffusers/loaders/lora.py +105 -310
- diffusers/loaders/lora_conversion_utils.py +1 -1
- diffusers/loaders/peft.py +1 -1
- diffusers/loaders/single_file.py +51 -12
- diffusers/loaders/single_file_utils.py +274 -49
- diffusers/loaders/textual_inversion.py +23 -4
- diffusers/loaders/unet.py +195 -41
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +3 -1
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +26 -36
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +171 -114
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flax.py +1 -1
- diffusers/models/downsampling.py +8 -12
- diffusers/models/dual_transformer_2d.py +1 -1
- diffusers/models/embeddings.py +3 -4
- diffusers/models/embeddings_flax.py +1 -1
- diffusers/models/lora.py +33 -10
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +4 -6
- diffusers/models/normalization.py +1 -1
- diffusers/models/resnet.py +31 -58
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/t5_film_transformer.py +1 -1
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/transformer_temporal.py +1 -1
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/t5_film_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +29 -31
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unet_1d.py +1 -1
- diffusers/models/unet_1d_blocks.py +1 -1
- diffusers/models/unet_2d.py +1 -1
- diffusers/models/unet_2d_blocks.py +1 -1
- diffusers/models/unet_2d_condition.py +1 -1
- diffusers/models/unets/__init__.py +1 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +4 -4
- diffusers/models/unets/unet_2d_blocks.py +238 -98
- diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +420 -323
- diffusers/models/unets/unet_2d_condition_flax.py +21 -12
- diffusers/models/unets/unet_3d_blocks.py +50 -40
- diffusers/models/unets/unet_3d_condition.py +47 -8
- diffusers/models/unets/unet_i2vgen_xl.py +75 -30
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +48 -8
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +610 -0
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +10 -16
- diffusers/models/vae_flax.py +1 -1
- diffusers/models/vq_model.py +1 -1
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +26 -0
- diffusers/pipelines/amused/pipeline_amused.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
- diffusers/pipelines/animatediff/pipeline_output.py +7 -6
- diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
- diffusers/pipelines/auto_pipeline.py +7 -16
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/free_init_utils.py +184 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ledits_pp/__init__.py +55 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
- diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
- diffusers/pipelines/onnx_utils.py +1 -1
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
- diffusers/pipelines/pia/pipeline_pia.py +168 -327
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +508 -0
- diffusers/pipelines/pipeline_utils.py +188 -534
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/__init__.py +50 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
- diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
- diffusers/pipelines/unclip/text_proj.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
- diffusers/schedulers/__init__.py +7 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +42 -19
- diffusers/schedulers/scheduling_ddim.py +2 -4
- diffusers/schedulers/scheduling_ddim_flax.py +13 -5
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm.py +2 -4
- diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +46 -19
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
- diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
- diffusers/schedulers/scheduling_edm_euler.py +381 -0
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
- diffusers/schedulers/scheduling_euler_discrete.py +42 -17
- diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_heun_discrete.py +35 -35
- diffusers/schedulers/scheduling_ipndm.py +37 -11
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
- diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +38 -14
- diffusers/schedulers/scheduling_lms_discrete.py +43 -15
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +2 -4
- diffusers/schedulers/scheduling_pndm_flax.py +2 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +41 -9
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_tcd.py +686 -0
- diffusers/schedulers/scheduling_unclip.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +9 -2
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +3 -3
- diffusers/utils/hub_utils.py +60 -16
- diffusers/utils/import_utils.py +15 -1
- diffusers/utils/loading_utils.py +2 -0
- diffusers/utils/logging.py +1 -1
- diffusers/utils/model_card_template.md +24 -0
- diffusers/utils/outputs.py +14 -7
- diffusers/utils/peft_utils.py +1 -1
- diffusers/utils/state_dict_utils.py +1 -1
- diffusers/utils/testing_utils.py +2 -0
- diffusers/utils/torch_utils.py +1 -1
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
- diffusers-0.27.0.dist-info/RECORD +399 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
- diffusers-0.26.3.dist-info/RECORD +0 -384
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -46,7 +46,7 @@ from ....utils import (
|
|
46
46
|
unscale_lora_layers,
|
47
47
|
)
|
48
48
|
from ....utils.torch_utils import randn_tensor
|
49
|
-
from ...pipeline_utils import DiffusionPipeline
|
49
|
+
from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
50
50
|
from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
51
51
|
from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
52
52
|
|
@@ -280,7 +280,7 @@ class Pix2PixZeroAttnProcessor:
|
|
280
280
|
return hidden_states
|
281
281
|
|
282
282
|
|
283
|
-
class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
283
|
+
class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin):
|
284
284
|
r"""
|
285
285
|
Pipeline for pixel-level image editing using Pix2Pix Zero. Based on Stable Diffusion.
|
286
286
|
|
@@ -463,7 +463,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
|
463
463
|
batch_size = prompt_embeds.shape[0]
|
464
464
|
|
465
465
|
if prompt_embeds is None:
|
466
|
-
# textual inversion:
|
466
|
+
# textual inversion: process multi-vector tokens if necessary
|
467
467
|
if isinstance(self, TextualInversionLoaderMixin):
|
468
468
|
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
469
469
|
|
@@ -545,7 +545,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
|
545
545
|
else:
|
546
546
|
uncond_tokens = negative_prompt
|
547
547
|
|
548
|
-
# textual inversion:
|
548
|
+
# textual inversion: process multi-vector tokens if necessary
|
549
549
|
if isinstance(self, TextualInversionLoaderMixin):
|
550
550
|
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
551
551
|
|
@@ -268,7 +268,6 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
|
268
268
|
return objs
|
269
269
|
|
270
270
|
|
271
|
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
|
272
271
|
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
273
272
|
r"""
|
274
273
|
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
@@ -1334,7 +1333,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1334
1333
|
**additional_residuals,
|
1335
1334
|
)
|
1336
1335
|
else:
|
1337
|
-
sample, res_samples = downsample_block(hidden_states=sample, temb=emb
|
1336
|
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
1338
1337
|
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
1339
1338
|
sample += down_intrablock_additional_residuals.pop(0)
|
1340
1339
|
|
@@ -1590,7 +1589,7 @@ class DownBlockFlat(nn.Module):
|
|
1590
1589
|
self.gradient_checkpointing = False
|
1591
1590
|
|
1592
1591
|
def forward(
|
1593
|
-
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
|
1592
|
+
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
|
1594
1593
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1595
1594
|
output_states = ()
|
1596
1595
|
|
@@ -1612,13 +1611,13 @@ class DownBlockFlat(nn.Module):
|
|
1612
1611
|
create_custom_forward(resnet), hidden_states, temb
|
1613
1612
|
)
|
1614
1613
|
else:
|
1615
|
-
hidden_states = resnet(hidden_states, temb
|
1614
|
+
hidden_states = resnet(hidden_states, temb)
|
1616
1615
|
|
1617
1616
|
output_states = output_states + (hidden_states,)
|
1618
1617
|
|
1619
1618
|
if self.downsamplers is not None:
|
1620
1619
|
for downsampler in self.downsamplers:
|
1621
|
-
hidden_states = downsampler(hidden_states
|
1620
|
+
hidden_states = downsampler(hidden_states)
|
1622
1621
|
|
1623
1622
|
output_states = output_states + (hidden_states,)
|
1624
1623
|
|
@@ -1729,8 +1728,6 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1729
1728
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1730
1729
|
output_states = ()
|
1731
1730
|
|
1732
|
-
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
1733
|
-
|
1734
1731
|
blocks = list(zip(self.resnets, self.attentions))
|
1735
1732
|
|
1736
1733
|
for i, (resnet, attn) in enumerate(blocks):
|
@@ -1761,7 +1758,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1761
1758
|
return_dict=False,
|
1762
1759
|
)[0]
|
1763
1760
|
else:
|
1764
|
-
hidden_states = resnet(hidden_states, temb
|
1761
|
+
hidden_states = resnet(hidden_states, temb)
|
1765
1762
|
hidden_states = attn(
|
1766
1763
|
hidden_states,
|
1767
1764
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -1779,7 +1776,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1779
1776
|
|
1780
1777
|
if self.downsamplers is not None:
|
1781
1778
|
for downsampler in self.downsamplers:
|
1782
|
-
hidden_states = downsampler(hidden_states
|
1779
|
+
hidden_states = downsampler(hidden_states)
|
1783
1780
|
|
1784
1781
|
output_states = output_states + (hidden_states,)
|
1785
1782
|
|
@@ -1843,8 +1840,13 @@ class UpBlockFlat(nn.Module):
|
|
1843
1840
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
1844
1841
|
temb: Optional[torch.FloatTensor] = None,
|
1845
1842
|
upsample_size: Optional[int] = None,
|
1846
|
-
|
1843
|
+
*args,
|
1844
|
+
**kwargs,
|
1847
1845
|
) -> torch.FloatTensor:
|
1846
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1847
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1848
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1849
|
+
|
1848
1850
|
is_freeu_enabled = (
|
1849
1851
|
getattr(self, "s1", None)
|
1850
1852
|
and getattr(self, "s2", None)
|
@@ -1888,11 +1890,11 @@ class UpBlockFlat(nn.Module):
|
|
1888
1890
|
create_custom_forward(resnet), hidden_states, temb
|
1889
1891
|
)
|
1890
1892
|
else:
|
1891
|
-
hidden_states = resnet(hidden_states, temb
|
1893
|
+
hidden_states = resnet(hidden_states, temb)
|
1892
1894
|
|
1893
1895
|
if self.upsamplers is not None:
|
1894
1896
|
for upsampler in self.upsamplers:
|
1895
|
-
hidden_states = upsampler(hidden_states, upsample_size
|
1897
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1896
1898
|
|
1897
1899
|
return hidden_states
|
1898
1900
|
|
@@ -2000,7 +2002,10 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
2000
2002
|
attention_mask: Optional[torch.FloatTensor] = None,
|
2001
2003
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
2002
2004
|
) -> torch.FloatTensor:
|
2003
|
-
|
2005
|
+
if cross_attention_kwargs is not None:
|
2006
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
2007
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
2008
|
+
|
2004
2009
|
is_freeu_enabled = (
|
2005
2010
|
getattr(self, "s1", None)
|
2006
2011
|
and getattr(self, "s2", None)
|
@@ -2054,7 +2059,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
2054
2059
|
return_dict=False,
|
2055
2060
|
)[0]
|
2056
2061
|
else:
|
2057
|
-
hidden_states = resnet(hidden_states, temb
|
2062
|
+
hidden_states = resnet(hidden_states, temb)
|
2058
2063
|
hidden_states = attn(
|
2059
2064
|
hidden_states,
|
2060
2065
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -2066,7 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
2066
2071
|
|
2067
2072
|
if self.upsamplers is not None:
|
2068
2073
|
for upsampler in self.upsamplers:
|
2069
|
-
hidden_states = upsampler(hidden_states, upsample_size
|
2074
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
2070
2075
|
|
2071
2076
|
return hidden_states
|
2072
2077
|
|
@@ -2159,7 +2164,7 @@ class UNetMidBlockFlat(nn.Module):
|
|
2159
2164
|
attentions = []
|
2160
2165
|
|
2161
2166
|
if attention_head_dim is None:
|
2162
|
-
logger.
|
2167
|
+
logger.warning(
|
2163
2168
|
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
2164
2169
|
)
|
2165
2170
|
attention_head_dim = in_channels
|
@@ -2331,8 +2336,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2331
2336
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2332
2337
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
2333
2338
|
) -> torch.FloatTensor:
|
2334
|
-
|
2335
|
-
|
2339
|
+
if cross_attention_kwargs is not None:
|
2340
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
2341
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
2342
|
+
|
2343
|
+
hidden_states = self.resnets[0](hidden_states, temb)
|
2336
2344
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
2337
2345
|
if self.training and self.gradient_checkpointing:
|
2338
2346
|
|
@@ -2369,7 +2377,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2369
2377
|
encoder_attention_mask=encoder_attention_mask,
|
2370
2378
|
return_dict=False,
|
2371
2379
|
)[0]
|
2372
|
-
hidden_states = resnet(hidden_states, temb
|
2380
|
+
hidden_states = resnet(hidden_states, temb)
|
2373
2381
|
|
2374
2382
|
return hidden_states
|
2375
2383
|
|
@@ -2470,7 +2478,8 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
|
2470
2478
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
2471
2479
|
) -> torch.FloatTensor:
|
2472
2480
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
2473
|
-
|
2481
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
2482
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
2474
2483
|
|
2475
2484
|
if attention_mask is None:
|
2476
2485
|
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
@@ -2483,7 +2492,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
|
2483
2492
|
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
2484
2493
|
mask = attention_mask
|
2485
2494
|
|
2486
|
-
hidden_states = self.resnets[0](hidden_states, temb
|
2495
|
+
hidden_states = self.resnets[0](hidden_states, temb)
|
2487
2496
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
2488
2497
|
# attn
|
2489
2498
|
hidden_states = attn(
|
@@ -2494,6 +2503,6 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
|
2494
2503
|
)
|
2495
2504
|
|
2496
2505
|
# resnet
|
2497
|
-
hidden_states = resnet(hidden_states, temb
|
2506
|
+
hidden_states = resnet(hidden_states, temb)
|
2498
2507
|
|
2499
2508
|
return hidden_states
|
diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -246,7 +246,6 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
|
246
246
|
extra_step_kwargs["generator"] = generator
|
247
247
|
return extra_step_kwargs
|
248
248
|
|
249
|
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
250
249
|
def check_inputs(
|
251
250
|
self,
|
252
251
|
prompt,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 Microsoft and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -4,7 +4,7 @@
|
|
4
4
|
# Copyright (c) 2021 OpenAI
|
5
5
|
# MIT License
|
6
6
|
#
|
7
|
-
# Copyright
|
7
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
8
8
|
#
|
9
9
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
10
10
|
# you may not use this file except in compliance with the License.
|
@@ -0,0 +1,184 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
from typing import Tuple, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.fft as fft
|
20
|
+
|
21
|
+
from ..utils.torch_utils import randn_tensor
|
22
|
+
|
23
|
+
|
24
|
+
class FreeInitMixin:
|
25
|
+
r"""Mixin class for FreeInit."""
|
26
|
+
|
27
|
+
def enable_free_init(
|
28
|
+
self,
|
29
|
+
num_iters: int = 3,
|
30
|
+
use_fast_sampling: bool = False,
|
31
|
+
method: str = "butterworth",
|
32
|
+
order: int = 4,
|
33
|
+
spatial_stop_frequency: float = 0.25,
|
34
|
+
temporal_stop_frequency: float = 0.25,
|
35
|
+
):
|
36
|
+
"""Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
|
37
|
+
|
38
|
+
This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
|
39
|
+
|
40
|
+
Args:
|
41
|
+
num_iters (`int`, *optional*, defaults to `3`):
|
42
|
+
Number of FreeInit noise re-initialization iterations.
|
43
|
+
use_fast_sampling (`bool`, *optional*, defaults to `False`):
|
44
|
+
Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
|
45
|
+
the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
|
46
|
+
method (`str`, *optional*, defaults to `butterworth`):
|
47
|
+
Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
|
48
|
+
FreeInit low pass filter.
|
49
|
+
order (`int`, *optional*, defaults to `4`):
|
50
|
+
Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
|
51
|
+
whereas lower values lead to `gaussian` method behaviour.
|
52
|
+
spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
|
53
|
+
Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
|
54
|
+
the original implementation.
|
55
|
+
temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
|
56
|
+
Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
|
57
|
+
the original implementation.
|
58
|
+
"""
|
59
|
+
self._free_init_num_iters = num_iters
|
60
|
+
self._free_init_use_fast_sampling = use_fast_sampling
|
61
|
+
self._free_init_method = method
|
62
|
+
self._free_init_order = order
|
63
|
+
self._free_init_spatial_stop_frequency = spatial_stop_frequency
|
64
|
+
self._free_init_temporal_stop_frequency = temporal_stop_frequency
|
65
|
+
|
66
|
+
def disable_free_init(self):
|
67
|
+
"""Disables the FreeInit mechanism if enabled."""
|
68
|
+
self._free_init_num_iters = None
|
69
|
+
|
70
|
+
@property
|
71
|
+
def free_init_enabled(self):
|
72
|
+
return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None
|
73
|
+
|
74
|
+
def _get_free_init_freq_filter(
|
75
|
+
self,
|
76
|
+
shape: Tuple[int, ...],
|
77
|
+
device: Union[str, torch.dtype],
|
78
|
+
filter_type: str,
|
79
|
+
order: float,
|
80
|
+
spatial_stop_frequency: float,
|
81
|
+
temporal_stop_frequency: float,
|
82
|
+
) -> torch.Tensor:
|
83
|
+
r"""Returns the FreeInit filter based on filter type and other input conditions."""
|
84
|
+
|
85
|
+
time, height, width = shape[-3], shape[-2], shape[-1]
|
86
|
+
mask = torch.zeros(shape)
|
87
|
+
|
88
|
+
if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
|
89
|
+
return mask
|
90
|
+
|
91
|
+
if filter_type == "butterworth":
|
92
|
+
|
93
|
+
def retrieve_mask(x):
|
94
|
+
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
|
95
|
+
elif filter_type == "gaussian":
|
96
|
+
|
97
|
+
def retrieve_mask(x):
|
98
|
+
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
|
99
|
+
elif filter_type == "ideal":
|
100
|
+
|
101
|
+
def retrieve_mask(x):
|
102
|
+
return 1 if x <= spatial_stop_frequency * 2 else 0
|
103
|
+
else:
|
104
|
+
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")
|
105
|
+
|
106
|
+
for t in range(time):
|
107
|
+
for h in range(height):
|
108
|
+
for w in range(width):
|
109
|
+
d_square = (
|
110
|
+
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
|
111
|
+
+ (2 * h / height - 1) ** 2
|
112
|
+
+ (2 * w / width - 1) ** 2
|
113
|
+
)
|
114
|
+
mask[..., t, h, w] = retrieve_mask(d_square)
|
115
|
+
|
116
|
+
return mask.to(device)
|
117
|
+
|
118
|
+
def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor:
|
119
|
+
r"""Noise reinitialization."""
|
120
|
+
# FFT
|
121
|
+
x_freq = fft.fftn(x, dim=(-3, -2, -1))
|
122
|
+
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
|
123
|
+
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
|
124
|
+
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
|
125
|
+
|
126
|
+
# frequency mix
|
127
|
+
high_pass_filter = 1 - low_pass_filter
|
128
|
+
x_freq_low = x_freq * low_pass_filter
|
129
|
+
noise_freq_high = noise_freq * high_pass_filter
|
130
|
+
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
|
131
|
+
|
132
|
+
# IFFT
|
133
|
+
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
|
134
|
+
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
|
135
|
+
|
136
|
+
return x_mixed
|
137
|
+
|
138
|
+
def _apply_free_init(
|
139
|
+
self,
|
140
|
+
latents: torch.Tensor,
|
141
|
+
free_init_iteration: int,
|
142
|
+
num_inference_steps: int,
|
143
|
+
device: torch.device,
|
144
|
+
dtype: torch.dtype,
|
145
|
+
generator: torch.Generator,
|
146
|
+
):
|
147
|
+
if free_init_iteration == 0:
|
148
|
+
self._free_init_initial_noise = latents.detach().clone()
|
149
|
+
return latents, self.scheduler.timesteps
|
150
|
+
|
151
|
+
latent_shape = latents.shape
|
152
|
+
|
153
|
+
free_init_filter_shape = (1, *latent_shape[1:])
|
154
|
+
free_init_freq_filter = self._get_free_init_freq_filter(
|
155
|
+
shape=free_init_filter_shape,
|
156
|
+
device=device,
|
157
|
+
filter_type=self._free_init_method,
|
158
|
+
order=self._free_init_order,
|
159
|
+
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
|
160
|
+
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
|
161
|
+
)
|
162
|
+
|
163
|
+
current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
|
164
|
+
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()
|
165
|
+
|
166
|
+
z_t = self.scheduler.add_noise(
|
167
|
+
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
|
168
|
+
).to(dtype=torch.float32)
|
169
|
+
|
170
|
+
z_rand = randn_tensor(
|
171
|
+
shape=latent_shape,
|
172
|
+
generator=generator,
|
173
|
+
device=device,
|
174
|
+
dtype=torch.float32,
|
175
|
+
)
|
176
|
+
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
|
177
|
+
latents = latents.to(dtype)
|
178
|
+
|
179
|
+
# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
|
180
|
+
if self._free_init_use_fast_sampling:
|
181
|
+
num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
|
182
|
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
183
|
+
|
184
|
+
return latents, self.scheduler.timesteps
|