diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -497,19 +497,19 @@ class TextualInversionLoaderMixin:
|
|
497
497
|
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
498
498
|
pipeline.load_textual_inversion(
|
499
499
|
state_dict["clip_l"],
|
500
|
-
|
500
|
+
tokens=["<s0>", "<s1>"],
|
501
501
|
text_encoder=pipeline.text_encoder,
|
502
502
|
tokenizer=pipeline.tokenizer,
|
503
503
|
)
|
504
504
|
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
|
505
505
|
pipeline.load_textual_inversion(
|
506
506
|
state_dict["clip_g"],
|
507
|
-
|
507
|
+
tokens=["<s0>", "<s1>"],
|
508
508
|
text_encoder=pipeline.text_encoder_2,
|
509
509
|
tokenizer=pipeline.tokenizer_2,
|
510
510
|
)
|
511
511
|
|
512
|
-
# Unload explicitly from both text encoders
|
512
|
+
# Unload explicitly from both text encoders and tokenizers
|
513
513
|
pipeline.unload_textual_inversion(
|
514
514
|
tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
|
515
515
|
)
|
@@ -561,6 +561,8 @@ class TextualInversionLoaderMixin:
|
|
561
561
|
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
|
562
562
|
key_id += 1
|
563
563
|
tokenizer._update_trie()
|
564
|
+
# set correct total vocab size after removing tokens
|
565
|
+
tokenizer._update_total_vocab_size()
|
564
566
|
|
565
567
|
# Delete from text encoder
|
566
568
|
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
|
@@ -0,0 +1,181 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from contextlib import nullcontext
|
15
|
+
|
16
|
+
from ..models.embeddings import (
|
17
|
+
ImageProjection,
|
18
|
+
MultiIPAdapterImageProjection,
|
19
|
+
)
|
20
|
+
from ..models.modeling_utils import load_model_dict_into_meta
|
21
|
+
from ..utils import (
|
22
|
+
is_accelerate_available,
|
23
|
+
is_torch_version,
|
24
|
+
logging,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
if is_accelerate_available():
|
29
|
+
pass
|
30
|
+
|
31
|
+
logger = logging.get_logger(__name__)
|
32
|
+
|
33
|
+
|
34
|
+
class FluxTransformer2DLoadersMixin:
|
35
|
+
"""
|
36
|
+
Load layers into a [`FluxTransformer2DModel`].
|
37
|
+
"""
|
38
|
+
|
39
|
+
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
|
40
|
+
if low_cpu_mem_usage:
|
41
|
+
if is_accelerate_available():
|
42
|
+
from accelerate import init_empty_weights
|
43
|
+
|
44
|
+
else:
|
45
|
+
low_cpu_mem_usage = False
|
46
|
+
logger.warning(
|
47
|
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
48
|
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
49
|
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
50
|
+
" install accelerate\n```\n."
|
51
|
+
)
|
52
|
+
|
53
|
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
54
|
+
raise NotImplementedError(
|
55
|
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
56
|
+
" `low_cpu_mem_usage=False`."
|
57
|
+
)
|
58
|
+
|
59
|
+
updated_state_dict = {}
|
60
|
+
image_projection = None
|
61
|
+
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
62
|
+
|
63
|
+
if "proj.weight" in state_dict:
|
64
|
+
# IP-Adapter
|
65
|
+
num_image_text_embeds = 4
|
66
|
+
if state_dict["proj.weight"].shape[0] == 65536:
|
67
|
+
num_image_text_embeds = 16
|
68
|
+
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
69
|
+
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
|
70
|
+
|
71
|
+
with init_context():
|
72
|
+
image_projection = ImageProjection(
|
73
|
+
cross_attention_dim=cross_attention_dim,
|
74
|
+
image_embed_dim=clip_embeddings_dim,
|
75
|
+
num_image_text_embeds=num_image_text_embeds,
|
76
|
+
)
|
77
|
+
|
78
|
+
for key, value in state_dict.items():
|
79
|
+
diffusers_name = key.replace("proj", "image_embeds")
|
80
|
+
updated_state_dict[diffusers_name] = value
|
81
|
+
|
82
|
+
if not low_cpu_mem_usage:
|
83
|
+
image_projection.load_state_dict(updated_state_dict, strict=True)
|
84
|
+
else:
|
85
|
+
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
|
86
|
+
|
87
|
+
return image_projection
|
88
|
+
|
89
|
+
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
|
90
|
+
from ..models.attention_processor import (
|
91
|
+
FluxIPAdapterJointAttnProcessor2_0,
|
92
|
+
)
|
93
|
+
|
94
|
+
if low_cpu_mem_usage:
|
95
|
+
if is_accelerate_available():
|
96
|
+
from accelerate import init_empty_weights
|
97
|
+
|
98
|
+
else:
|
99
|
+
low_cpu_mem_usage = False
|
100
|
+
logger.warning(
|
101
|
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
102
|
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
103
|
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
104
|
+
" install accelerate\n```\n."
|
105
|
+
)
|
106
|
+
|
107
|
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
108
|
+
raise NotImplementedError(
|
109
|
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
110
|
+
" `low_cpu_mem_usage=False`."
|
111
|
+
)
|
112
|
+
|
113
|
+
# set ip-adapter cross-attention processors & load state_dict
|
114
|
+
attn_procs = {}
|
115
|
+
key_id = 0
|
116
|
+
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
117
|
+
for name in self.attn_processors.keys():
|
118
|
+
if name.startswith("single_transformer_blocks"):
|
119
|
+
attn_processor_class = self.attn_processors[name].__class__
|
120
|
+
attn_procs[name] = attn_processor_class()
|
121
|
+
else:
|
122
|
+
cross_attention_dim = self.config.joint_attention_dim
|
123
|
+
hidden_size = self.inner_dim
|
124
|
+
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
125
|
+
num_image_text_embeds = []
|
126
|
+
for state_dict in state_dicts:
|
127
|
+
if "proj.weight" in state_dict["image_proj"]:
|
128
|
+
num_image_text_embed = 4
|
129
|
+
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
|
130
|
+
num_image_text_embed = 16
|
131
|
+
# IP-Adapter
|
132
|
+
num_image_text_embeds += [num_image_text_embed]
|
133
|
+
|
134
|
+
with init_context():
|
135
|
+
attn_procs[name] = attn_processor_class(
|
136
|
+
hidden_size=hidden_size,
|
137
|
+
cross_attention_dim=cross_attention_dim,
|
138
|
+
scale=1.0,
|
139
|
+
num_tokens=num_image_text_embeds,
|
140
|
+
dtype=self.dtype,
|
141
|
+
device=self.device,
|
142
|
+
)
|
143
|
+
|
144
|
+
value_dict = {}
|
145
|
+
for i, state_dict in enumerate(state_dicts):
|
146
|
+
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
147
|
+
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
148
|
+
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
|
149
|
+
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
|
150
|
+
|
151
|
+
if not low_cpu_mem_usage:
|
152
|
+
attn_procs[name].load_state_dict(value_dict)
|
153
|
+
else:
|
154
|
+
device = self.device
|
155
|
+
dtype = self.dtype
|
156
|
+
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
|
157
|
+
|
158
|
+
key_id += 1
|
159
|
+
|
160
|
+
return attn_procs
|
161
|
+
|
162
|
+
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
|
163
|
+
if not isinstance(state_dicts, list):
|
164
|
+
state_dicts = [state_dicts]
|
165
|
+
|
166
|
+
self.encoder_hid_proj = None
|
167
|
+
|
168
|
+
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
169
|
+
self.set_attn_processor(attn_procs)
|
170
|
+
|
171
|
+
image_projection_layers = []
|
172
|
+
for state_dict in state_dicts:
|
173
|
+
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
174
|
+
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
175
|
+
)
|
176
|
+
image_projection_layers.append(image_projection_layer)
|
177
|
+
|
178
|
+
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
179
|
+
self.config.encoder_hid_dim_type = "ip_image_proj"
|
180
|
+
|
181
|
+
self.to(dtype=self.dtype, device=self.device)
|
@@ -0,0 +1,89 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from typing import Dict
|
15
|
+
|
16
|
+
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
17
|
+
from ..models.embeddings import IPAdapterTimeImageProjection
|
18
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
19
|
+
|
20
|
+
|
21
|
+
class SD3Transformer2DLoadersMixin:
|
22
|
+
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
|
23
|
+
|
24
|
+
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
|
25
|
+
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
state_dict (`Dict`):
|
29
|
+
State dict with keys "ip_adapter", which contains parameters for attention processors, and
|
30
|
+
"image_proj", which contains parameters for image projection net.
|
31
|
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
32
|
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
33
|
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
34
|
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
35
|
+
argument to `True` will raise an error.
|
36
|
+
"""
|
37
|
+
# IP-Adapter cross attention parameters
|
38
|
+
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
|
39
|
+
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
|
40
|
+
timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
|
41
|
+
|
42
|
+
# Dict where key is transformer layer index, value is attention processor's state dict
|
43
|
+
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
|
44
|
+
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
|
45
|
+
for key, weights in state_dict["ip_adapter"].items():
|
46
|
+
idx, name = key.split(".", maxsplit=1)
|
47
|
+
layer_state_dict[int(idx)][name] = weights
|
48
|
+
|
49
|
+
# Create IP-Adapter attention processor
|
50
|
+
attn_procs = {}
|
51
|
+
for idx, name in enumerate(self.attn_processors.keys()):
|
52
|
+
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
|
53
|
+
hidden_size=hidden_size,
|
54
|
+
ip_hidden_states_dim=ip_hidden_states_dim,
|
55
|
+
head_dim=self.config.attention_head_dim,
|
56
|
+
timesteps_emb_dim=timesteps_emb_dim,
|
57
|
+
).to(self.device, dtype=self.dtype)
|
58
|
+
|
59
|
+
if not low_cpu_mem_usage:
|
60
|
+
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
61
|
+
else:
|
62
|
+
load_model_dict_into_meta(
|
63
|
+
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
|
64
|
+
)
|
65
|
+
|
66
|
+
self.set_attn_processor(attn_procs)
|
67
|
+
|
68
|
+
# Image projetion parameters
|
69
|
+
embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
|
70
|
+
output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
|
71
|
+
hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
|
72
|
+
heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
|
73
|
+
num_queries = state_dict["image_proj"]["latents"].shape[1]
|
74
|
+
timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
|
75
|
+
|
76
|
+
# Image projection
|
77
|
+
self.image_proj = IPAdapterTimeImageProjection(
|
78
|
+
embed_dim=embed_dim,
|
79
|
+
output_dim=output_dim,
|
80
|
+
hidden_dim=hidden_dim,
|
81
|
+
heads=heads,
|
82
|
+
num_queries=num_queries,
|
83
|
+
timestep_in_dim=timestep_in_dim,
|
84
|
+
).to(device=self.device, dtype=self.dtype)
|
85
|
+
|
86
|
+
if not low_cpu_mem_usage:
|
87
|
+
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
|
88
|
+
else:
|
89
|
+
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
|
diffusers/loaders/unet.py
CHANGED
@@ -36,6 +36,7 @@ from ..utils import (
|
|
36
36
|
USE_PEFT_BACKEND,
|
37
37
|
_get_model_file,
|
38
38
|
convert_unet_state_dict_to_peft,
|
39
|
+
deprecate,
|
39
40
|
get_adapter_name,
|
40
41
|
get_peft_kwargs,
|
41
42
|
is_accelerate_available,
|
@@ -115,6 +116,9 @@ class UNet2DConditionLoadersMixin:
|
|
115
116
|
`default_{i}` where i is the total number of adapters being loaded.
|
116
117
|
weight_name (`str`, *optional*, defaults to None):
|
117
118
|
Name of the serialized state dict file.
|
119
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
120
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
121
|
+
weights.
|
118
122
|
|
119
123
|
Example:
|
120
124
|
|
@@ -142,8 +146,14 @@ class UNet2DConditionLoadersMixin:
|
|
142
146
|
adapter_name = kwargs.pop("adapter_name", None)
|
143
147
|
_pipeline = kwargs.pop("_pipeline", None)
|
144
148
|
network_alphas = kwargs.pop("network_alphas", None)
|
149
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
145
150
|
allow_pickle = False
|
146
151
|
|
152
|
+
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
|
153
|
+
raise ValueError(
|
154
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
155
|
+
)
|
156
|
+
|
147
157
|
if use_safetensors is None:
|
148
158
|
use_safetensors = True
|
149
159
|
allow_pickle = True
|
@@ -200,6 +210,10 @@ class UNet2DConditionLoadersMixin:
|
|
200
210
|
is_model_cpu_offload = False
|
201
211
|
is_sequential_cpu_offload = False
|
202
212
|
|
213
|
+
if is_lora:
|
214
|
+
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
|
215
|
+
deprecate("load_attn_procs", "0.40.0", deprecation_message)
|
216
|
+
|
203
217
|
if is_custom_diffusion:
|
204
218
|
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
205
219
|
elif is_lora:
|
@@ -209,6 +223,7 @@ class UNet2DConditionLoadersMixin:
|
|
209
223
|
network_alphas=network_alphas,
|
210
224
|
adapter_name=adapter_name,
|
211
225
|
_pipeline=_pipeline,
|
226
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
212
227
|
)
|
213
228
|
else:
|
214
229
|
raise ValueError(
|
@@ -268,7 +283,9 @@ class UNet2DConditionLoadersMixin:
|
|
268
283
|
|
269
284
|
return attn_processors
|
270
285
|
|
271
|
-
def _process_lora(
|
286
|
+
def _process_lora(
|
287
|
+
self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
|
288
|
+
):
|
272
289
|
# This method does the following things:
|
273
290
|
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
|
274
291
|
# format. For legacy format no filtering is applied.
|
@@ -335,18 +352,37 @@ class UNet2DConditionLoadersMixin:
|
|
335
352
|
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
336
353
|
# otherwise loading LoRA weights will lead to an error
|
337
354
|
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
355
|
+
peft_kwargs = {}
|
356
|
+
if is_peft_version(">=", "0.13.1"):
|
357
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
338
358
|
|
339
|
-
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
|
340
|
-
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
|
359
|
+
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
360
|
+
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
341
361
|
|
362
|
+
warn_msg = ""
|
342
363
|
if incompatible_keys is not None:
|
343
|
-
#
|
364
|
+
# Check only for unexpected keys.
|
344
365
|
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
345
366
|
if unexpected_keys:
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
367
|
+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
368
|
+
if lora_unexpected_keys:
|
369
|
+
warn_msg = (
|
370
|
+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
371
|
+
f" {', '.join(lora_unexpected_keys)}. "
|
372
|
+
)
|
373
|
+
|
374
|
+
# Filter missing keys specific to the current adapter.
|
375
|
+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
376
|
+
if missing_keys:
|
377
|
+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
378
|
+
if lora_missing_keys:
|
379
|
+
warn_msg += (
|
380
|
+
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
381
|
+
f" {', '.join(lora_missing_keys)}."
|
382
|
+
)
|
383
|
+
|
384
|
+
if warn_msg:
|
385
|
+
logger.warning(warn_msg)
|
350
386
|
|
351
387
|
return is_model_cpu_offload, is_sequential_cpu_offload
|
352
388
|
|
@@ -456,6 +492,9 @@ class UNet2DConditionLoadersMixin:
|
|
456
492
|
)
|
457
493
|
state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
|
458
494
|
else:
|
495
|
+
deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
|
496
|
+
deprecate("save_attn_procs", "0.40.0", deprecation_message)
|
497
|
+
|
459
498
|
if not USE_PEFT_BACKEND:
|
460
499
|
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
|
461
500
|
|
@@ -734,6 +773,7 @@ class UNet2DConditionLoadersMixin:
|
|
734
773
|
from ..models.attention_processor import (
|
735
774
|
IPAdapterAttnProcessor,
|
736
775
|
IPAdapterAttnProcessor2_0,
|
776
|
+
IPAdapterXFormersAttnProcessor,
|
737
777
|
)
|
738
778
|
|
739
779
|
if low_cpu_mem_usage:
|
@@ -773,11 +813,15 @@ class UNet2DConditionLoadersMixin:
|
|
773
813
|
if cross_attention_dim is None or "motion_modules" in name:
|
774
814
|
attn_processor_class = self.attn_processors[name].__class__
|
775
815
|
attn_procs[name] = attn_processor_class()
|
776
|
-
|
777
816
|
else:
|
778
|
-
|
779
|
-
|
780
|
-
|
817
|
+
if "XFormers" in str(self.attn_processors[name].__class__):
|
818
|
+
attn_processor_class = IPAdapterXFormersAttnProcessor
|
819
|
+
else:
|
820
|
+
attn_processor_class = (
|
821
|
+
IPAdapterAttnProcessor2_0
|
822
|
+
if hasattr(F, "scaled_dot_product_attention")
|
823
|
+
else IPAdapterAttnProcessor
|
824
|
+
)
|
781
825
|
num_image_text_embeds = []
|
782
826
|
for state_dict in state_dicts:
|
783
827
|
if "proj.weight" in state_dict["image_proj"]:
|
diffusers/models/__init__.py
CHANGED
@@ -27,18 +27,29 @@ _import_structure = {}
|
|
27
27
|
if is_torch_available():
|
28
28
|
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
29
29
|
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
30
|
+
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
|
30
31
|
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
32
|
+
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
|
31
33
|
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
34
|
+
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
35
|
+
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
36
|
+
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
32
37
|
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
33
38
|
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
34
39
|
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
35
40
|
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
36
41
|
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
37
|
-
_import_structure["controlnet"] = ["ControlNetModel"]
|
38
|
-
_import_structure["
|
39
|
-
_import_structure["
|
40
|
-
|
41
|
-
|
42
|
+
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
|
43
|
+
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
|
44
|
+
_import_structure["controlnets.controlnet_hunyuan"] = [
|
45
|
+
"HunyuanDiT2DControlNetModel",
|
46
|
+
"HunyuanDiT2DMultiControlNetModel",
|
47
|
+
]
|
48
|
+
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
49
|
+
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
50
|
+
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
|
51
|
+
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
52
|
+
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
|
42
53
|
_import_structure["embeddings"] = ["ImageProjection"]
|
43
54
|
_import_structure["modeling_utils"] = ["ModelMixin"]
|
44
55
|
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
|
@@ -50,10 +61,16 @@ if is_torch_available():
|
|
50
61
|
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
|
51
62
|
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
|
52
63
|
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
64
|
+
_import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
|
53
65
|
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
|
54
66
|
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
55
67
|
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
68
|
+
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
69
|
+
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
56
70
|
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
71
|
+
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
72
|
+
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
73
|
+
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
57
74
|
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
58
75
|
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
59
76
|
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
@@ -68,7 +85,7 @@ if is_torch_available():
|
|
68
85
|
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
69
86
|
|
70
87
|
if is_flax_available():
|
71
|
-
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
|
88
|
+
_import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
|
72
89
|
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
73
90
|
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
|
74
91
|
|
@@ -78,32 +95,52 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
78
95
|
from .adapter import MultiAdapter, T2IAdapter
|
79
96
|
from .autoencoders import (
|
80
97
|
AsymmetricAutoencoderKL,
|
98
|
+
AutoencoderDC,
|
81
99
|
AutoencoderKL,
|
100
|
+
AutoencoderKLAllegro,
|
82
101
|
AutoencoderKLCogVideoX,
|
102
|
+
AutoencoderKLHunyuanVideo,
|
103
|
+
AutoencoderKLLTXVideo,
|
104
|
+
AutoencoderKLMochi,
|
83
105
|
AutoencoderKLTemporalDecoder,
|
84
106
|
AutoencoderOobleck,
|
85
107
|
AutoencoderTiny,
|
86
108
|
ConsistencyDecoderVAE,
|
87
109
|
VQModel,
|
88
110
|
)
|
89
|
-
from .
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
111
|
+
from .controlnets import (
|
112
|
+
ControlNetModel,
|
113
|
+
ControlNetUnionModel,
|
114
|
+
ControlNetXSAdapter,
|
115
|
+
FluxControlNetModel,
|
116
|
+
FluxMultiControlNetModel,
|
117
|
+
HunyuanDiT2DControlNetModel,
|
118
|
+
HunyuanDiT2DMultiControlNetModel,
|
119
|
+
MultiControlNetModel,
|
120
|
+
SD3ControlNetModel,
|
121
|
+
SD3MultiControlNetModel,
|
122
|
+
SparseControlNetModel,
|
123
|
+
UNetControlNetXSModel,
|
124
|
+
)
|
94
125
|
from .embeddings import ImageProjection
|
95
126
|
from .modeling_utils import ModelMixin
|
96
127
|
from .transformers import (
|
128
|
+
AllegroTransformer3DModel,
|
97
129
|
AuraFlowTransformer2DModel,
|
98
130
|
CogVideoXTransformer3DModel,
|
131
|
+
CogView3PlusTransformer2DModel,
|
99
132
|
DiTTransformer2DModel,
|
100
133
|
DualTransformer2DModel,
|
101
134
|
FluxTransformer2DModel,
|
102
135
|
HunyuanDiT2DModel,
|
136
|
+
HunyuanVideoTransformer3DModel,
|
103
137
|
LatteTransformer3DModel,
|
138
|
+
LTXVideoTransformer3DModel,
|
104
139
|
LuminaNextDiT2DModel,
|
140
|
+
MochiTransformer3DModel,
|
105
141
|
PixArtTransformer2DModel,
|
106
142
|
PriorTransformer,
|
143
|
+
SanaTransformer2DModel,
|
107
144
|
SD3Transformer2DModel,
|
108
145
|
StableAudioDiTModel,
|
109
146
|
T5FilmDecoder,
|
@@ -125,7 +162,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
125
162
|
)
|
126
163
|
|
127
164
|
if is_flax_available():
|
128
|
-
from .
|
165
|
+
from .controlnets import FlaxControlNetModel
|
129
166
|
from .unets import FlaxUNet2DConditionModel
|
130
167
|
from .vae_flax import FlaxAutoencoderKL
|
131
168
|
|
diffusers/models/activations.py
CHANGED
@@ -18,7 +18,7 @@ import torch.nn.functional as F
|
|
18
18
|
from torch import nn
|
19
19
|
|
20
20
|
from ..utils import deprecate
|
21
|
-
from ..utils.import_utils import is_torch_npu_available
|
21
|
+
from ..utils.import_utils import is_torch_npu_available, is_torch_version
|
22
22
|
|
23
23
|
|
24
24
|
if is_torch_npu_available():
|
@@ -79,10 +79,10 @@ class GELU(nn.Module):
|
|
79
79
|
self.approximate = approximate
|
80
80
|
|
81
81
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
82
|
-
if gate.device.type
|
83
|
-
|
84
|
-
|
85
|
-
return F.gelu(gate
|
82
|
+
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
|
83
|
+
# fp16 gelu not supported on mps before torch 2.0
|
84
|
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
85
|
+
return F.gelu(gate, approximate=self.approximate)
|
86
86
|
|
87
87
|
def forward(self, hidden_states):
|
88
88
|
hidden_states = self.proj(hidden_states)
|
@@ -105,10 +105,10 @@ class GEGLU(nn.Module):
|
|
105
105
|
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
106
106
|
|
107
107
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
108
|
-
if gate.device.type
|
109
|
-
|
110
|
-
|
111
|
-
return F.gelu(gate
|
108
|
+
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
|
109
|
+
# fp16 gelu not supported on mps before torch 2.0
|
110
|
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
111
|
+
return F.gelu(gate)
|
112
112
|
|
113
113
|
def forward(self, hidden_states, *args, **kwargs):
|
114
114
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
@@ -136,6 +136,7 @@ class SwiGLU(nn.Module):
|
|
136
136
|
|
137
137
|
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
138
138
|
super().__init__()
|
139
|
+
|
139
140
|
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
140
141
|
self.activation = nn.SiLU()
|
141
142
|
|
@@ -163,3 +164,15 @@ class ApproximateGELU(nn.Module):
|
|
163
164
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
164
165
|
x = self.proj(x)
|
165
166
|
return x * torch.sigmoid(1.702 * x)
|
167
|
+
|
168
|
+
|
169
|
+
class LinearActivation(nn.Module):
|
170
|
+
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
|
171
|
+
super().__init__()
|
172
|
+
|
173
|
+
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
174
|
+
self.activation = get_activation(activation)
|
175
|
+
|
176
|
+
def forward(self, hidden_states):
|
177
|
+
hidden_states = self.proj(hidden_states)
|
178
|
+
return self.activation(hidden_states)
|