diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
diffusers/image_processor.py
CHANGED
@@ -29,15 +29,34 @@ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
|
29
29
|
PipelineImageInput = Union[
|
30
30
|
PIL.Image.Image,
|
31
31
|
np.ndarray,
|
32
|
-
torch.
|
32
|
+
torch.Tensor,
|
33
33
|
List[PIL.Image.Image],
|
34
34
|
List[np.ndarray],
|
35
|
-
List[torch.
|
35
|
+
List[torch.Tensor],
|
36
36
|
]
|
37
37
|
|
38
38
|
PipelineDepthInput = PipelineImageInput
|
39
39
|
|
40
40
|
|
41
|
+
def is_valid_image(image):
|
42
|
+
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
|
43
|
+
|
44
|
+
|
45
|
+
def is_valid_image_imagelist(images):
|
46
|
+
# check if the image input is one of the supported formats for image and image list:
|
47
|
+
# it can be either one of below 3
|
48
|
+
# (1) a 4d pytorch tensor or numpy array,
|
49
|
+
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
|
50
|
+
# (3) a list of valid image
|
51
|
+
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
|
52
|
+
return True
|
53
|
+
elif is_valid_image(images):
|
54
|
+
return True
|
55
|
+
elif isinstance(images, list):
|
56
|
+
return all(is_valid_image(image) for image in images)
|
57
|
+
return False
|
58
|
+
|
59
|
+
|
41
60
|
class VaeImageProcessor(ConfigMixin):
|
42
61
|
"""
|
43
62
|
Image processor for VAE.
|
@@ -80,7 +99,6 @@ class VaeImageProcessor(ConfigMixin):
|
|
80
99
|
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
|
81
100
|
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
|
82
101
|
)
|
83
|
-
self.config.do_convert_rgb = False
|
84
102
|
|
85
103
|
@staticmethod
|
86
104
|
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
@@ -111,7 +129,7 @@ class VaeImageProcessor(ConfigMixin):
|
|
111
129
|
return images
|
112
130
|
|
113
131
|
@staticmethod
|
114
|
-
def numpy_to_pt(images: np.ndarray) -> torch.
|
132
|
+
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
|
115
133
|
"""
|
116
134
|
Convert a NumPy image to a PyTorch tensor.
|
117
135
|
"""
|
@@ -122,7 +140,7 @@ class VaeImageProcessor(ConfigMixin):
|
|
122
140
|
return images
|
123
141
|
|
124
142
|
@staticmethod
|
125
|
-
def pt_to_numpy(images: torch.
|
143
|
+
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
126
144
|
"""
|
127
145
|
Convert a PyTorch tensor to a NumPy image.
|
128
146
|
"""
|
@@ -173,8 +191,9 @@ class VaeImageProcessor(ConfigMixin):
|
|
173
191
|
@staticmethod
|
174
192
|
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
|
175
193
|
"""
|
176
|
-
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
|
177
|
-
for example, if user drew mask in a 128x32 region, and the dimensions for
|
194
|
+
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
|
195
|
+
ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
|
196
|
+
processing are 512x512, the region will be expanded to 128x128.
|
178
197
|
|
179
198
|
Args:
|
180
199
|
mask_image (PIL.Image.Image): Mask image.
|
@@ -183,7 +202,8 @@ class VaeImageProcessor(ConfigMixin):
|
|
183
202
|
pad (int, optional): Padding to be added to the crop region. Defaults to 0.
|
184
203
|
|
185
204
|
Returns:
|
186
|
-
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
|
205
|
+
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
|
206
|
+
matches the original aspect ratio.
|
187
207
|
"""
|
188
208
|
|
189
209
|
mask_image = mask_image.convert("L")
|
@@ -265,7 +285,8 @@ class VaeImageProcessor(ConfigMixin):
|
|
265
285
|
height: int,
|
266
286
|
) -> PIL.Image.Image:
|
267
287
|
"""
|
268
|
-
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
288
|
+
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
289
|
+
the image within the dimensions, filling empty with data from image.
|
269
290
|
|
270
291
|
Args:
|
271
292
|
image: The image to resize.
|
@@ -309,7 +330,8 @@ class VaeImageProcessor(ConfigMixin):
|
|
309
330
|
height: int,
|
310
331
|
) -> PIL.Image.Image:
|
311
332
|
"""
|
312
|
-
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
333
|
+
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
334
|
+
the image within the dimensions, cropping the excess.
|
313
335
|
|
314
336
|
Args:
|
315
337
|
image: The image to resize.
|
@@ -346,12 +368,12 @@ class VaeImageProcessor(ConfigMixin):
|
|
346
368
|
The width to resize to.
|
347
369
|
resize_mode (`str`, *optional*, defaults to `default`):
|
348
370
|
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
349
|
-
within the specified width and height, and it may not maintaining the original aspect ratio.
|
350
|
-
|
351
|
-
within the dimensions, filling empty with data from image.
|
352
|
-
|
353
|
-
within the dimensions, cropping the excess.
|
354
|
-
|
371
|
+
within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
|
372
|
+
will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
|
373
|
+
then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
|
374
|
+
the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
375
|
+
the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
376
|
+
supported for PIL image input.
|
355
377
|
|
356
378
|
Returns:
|
357
379
|
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
@@ -456,19 +478,21 @@ class VaeImageProcessor(ConfigMixin):
|
|
456
478
|
|
457
479
|
Args:
|
458
480
|
image (`pipeline_image_input`):
|
459
|
-
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
481
|
+
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
482
|
+
supported formats.
|
460
483
|
height (`int`, *optional*, defaults to `None`):
|
461
|
-
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
484
|
+
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
485
|
+
height.
|
462
486
|
width (`int`, *optional*`, defaults to `None`):
|
463
|
-
The width in preprocessed. If `None`, will use
|
487
|
+
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
464
488
|
resize_mode (`str`, *optional*, defaults to `default`):
|
465
|
-
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
466
|
-
|
467
|
-
|
468
|
-
within the dimensions, filling empty with data from image.
|
469
|
-
|
470
|
-
within the dimensions, cropping the excess.
|
471
|
-
|
489
|
+
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
490
|
+
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
|
491
|
+
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
|
492
|
+
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
|
493
|
+
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
494
|
+
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
495
|
+
supported for PIL image input.
|
472
496
|
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
473
497
|
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
474
498
|
"""
|
@@ -492,12 +516,27 @@ class VaeImageProcessor(ConfigMixin):
|
|
492
516
|
else:
|
493
517
|
image = np.expand_dims(image, axis=-1)
|
494
518
|
|
495
|
-
if isinstance(image,
|
496
|
-
|
497
|
-
|
519
|
+
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
520
|
+
warnings.warn(
|
521
|
+
"Passing `image` as a list of 4d np.ndarray is deprecated."
|
522
|
+
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
523
|
+
FutureWarning,
|
524
|
+
)
|
525
|
+
image = np.concatenate(image, axis=0)
|
526
|
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
527
|
+
warnings.warn(
|
528
|
+
"Passing `image` as a list of 4d torch.Tensor is deprecated."
|
529
|
+
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
530
|
+
FutureWarning,
|
531
|
+
)
|
532
|
+
image = torch.cat(image, axis=0)
|
533
|
+
|
534
|
+
if not is_valid_image_imagelist(image):
|
498
535
|
raise ValueError(
|
499
|
-
f"Input is in incorrect format
|
536
|
+
f"Input is in incorrect format. Currently, we only support {', '.join(supported_formats)}"
|
500
537
|
)
|
538
|
+
if not isinstance(image, list):
|
539
|
+
image = [image]
|
501
540
|
|
502
541
|
if isinstance(image[0], PIL.Image.Image):
|
503
542
|
if crops_coords is not None:
|
@@ -556,15 +595,15 @@ class VaeImageProcessor(ConfigMixin):
|
|
556
595
|
|
557
596
|
def postprocess(
|
558
597
|
self,
|
559
|
-
image: torch.
|
598
|
+
image: torch.Tensor,
|
560
599
|
output_type: str = "pil",
|
561
600
|
do_denormalize: Optional[List[bool]] = None,
|
562
|
-
) -> Union[PIL.Image.Image, np.ndarray, torch.
|
601
|
+
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
563
602
|
"""
|
564
603
|
Postprocess the image output from tensor to `output_type`.
|
565
604
|
|
566
605
|
Args:
|
567
|
-
image (`torch.
|
606
|
+
image (`torch.Tensor`):
|
568
607
|
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
569
608
|
output_type (`str`, *optional*, defaults to `pil`):
|
570
609
|
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
@@ -573,7 +612,7 @@ class VaeImageProcessor(ConfigMixin):
|
|
573
612
|
`VaeImageProcessor` config.
|
574
613
|
|
575
614
|
Returns:
|
576
|
-
`PIL.Image.Image`, `np.ndarray` or `torch.
|
615
|
+
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
577
616
|
The postprocessed image.
|
578
617
|
"""
|
579
618
|
if not isinstance(image, torch.Tensor):
|
@@ -733,15 +772,15 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
|
733
772
|
|
734
773
|
def postprocess(
|
735
774
|
self,
|
736
|
-
image: torch.
|
775
|
+
image: torch.Tensor,
|
737
776
|
output_type: str = "pil",
|
738
777
|
do_denormalize: Optional[List[bool]] = None,
|
739
|
-
) -> Union[PIL.Image.Image, np.ndarray, torch.
|
778
|
+
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
740
779
|
"""
|
741
780
|
Postprocess the image output from tensor to `output_type`.
|
742
781
|
|
743
782
|
Args:
|
744
|
-
image (`torch.
|
783
|
+
image (`torch.Tensor`):
|
745
784
|
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
746
785
|
output_type (`str`, *optional*, defaults to `pil`):
|
747
786
|
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
@@ -750,7 +789,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
|
750
789
|
`VaeImageProcessor` config.
|
751
790
|
|
752
791
|
Returns:
|
753
|
-
`PIL.Image.Image`, `np.ndarray` or `torch.
|
792
|
+
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
754
793
|
The postprocessed image.
|
755
794
|
"""
|
756
795
|
if not isinstance(image, torch.Tensor):
|
@@ -788,8 +827,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
|
788
827
|
|
789
828
|
def preprocess(
|
790
829
|
self,
|
791
|
-
rgb: Union[torch.
|
792
|
-
depth: Union[torch.
|
830
|
+
rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
|
831
|
+
depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
|
793
832
|
height: Optional[int] = None,
|
794
833
|
width: Optional[int] = None,
|
795
834
|
target_res: Optional[int] = None,
|
@@ -928,13 +967,13 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
|
|
928
967
|
)
|
929
968
|
|
930
969
|
@staticmethod
|
931
|
-
def downsample(mask: torch.
|
970
|
+
def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
|
932
971
|
"""
|
933
|
-
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
|
934
|
-
|
972
|
+
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
|
973
|
+
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
|
935
974
|
|
936
975
|
Args:
|
937
|
-
mask (`torch.
|
976
|
+
mask (`torch.Tensor`):
|
938
977
|
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
|
939
978
|
batch_size (`int`):
|
940
979
|
The batch size.
|
@@ -944,7 +983,7 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
|
|
944
983
|
The dimensionality of the value embeddings.
|
945
984
|
|
946
985
|
Returns:
|
947
|
-
`torch.
|
986
|
+
`torch.Tensor`:
|
948
987
|
The downsampled mask tensor.
|
949
988
|
|
950
989
|
"""
|
@@ -988,3 +1027,77 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
|
|
988
1027
|
)
|
989
1028
|
|
990
1029
|
return mask_downsample
|
1030
|
+
|
1031
|
+
|
1032
|
+
class PixArtImageProcessor(VaeImageProcessor):
|
1033
|
+
"""
|
1034
|
+
Image processor for PixArt image resize and crop.
|
1035
|
+
|
1036
|
+
Args:
|
1037
|
+
do_resize (`bool`, *optional*, defaults to `True`):
|
1038
|
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
1039
|
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
1040
|
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
1041
|
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
1042
|
+
resample (`str`, *optional*, defaults to `lanczos`):
|
1043
|
+
Resampling filter to use when resizing the image.
|
1044
|
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
1045
|
+
Whether to normalize the image to [-1,1].
|
1046
|
+
do_binarize (`bool`, *optional*, defaults to `False`):
|
1047
|
+
Whether to binarize the image to 0/1.
|
1048
|
+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
1049
|
+
Whether to convert the images to RGB format.
|
1050
|
+
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
1051
|
+
Whether to convert the images to grayscale format.
|
1052
|
+
"""
|
1053
|
+
|
1054
|
+
@register_to_config
|
1055
|
+
def __init__(
|
1056
|
+
self,
|
1057
|
+
do_resize: bool = True,
|
1058
|
+
vae_scale_factor: int = 8,
|
1059
|
+
resample: str = "lanczos",
|
1060
|
+
do_normalize: bool = True,
|
1061
|
+
do_binarize: bool = False,
|
1062
|
+
do_convert_grayscale: bool = False,
|
1063
|
+
):
|
1064
|
+
super().__init__(
|
1065
|
+
do_resize=do_resize,
|
1066
|
+
vae_scale_factor=vae_scale_factor,
|
1067
|
+
resample=resample,
|
1068
|
+
do_normalize=do_normalize,
|
1069
|
+
do_binarize=do_binarize,
|
1070
|
+
do_convert_grayscale=do_convert_grayscale,
|
1071
|
+
)
|
1072
|
+
|
1073
|
+
@staticmethod
|
1074
|
+
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
1075
|
+
"""Returns binned height and width."""
|
1076
|
+
ar = float(height / width)
|
1077
|
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
1078
|
+
default_hw = ratios[closest_ratio]
|
1079
|
+
return int(default_hw[0]), int(default_hw[1])
|
1080
|
+
|
1081
|
+
@staticmethod
|
1082
|
+
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
|
1083
|
+
orig_height, orig_width = samples.shape[2], samples.shape[3]
|
1084
|
+
|
1085
|
+
# Check if resizing is needed
|
1086
|
+
if orig_height != new_height or orig_width != new_width:
|
1087
|
+
ratio = max(new_height / orig_height, new_width / orig_width)
|
1088
|
+
resized_width = int(orig_width * ratio)
|
1089
|
+
resized_height = int(orig_height * ratio)
|
1090
|
+
|
1091
|
+
# Resize
|
1092
|
+
samples = F.interpolate(
|
1093
|
+
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
1094
|
+
)
|
1095
|
+
|
1096
|
+
# Center Crop
|
1097
|
+
start_x = (resized_width - new_width) // 2
|
1098
|
+
end_x = start_x + new_width
|
1099
|
+
start_y = (resized_height - new_height) // 2
|
1100
|
+
end_y = start_y + new_height
|
1101
|
+
samples = samples[:, :, start_y:end_y, start_x:end_x]
|
1102
|
+
|
1103
|
+
return samples
|
diffusers/loaders/__init__.py
CHANGED
@@ -54,9 +54,7 @@ if is_transformers_available():
|
|
54
54
|
_import_structure = {}
|
55
55
|
|
56
56
|
if is_torch_available():
|
57
|
-
_import_structure["
|
58
|
-
|
59
|
-
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
|
57
|
+
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
|
60
58
|
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
61
59
|
_import_structure["utils"] = ["AttnProcsLayers"]
|
62
60
|
if is_transformers_available():
|
@@ -70,8 +68,7 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
|
70
68
|
|
71
69
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
72
70
|
if is_torch_available():
|
73
|
-
from .
|
74
|
-
from .controlnet import FromOriginalControlNetMixin
|
71
|
+
from .single_file_model import FromOriginalModelMixin
|
75
72
|
from .unet import UNet2DConditionLoadersMixin
|
76
73
|
from .utils import AttnProcsLayers
|
77
74
|
|
diffusers/loaders/autoencoder.py
CHANGED
@@ -50,9 +50,9 @@ class FromOriginalVAEMixin:
|
|
50
50
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
51
51
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
52
52
|
is not used.
|
53
|
-
resume_download
|
54
|
-
|
55
|
-
|
53
|
+
resume_download:
|
54
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
55
|
+
of Diffusers.
|
56
56
|
proxies (`Dict[str, str]`, *optional*):
|
57
57
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
58
58
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -99,7 +99,7 @@ class FromOriginalVAEMixin:
|
|
99
99
|
|
100
100
|
original_config_file = kwargs.pop("original_config_file", None)
|
101
101
|
config_file = kwargs.pop("config_file", None)
|
102
|
-
resume_download = kwargs.pop("resume_download",
|
102
|
+
resume_download = kwargs.pop("resume_download", None)
|
103
103
|
force_download = kwargs.pop("force_download", False)
|
104
104
|
proxies = kwargs.pop("proxies", None)
|
105
105
|
token = kwargs.pop("token", None)
|
diffusers/loaders/controlnet.py
CHANGED
@@ -50,9 +50,9 @@ class FromOriginalControlNetMixin:
|
|
50
50
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
51
51
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
52
52
|
is not used.
|
53
|
-
resume_download
|
54
|
-
|
55
|
-
|
53
|
+
resume_download:
|
54
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
55
|
+
of Diffusers.
|
56
56
|
proxies (`Dict[str, str]`, *optional*):
|
57
57
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
58
58
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -89,7 +89,7 @@ class FromOriginalControlNetMixin:
|
|
89
89
|
"""
|
90
90
|
original_config_file = kwargs.pop("original_config_file", None)
|
91
91
|
config_file = kwargs.pop("config_file", None)
|
92
|
-
resume_download = kwargs.pop("resume_download",
|
92
|
+
resume_download = kwargs.pop("resume_download", None)
|
93
93
|
force_download = kwargs.pop("force_download", False)
|
94
94
|
proxies = kwargs.pop("proxies", None)
|
95
95
|
token = kwargs.pop("token", None)
|
diffusers/loaders/ip_adapter.py
CHANGED
@@ -16,17 +16,20 @@ from pathlib import Path
|
|
16
16
|
from typing import Dict, List, Optional, Union
|
17
17
|
|
18
18
|
import torch
|
19
|
+
import torch.nn.functional as F
|
19
20
|
from huggingface_hub.utils import validate_hf_hub_args
|
20
21
|
from safetensors import safe_open
|
21
22
|
|
22
|
-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
23
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
23
24
|
from ..utils import (
|
25
|
+
USE_PEFT_BACKEND,
|
24
26
|
_get_model_file,
|
25
27
|
is_accelerate_available,
|
26
28
|
is_torch_version,
|
27
29
|
is_transformers_available,
|
28
30
|
logging,
|
29
31
|
)
|
32
|
+
from .unet_loader_utils import _maybe_expand_lora_scales
|
30
33
|
|
31
34
|
|
32
35
|
if is_transformers_available():
|
@@ -36,6 +39,8 @@ if is_transformers_available():
|
|
36
39
|
)
|
37
40
|
|
38
41
|
from ..models.attention_processor import (
|
42
|
+
AttnProcessor,
|
43
|
+
AttnProcessor2_0,
|
39
44
|
IPAdapterAttnProcessor,
|
40
45
|
IPAdapterAttnProcessor2_0,
|
41
46
|
)
|
@@ -67,26 +72,27 @@ class IPAdapterMixin:
|
|
67
72
|
- A [torch state
|
68
73
|
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
69
74
|
subfolder (`str` or `List[str]`):
|
70
|
-
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
71
|
-
|
75
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
76
|
+
list is passed, it should have the same length as `weight_name`.
|
72
77
|
weight_name (`str` or `List[str]`):
|
73
78
|
The name of the weight file to load. If a list is passed, it should have the same length as
|
74
79
|
`weight_name`.
|
75
80
|
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
76
81
|
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
77
|
-
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
78
|
-
you only need to pass the name of the folder that contains image encoder weights, e.g.
|
79
|
-
If the image encoder is located in a folder other than
|
80
|
-
for example,
|
82
|
+
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
83
|
+
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
84
|
+
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
85
|
+
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
86
|
+
`image_encoder_folder="different_subfolder/image_encoder"`.
|
81
87
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
82
88
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
83
89
|
is not used.
|
84
90
|
force_download (`bool`, *optional*, defaults to `False`):
|
85
91
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
86
92
|
cached versions if they exist.
|
87
|
-
resume_download
|
88
|
-
|
89
|
-
|
93
|
+
resume_download:
|
94
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
95
|
+
of Diffusers.
|
90
96
|
proxies (`Dict[str, str]`, *optional*):
|
91
97
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
92
98
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -129,7 +135,7 @@ class IPAdapterMixin:
|
|
129
135
|
# Load the main state dict first.
|
130
136
|
cache_dir = kwargs.pop("cache_dir", None)
|
131
137
|
force_download = kwargs.pop("force_download", False)
|
132
|
-
resume_download = kwargs.pop("resume_download",
|
138
|
+
resume_download = kwargs.pop("resume_download", None)
|
133
139
|
proxies = kwargs.pop("proxies", None)
|
134
140
|
local_files_only = kwargs.pop("local_files_only", None)
|
135
141
|
token = kwargs.pop("token", None)
|
@@ -182,7 +188,7 @@ class IPAdapterMixin:
|
|
182
188
|
elif key.startswith("ip_adapter."):
|
183
189
|
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
184
190
|
else:
|
185
|
-
state_dict =
|
191
|
+
state_dict = load_state_dict(model_file)
|
186
192
|
else:
|
187
193
|
state_dict = pretrained_model_name_or_path_or_dict
|
188
194
|
|
@@ -227,27 +233,69 @@ class IPAdapterMixin:
|
|
227
233
|
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
228
234
|
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
229
235
|
|
236
|
+
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
237
|
+
if extra_loras != {}:
|
238
|
+
if not USE_PEFT_BACKEND:
|
239
|
+
logger.warning("PEFT backend is required to load these weights.")
|
240
|
+
else:
|
241
|
+
# apply the IP Adapter Face ID LoRA weights
|
242
|
+
peft_config = getattr(unet, "peft_config", {})
|
243
|
+
for k, lora in extra_loras.items():
|
244
|
+
if f"faceid_{k}" not in peft_config:
|
245
|
+
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
246
|
+
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
247
|
+
|
230
248
|
def set_ip_adapter_scale(self, scale):
|
231
249
|
"""
|
232
|
-
|
250
|
+
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
251
|
+
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
|
233
252
|
|
234
253
|
Example:
|
235
254
|
|
236
255
|
```py
|
237
|
-
|
256
|
+
# To use original IP-Adapter
|
257
|
+
scale = 1.0
|
258
|
+
pipeline.set_ip_adapter_scale(scale)
|
259
|
+
|
260
|
+
# To use style block only
|
261
|
+
scale = {
|
262
|
+
"up": {"block_0": [0.0, 1.0, 0.0]},
|
263
|
+
}
|
264
|
+
pipeline.set_ip_adapter_scale(scale)
|
265
|
+
|
266
|
+
# To use style+layout blocks
|
267
|
+
scale = {
|
268
|
+
"down": {"block_2": [0.0, 1.0]},
|
269
|
+
"up": {"block_0": [0.0, 1.0, 0.0]},
|
270
|
+
}
|
271
|
+
pipeline.set_ip_adapter_scale(scale)
|
272
|
+
|
273
|
+
# To use style and layout from 2 reference images
|
274
|
+
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
|
275
|
+
pipeline.set_ip_adapter_scale(scales)
|
238
276
|
```
|
239
277
|
"""
|
240
278
|
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
241
|
-
|
279
|
+
if not isinstance(scale, list):
|
280
|
+
scale = [scale]
|
281
|
+
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
282
|
+
|
283
|
+
for attn_name, attn_processor in unet.attn_processors.items():
|
242
284
|
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
243
|
-
if
|
244
|
-
scale = [scale] * len(attn_processor.scale)
|
245
|
-
if len(attn_processor.scale) != len(scale):
|
285
|
+
if len(scale_configs) != len(attn_processor.scale):
|
246
286
|
raise ValueError(
|
247
|
-
f"
|
248
|
-
f"
|
287
|
+
f"Cannot assign {len(scale_configs)} scale_configs to "
|
288
|
+
f"{len(attn_processor.scale)} IP-Adapter."
|
249
289
|
)
|
250
|
-
|
290
|
+
elif len(scale_configs) == 1:
|
291
|
+
scale_configs = scale_configs * len(attn_processor.scale)
|
292
|
+
for i, scale_config in enumerate(scale_configs):
|
293
|
+
if isinstance(scale_config, dict):
|
294
|
+
for k, s in scale_config.items():
|
295
|
+
if attn_name.startswith(k):
|
296
|
+
attn_processor.scale[i] = s
|
297
|
+
else:
|
298
|
+
attn_processor.scale[i] = scale_config
|
251
299
|
|
252
300
|
def unload_ip_adapter(self):
|
253
301
|
"""
|
@@ -278,4 +326,14 @@ class IPAdapterMixin:
|
|
278
326
|
self.config.encoder_hid_dim_type = None
|
279
327
|
|
280
328
|
# restore original Unet attention processors layers
|
281
|
-
|
329
|
+
attn_procs = {}
|
330
|
+
for name, value in self.unet.attn_processors.items():
|
331
|
+
attn_processor_class = (
|
332
|
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
333
|
+
)
|
334
|
+
attn_procs[name] = (
|
335
|
+
attn_processor_class
|
336
|
+
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
|
337
|
+
else value.__class__()
|
338
|
+
)
|
339
|
+
self.unet.set_attn_processor(attn_procs)
|