diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +74 -28
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,181 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from contextlib import nullcontext
|
15
|
+
|
16
|
+
from ..models.embeddings import (
|
17
|
+
ImageProjection,
|
18
|
+
MultiIPAdapterImageProjection,
|
19
|
+
)
|
20
|
+
from ..models.modeling_utils import load_model_dict_into_meta
|
21
|
+
from ..utils import (
|
22
|
+
is_accelerate_available,
|
23
|
+
is_torch_version,
|
24
|
+
logging,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
if is_accelerate_available():
|
29
|
+
pass
|
30
|
+
|
31
|
+
logger = logging.get_logger(__name__)
|
32
|
+
|
33
|
+
|
34
|
+
class FluxTransformer2DLoadersMixin:
|
35
|
+
"""
|
36
|
+
Load layers into a [`FluxTransformer2DModel`].
|
37
|
+
"""
|
38
|
+
|
39
|
+
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
|
40
|
+
if low_cpu_mem_usage:
|
41
|
+
if is_accelerate_available():
|
42
|
+
from accelerate import init_empty_weights
|
43
|
+
|
44
|
+
else:
|
45
|
+
low_cpu_mem_usage = False
|
46
|
+
logger.warning(
|
47
|
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
48
|
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
49
|
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
50
|
+
" install accelerate\n```\n."
|
51
|
+
)
|
52
|
+
|
53
|
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
54
|
+
raise NotImplementedError(
|
55
|
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
56
|
+
" `low_cpu_mem_usage=False`."
|
57
|
+
)
|
58
|
+
|
59
|
+
updated_state_dict = {}
|
60
|
+
image_projection = None
|
61
|
+
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
62
|
+
|
63
|
+
if "proj.weight" in state_dict:
|
64
|
+
# IP-Adapter
|
65
|
+
num_image_text_embeds = 4
|
66
|
+
if state_dict["proj.weight"].shape[0] == 65536:
|
67
|
+
num_image_text_embeds = 16
|
68
|
+
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
69
|
+
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
|
70
|
+
|
71
|
+
with init_context():
|
72
|
+
image_projection = ImageProjection(
|
73
|
+
cross_attention_dim=cross_attention_dim,
|
74
|
+
image_embed_dim=clip_embeddings_dim,
|
75
|
+
num_image_text_embeds=num_image_text_embeds,
|
76
|
+
)
|
77
|
+
|
78
|
+
for key, value in state_dict.items():
|
79
|
+
diffusers_name = key.replace("proj", "image_embeds")
|
80
|
+
updated_state_dict[diffusers_name] = value
|
81
|
+
|
82
|
+
if not low_cpu_mem_usage:
|
83
|
+
image_projection.load_state_dict(updated_state_dict, strict=True)
|
84
|
+
else:
|
85
|
+
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
|
86
|
+
|
87
|
+
return image_projection
|
88
|
+
|
89
|
+
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
|
90
|
+
from ..models.attention_processor import (
|
91
|
+
FluxIPAdapterJointAttnProcessor2_0,
|
92
|
+
)
|
93
|
+
|
94
|
+
if low_cpu_mem_usage:
|
95
|
+
if is_accelerate_available():
|
96
|
+
from accelerate import init_empty_weights
|
97
|
+
|
98
|
+
else:
|
99
|
+
low_cpu_mem_usage = False
|
100
|
+
logger.warning(
|
101
|
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
102
|
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
103
|
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
104
|
+
" install accelerate\n```\n."
|
105
|
+
)
|
106
|
+
|
107
|
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
108
|
+
raise NotImplementedError(
|
109
|
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
110
|
+
" `low_cpu_mem_usage=False`."
|
111
|
+
)
|
112
|
+
|
113
|
+
# set ip-adapter cross-attention processors & load state_dict
|
114
|
+
attn_procs = {}
|
115
|
+
key_id = 0
|
116
|
+
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
117
|
+
for name in self.attn_processors.keys():
|
118
|
+
if name.startswith("single_transformer_blocks"):
|
119
|
+
attn_processor_class = self.attn_processors[name].__class__
|
120
|
+
attn_procs[name] = attn_processor_class()
|
121
|
+
else:
|
122
|
+
cross_attention_dim = self.config.joint_attention_dim
|
123
|
+
hidden_size = self.inner_dim
|
124
|
+
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
125
|
+
num_image_text_embeds = []
|
126
|
+
for state_dict in state_dicts:
|
127
|
+
if "proj.weight" in state_dict["image_proj"]:
|
128
|
+
num_image_text_embed = 4
|
129
|
+
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
|
130
|
+
num_image_text_embed = 16
|
131
|
+
# IP-Adapter
|
132
|
+
num_image_text_embeds += [num_image_text_embed]
|
133
|
+
|
134
|
+
with init_context():
|
135
|
+
attn_procs[name] = attn_processor_class(
|
136
|
+
hidden_size=hidden_size,
|
137
|
+
cross_attention_dim=cross_attention_dim,
|
138
|
+
scale=1.0,
|
139
|
+
num_tokens=num_image_text_embeds,
|
140
|
+
dtype=self.dtype,
|
141
|
+
device=self.device,
|
142
|
+
)
|
143
|
+
|
144
|
+
value_dict = {}
|
145
|
+
for i, state_dict in enumerate(state_dicts):
|
146
|
+
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
147
|
+
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
148
|
+
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
|
149
|
+
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
|
150
|
+
|
151
|
+
if not low_cpu_mem_usage:
|
152
|
+
attn_procs[name].load_state_dict(value_dict)
|
153
|
+
else:
|
154
|
+
device = self.device
|
155
|
+
dtype = self.dtype
|
156
|
+
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
|
157
|
+
|
158
|
+
key_id += 1
|
159
|
+
|
160
|
+
return attn_procs
|
161
|
+
|
162
|
+
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
|
163
|
+
if not isinstance(state_dicts, list):
|
164
|
+
state_dicts = [state_dicts]
|
165
|
+
|
166
|
+
self.encoder_hid_proj = None
|
167
|
+
|
168
|
+
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
169
|
+
self.set_attn_processor(attn_procs)
|
170
|
+
|
171
|
+
image_projection_layers = []
|
172
|
+
for state_dict in state_dicts:
|
173
|
+
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
174
|
+
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
175
|
+
)
|
176
|
+
image_projection_layers.append(image_projection_layer)
|
177
|
+
|
178
|
+
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
179
|
+
self.config.encoder_hid_dim_type = "ip_image_proj"
|
180
|
+
|
181
|
+
self.to(dtype=self.dtype, device=self.device)
|
@@ -0,0 +1,89 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from typing import Dict
|
15
|
+
|
16
|
+
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
17
|
+
from ..models.embeddings import IPAdapterTimeImageProjection
|
18
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
19
|
+
|
20
|
+
|
21
|
+
class SD3Transformer2DLoadersMixin:
|
22
|
+
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
|
23
|
+
|
24
|
+
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
|
25
|
+
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
state_dict (`Dict`):
|
29
|
+
State dict with keys "ip_adapter", which contains parameters for attention processors, and
|
30
|
+
"image_proj", which contains parameters for image projection net.
|
31
|
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
32
|
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
33
|
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
34
|
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
35
|
+
argument to `True` will raise an error.
|
36
|
+
"""
|
37
|
+
# IP-Adapter cross attention parameters
|
38
|
+
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
|
39
|
+
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
|
40
|
+
timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
|
41
|
+
|
42
|
+
# Dict where key is transformer layer index, value is attention processor's state dict
|
43
|
+
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
|
44
|
+
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
|
45
|
+
for key, weights in state_dict["ip_adapter"].items():
|
46
|
+
idx, name = key.split(".", maxsplit=1)
|
47
|
+
layer_state_dict[int(idx)][name] = weights
|
48
|
+
|
49
|
+
# Create IP-Adapter attention processor
|
50
|
+
attn_procs = {}
|
51
|
+
for idx, name in enumerate(self.attn_processors.keys()):
|
52
|
+
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
|
53
|
+
hidden_size=hidden_size,
|
54
|
+
ip_hidden_states_dim=ip_hidden_states_dim,
|
55
|
+
head_dim=self.config.attention_head_dim,
|
56
|
+
timesteps_emb_dim=timesteps_emb_dim,
|
57
|
+
).to(self.device, dtype=self.dtype)
|
58
|
+
|
59
|
+
if not low_cpu_mem_usage:
|
60
|
+
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
61
|
+
else:
|
62
|
+
load_model_dict_into_meta(
|
63
|
+
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
|
64
|
+
)
|
65
|
+
|
66
|
+
self.set_attn_processor(attn_procs)
|
67
|
+
|
68
|
+
# Image projetion parameters
|
69
|
+
embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
|
70
|
+
output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
|
71
|
+
hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
|
72
|
+
heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
|
73
|
+
num_queries = state_dict["image_proj"]["latents"].shape[1]
|
74
|
+
timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
|
75
|
+
|
76
|
+
# Image projection
|
77
|
+
self.image_proj = IPAdapterTimeImageProjection(
|
78
|
+
embed_dim=embed_dim,
|
79
|
+
output_dim=output_dim,
|
80
|
+
hidden_dim=hidden_dim,
|
81
|
+
heads=heads,
|
82
|
+
num_queries=num_queries,
|
83
|
+
timestep_in_dim=timestep_in_dim,
|
84
|
+
).to(device=self.device, dtype=self.dtype)
|
85
|
+
|
86
|
+
if not low_cpu_mem_usage:
|
87
|
+
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
|
88
|
+
else:
|
89
|
+
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
|
diffusers/loaders/unet.py
CHANGED
@@ -36,6 +36,7 @@ from ..utils import (
|
|
36
36
|
USE_PEFT_BACKEND,
|
37
37
|
_get_model_file,
|
38
38
|
convert_unet_state_dict_to_peft,
|
39
|
+
deprecate,
|
39
40
|
get_adapter_name,
|
40
41
|
get_peft_kwargs,
|
41
42
|
is_accelerate_available,
|
@@ -209,6 +210,10 @@ class UNet2DConditionLoadersMixin:
|
|
209
210
|
is_model_cpu_offload = False
|
210
211
|
is_sequential_cpu_offload = False
|
211
212
|
|
213
|
+
if is_lora:
|
214
|
+
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
|
215
|
+
deprecate("load_attn_procs", "0.40.0", deprecation_message)
|
216
|
+
|
212
217
|
if is_custom_diffusion:
|
213
218
|
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
214
219
|
elif is_lora:
|
@@ -487,6 +492,9 @@ class UNet2DConditionLoadersMixin:
|
|
487
492
|
)
|
488
493
|
state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
|
489
494
|
else:
|
495
|
+
deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
|
496
|
+
deprecate("save_attn_procs", "0.40.0", deprecation_message)
|
497
|
+
|
490
498
|
if not USE_PEFT_BACKEND:
|
491
499
|
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
|
492
500
|
|
@@ -765,6 +773,7 @@ class UNet2DConditionLoadersMixin:
|
|
765
773
|
from ..models.attention_processor import (
|
766
774
|
IPAdapterAttnProcessor,
|
767
775
|
IPAdapterAttnProcessor2_0,
|
776
|
+
IPAdapterXFormersAttnProcessor,
|
768
777
|
)
|
769
778
|
|
770
779
|
if low_cpu_mem_usage:
|
@@ -804,11 +813,15 @@ class UNet2DConditionLoadersMixin:
|
|
804
813
|
if cross_attention_dim is None or "motion_modules" in name:
|
805
814
|
attn_processor_class = self.attn_processors[name].__class__
|
806
815
|
attn_procs[name] = attn_processor_class()
|
807
|
-
|
808
816
|
else:
|
809
|
-
|
810
|
-
|
811
|
-
|
817
|
+
if "XFormers" in str(self.attn_processors[name].__class__):
|
818
|
+
attn_processor_class = IPAdapterXFormersAttnProcessor
|
819
|
+
else:
|
820
|
+
attn_processor_class = (
|
821
|
+
IPAdapterAttnProcessor2_0
|
822
|
+
if hasattr(F, "scaled_dot_product_attention")
|
823
|
+
else IPAdapterAttnProcessor
|
824
|
+
)
|
812
825
|
num_image_text_embeds = []
|
813
826
|
for state_dict in state_dicts:
|
814
827
|
if "proj.weight" in state_dict["image_proj"]:
|
diffusers/models/__init__.py
CHANGED
@@ -27,19 +27,29 @@ _import_structure = {}
|
|
27
27
|
if is_torch_available():
|
28
28
|
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
29
29
|
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
30
|
+
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
|
30
31
|
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
32
|
+
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
|
31
33
|
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
34
|
+
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
35
|
+
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
36
|
+
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
32
37
|
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
33
38
|
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
34
39
|
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
35
40
|
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
36
41
|
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
37
|
-
_import_structure["controlnet"] = ["ControlNetModel"]
|
38
|
-
_import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
|
39
|
-
_import_structure["controlnet_hunyuan"] = [
|
40
|
-
|
41
|
-
|
42
|
-
|
42
|
+
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
|
43
|
+
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
|
44
|
+
_import_structure["controlnets.controlnet_hunyuan"] = [
|
45
|
+
"HunyuanDiT2DControlNetModel",
|
46
|
+
"HunyuanDiT2DMultiControlNetModel",
|
47
|
+
]
|
48
|
+
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
49
|
+
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
50
|
+
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
|
51
|
+
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
52
|
+
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
|
43
53
|
_import_structure["embeddings"] = ["ImageProjection"]
|
44
54
|
_import_structure["modeling_utils"] = ["ModelMixin"]
|
45
55
|
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
|
@@ -51,11 +61,16 @@ if is_torch_available():
|
|
51
61
|
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
|
52
62
|
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
|
53
63
|
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
64
|
+
_import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
|
54
65
|
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
|
55
66
|
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
56
67
|
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
68
|
+
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
57
69
|
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
58
70
|
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
71
|
+
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
72
|
+
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
73
|
+
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
59
74
|
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
60
75
|
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
61
76
|
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
@@ -70,7 +85,7 @@ if is_torch_available():
|
|
70
85
|
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
71
86
|
|
72
87
|
if is_flax_available():
|
73
|
-
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
|
88
|
+
_import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
|
74
89
|
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
75
90
|
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
|
76
91
|
|
@@ -80,23 +95,37 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
80
95
|
from .adapter import MultiAdapter, T2IAdapter
|
81
96
|
from .autoencoders import (
|
82
97
|
AsymmetricAutoencoderKL,
|
98
|
+
AutoencoderDC,
|
83
99
|
AutoencoderKL,
|
100
|
+
AutoencoderKLAllegro,
|
84
101
|
AutoencoderKLCogVideoX,
|
102
|
+
AutoencoderKLHunyuanVideo,
|
103
|
+
AutoencoderKLLTXVideo,
|
104
|
+
AutoencoderKLMochi,
|
85
105
|
AutoencoderKLTemporalDecoder,
|
86
106
|
AutoencoderOobleck,
|
87
107
|
AutoencoderTiny,
|
88
108
|
ConsistencyDecoderVAE,
|
89
109
|
VQModel,
|
90
110
|
)
|
91
|
-
from .
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
111
|
+
from .controlnets import (
|
112
|
+
ControlNetModel,
|
113
|
+
ControlNetUnionModel,
|
114
|
+
ControlNetXSAdapter,
|
115
|
+
FluxControlNetModel,
|
116
|
+
FluxMultiControlNetModel,
|
117
|
+
HunyuanDiT2DControlNetModel,
|
118
|
+
HunyuanDiT2DMultiControlNetModel,
|
119
|
+
MultiControlNetModel,
|
120
|
+
SD3ControlNetModel,
|
121
|
+
SD3MultiControlNetModel,
|
122
|
+
SparseControlNetModel,
|
123
|
+
UNetControlNetXSModel,
|
124
|
+
)
|
97
125
|
from .embeddings import ImageProjection
|
98
126
|
from .modeling_utils import ModelMixin
|
99
127
|
from .transformers import (
|
128
|
+
AllegroTransformer3DModel,
|
100
129
|
AuraFlowTransformer2DModel,
|
101
130
|
CogVideoXTransformer3DModel,
|
102
131
|
CogView3PlusTransformer2DModel,
|
@@ -104,10 +133,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
104
133
|
DualTransformer2DModel,
|
105
134
|
FluxTransformer2DModel,
|
106
135
|
HunyuanDiT2DModel,
|
136
|
+
HunyuanVideoTransformer3DModel,
|
107
137
|
LatteTransformer3DModel,
|
138
|
+
LTXVideoTransformer3DModel,
|
108
139
|
LuminaNextDiT2DModel,
|
140
|
+
MochiTransformer3DModel,
|
109
141
|
PixArtTransformer2DModel,
|
110
142
|
PriorTransformer,
|
143
|
+
SanaTransformer2DModel,
|
111
144
|
SD3Transformer2DModel,
|
112
145
|
StableAudioDiTModel,
|
113
146
|
T5FilmDecoder,
|
@@ -129,7 +162,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
129
162
|
)
|
130
163
|
|
131
164
|
if is_flax_available():
|
132
|
-
from .
|
165
|
+
from .controlnets import FlaxControlNetModel
|
133
166
|
from .unets import FlaxUNet2DConditionModel
|
134
167
|
from .vae_flax import FlaxAutoencoderKL
|
135
168
|
|
diffusers/models/activations.py
CHANGED
@@ -18,7 +18,7 @@ import torch.nn.functional as F
|
|
18
18
|
from torch import nn
|
19
19
|
|
20
20
|
from ..utils import deprecate
|
21
|
-
from ..utils.import_utils import is_torch_npu_available
|
21
|
+
from ..utils.import_utils import is_torch_npu_available, is_torch_version
|
22
22
|
|
23
23
|
|
24
24
|
if is_torch_npu_available():
|
@@ -79,10 +79,10 @@ class GELU(nn.Module):
|
|
79
79
|
self.approximate = approximate
|
80
80
|
|
81
81
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
82
|
-
if gate.device.type
|
83
|
-
|
84
|
-
|
85
|
-
return F.gelu(gate
|
82
|
+
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
|
83
|
+
# fp16 gelu not supported on mps before torch 2.0
|
84
|
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
85
|
+
return F.gelu(gate, approximate=self.approximate)
|
86
86
|
|
87
87
|
def forward(self, hidden_states):
|
88
88
|
hidden_states = self.proj(hidden_states)
|
@@ -105,10 +105,10 @@ class GEGLU(nn.Module):
|
|
105
105
|
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
106
106
|
|
107
107
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
108
|
-
if gate.device.type
|
109
|
-
|
110
|
-
|
111
|
-
return F.gelu(gate
|
108
|
+
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
|
109
|
+
# fp16 gelu not supported on mps before torch 2.0
|
110
|
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
111
|
+
return F.gelu(gate)
|
112
112
|
|
113
113
|
def forward(self, hidden_states, *args, **kwargs):
|
114
114
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
@@ -136,6 +136,7 @@ class SwiGLU(nn.Module):
|
|
136
136
|
|
137
137
|
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
138
138
|
super().__init__()
|
139
|
+
|
139
140
|
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
140
141
|
self.activation = nn.SiLU()
|
141
142
|
|
@@ -163,3 +164,15 @@ class ApproximateGELU(nn.Module):
|
|
163
164
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
164
165
|
x = self.proj(x)
|
165
166
|
return x * torch.sigmoid(1.702 * x)
|
167
|
+
|
168
|
+
|
169
|
+
class LinearActivation(nn.Module):
|
170
|
+
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
|
171
|
+
super().__init__()
|
172
|
+
|
173
|
+
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
174
|
+
self.activation = get_activation(activation)
|
175
|
+
|
176
|
+
def forward(self, hidden_states):
|
177
|
+
hidden_states = self.proj(hidden_states)
|
178
|
+
return self.activation(hidden_states)
|
diffusers/models/attention.py
CHANGED
@@ -19,7 +19,7 @@ from torch import nn
|
|
19
19
|
|
20
20
|
from ..utils import deprecate, logging
|
21
21
|
from ..utils.torch_utils import maybe_allow_in_graph
|
22
|
-
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
22
|
+
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
23
23
|
from .attention_processor import Attention, JointAttnProcessor2_0
|
24
24
|
from .embeddings import SinusoidalPositionalEmbedding
|
25
25
|
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module):
|
|
188
188
|
self._chunk_dim = dim
|
189
189
|
|
190
190
|
def forward(
|
191
|
-
self,
|
191
|
+
self,
|
192
|
+
hidden_states: torch.FloatTensor,
|
193
|
+
encoder_hidden_states: torch.FloatTensor,
|
194
|
+
temb: torch.FloatTensor,
|
195
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
192
196
|
):
|
197
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
193
198
|
if self.use_dual_attention:
|
194
199
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
195
200
|
hidden_states, emb=temb
|
@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module):
|
|
206
211
|
|
207
212
|
# Attention.
|
208
213
|
attn_output, context_attn_output = self.attn(
|
209
|
-
hidden_states=norm_hidden_states,
|
214
|
+
hidden_states=norm_hidden_states,
|
215
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
216
|
+
**joint_attention_kwargs,
|
210
217
|
)
|
211
218
|
|
212
219
|
# Process attention outputs for the `hidden_states`.
|
@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module):
|
|
214
221
|
hidden_states = hidden_states + attn_output
|
215
222
|
|
216
223
|
if self.use_dual_attention:
|
217
|
-
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
|
224
|
+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
|
218
225
|
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
219
226
|
hidden_states = hidden_states + attn_output2
|
220
227
|
|
@@ -1222,6 +1229,8 @@ class FeedForward(nn.Module):
|
|
1222
1229
|
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
1223
1230
|
elif activation_fn == "swiglu":
|
1224
1231
|
act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
1232
|
+
elif activation_fn == "linear-silu":
|
1233
|
+
act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
|
1225
1234
|
|
1226
1235
|
self.net = nn.ModuleList([])
|
1227
1236
|
# project in
|
@@ -216,8 +216,8 @@ class FlaxAttention(nn.Module):
|
|
216
216
|
hidden_states = jax_memory_efficient_attention(
|
217
217
|
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
|
218
218
|
)
|
219
|
-
|
220
219
|
hidden_states = hidden_states.transpose(1, 0, 2)
|
220
|
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
221
221
|
else:
|
222
222
|
# compute attentions
|
223
223
|
if self.split_head_dim:
|