diffusers 0.30.0__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +1 -1
- diffusers/loaders/lora_pipeline.py +37 -7
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_utils.py +34 -9
- diffusers/models/attention_processor.py +142 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +327 -91
- diffusers/models/embeddings.py +84 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +17 -1
- diffusers/models/transformers/cogvideox_transformer_3d.py +196 -56
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +78 -19
- diffusers/utils/export_utils.py +50 -3
- diffusers/utils/import_utils.py +19 -0
- diffusers/utils/loading_utils.py +16 -12
- {diffusers-0.30.0.dist-info → diffusers-0.30.1.dist-info}/METADATA +1 -1
- {diffusers-0.30.0.dist-info → diffusers-0.30.1.dist-info}/RECORD +19 -19
- {diffusers-0.30.0.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
- {diffusers-0.30.0.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
- {diffusers-0.30.0.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.0.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
diffusers/__init__.py
CHANGED
@@ -1489,10 +1489,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1489
1489
|
|
1490
1490
|
@classmethod
|
1491
1491
|
@validate_hf_hub_args
|
1492
|
-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
1493
1492
|
def lora_state_dict(
|
1494
1493
|
cls,
|
1495
1494
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
1495
|
+
return_alphas: bool = False,
|
1496
1496
|
**kwargs,
|
1497
1497
|
):
|
1498
1498
|
r"""
|
@@ -1577,7 +1577,26 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1577
1577
|
allow_pickle=allow_pickle,
|
1578
1578
|
)
|
1579
1579
|
|
1580
|
-
|
1580
|
+
# For state dicts like
|
1581
|
+
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
1582
|
+
keys = list(state_dict.keys())
|
1583
|
+
network_alphas = {}
|
1584
|
+
for k in keys:
|
1585
|
+
if "alpha" in k:
|
1586
|
+
alpha_value = state_dict.get(k)
|
1587
|
+
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
|
1588
|
+
alpha_value, float
|
1589
|
+
):
|
1590
|
+
network_alphas[k] = state_dict.pop(k)
|
1591
|
+
else:
|
1592
|
+
raise ValueError(
|
1593
|
+
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
|
1594
|
+
)
|
1595
|
+
|
1596
|
+
if return_alphas:
|
1597
|
+
return state_dict, network_alphas
|
1598
|
+
else:
|
1599
|
+
return state_dict
|
1581
1600
|
|
1582
1601
|
def load_lora_weights(
|
1583
1602
|
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
@@ -1611,7 +1630,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1611
1630
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
1612
1631
|
|
1613
1632
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
1614
|
-
state_dict = self.lora_state_dict(
|
1633
|
+
state_dict, network_alphas = self.lora_state_dict(
|
1634
|
+
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
|
1635
|
+
)
|
1615
1636
|
|
1616
1637
|
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
1617
1638
|
if not is_correct_format:
|
@@ -1619,6 +1640,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1619
1640
|
|
1620
1641
|
self.load_lora_into_transformer(
|
1621
1642
|
state_dict,
|
1643
|
+
network_alphas=network_alphas,
|
1622
1644
|
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
1623
1645
|
adapter_name=adapter_name,
|
1624
1646
|
_pipeline=self,
|
@@ -1628,7 +1650,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1628
1650
|
if len(text_encoder_state_dict) > 0:
|
1629
1651
|
self.load_lora_into_text_encoder(
|
1630
1652
|
text_encoder_state_dict,
|
1631
|
-
network_alphas=
|
1653
|
+
network_alphas=network_alphas,
|
1632
1654
|
text_encoder=self.text_encoder,
|
1633
1655
|
prefix="text_encoder",
|
1634
1656
|
lora_scale=self.lora_scale,
|
@@ -1637,8 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1637
1659
|
)
|
1638
1660
|
|
1639
1661
|
@classmethod
|
1640
|
-
|
1641
|
-
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
1662
|
+
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
|
1642
1663
|
"""
|
1643
1664
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1644
1665
|
|
@@ -1647,6 +1668,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1647
1668
|
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
1648
1669
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
1649
1670
|
encoder lora layers.
|
1671
|
+
network_alphas (`Dict[str, float]`):
|
1672
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
1673
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
1674
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
1650
1675
|
transformer (`SD3Transformer2DModel`):
|
1651
1676
|
The Transformer model to load the LoRA layers into.
|
1652
1677
|
adapter_name (`str`, *optional*):
|
@@ -1678,7 +1703,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1678
1703
|
if "lora_B" in key:
|
1679
1704
|
rank[key] = val.shape[1]
|
1680
1705
|
|
1681
|
-
|
1706
|
+
if network_alphas is not None and len(network_alphas) >= 1:
|
1707
|
+
prefix = cls.transformer_name
|
1708
|
+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
1709
|
+
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
1710
|
+
|
1711
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
1682
1712
|
if "use_dora" in lora_config_kwargs:
|
1683
1713
|
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
1684
1714
|
raise ValueError(
|
diffusers/loaders/single_file.py
CHANGED
@@ -23,6 +23,7 @@ from packaging import version
|
|
23
23
|
from ..utils import deprecate, is_transformers_available, logging
|
24
24
|
from .single_file_utils import (
|
25
25
|
SingleFileComponentError,
|
26
|
+
_is_legacy_scheduler_kwargs,
|
26
27
|
_is_model_weights_in_cached_folder,
|
27
28
|
_legacy_load_clip_tokenizer,
|
28
29
|
_legacy_load_safety_checker,
|
@@ -42,7 +43,6 @@ logger = logging.get_logger(__name__)
|
|
42
43
|
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
|
43
44
|
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
|
44
45
|
|
45
|
-
|
46
46
|
if is_transformers_available():
|
47
47
|
import transformers
|
48
48
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
@@ -135,7 +135,7 @@ def load_single_file_sub_model(
|
|
135
135
|
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
136
136
|
)
|
137
137
|
|
138
|
-
elif is_diffusers_scheduler and is_legacy_loading:
|
138
|
+
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
|
139
139
|
loaded_sub_model = _legacy_load_scheduler(
|
140
140
|
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
|
141
141
|
)
|
@@ -79,7 +79,10 @@ CHECKPOINT_KEY_NAMES = {
|
|
79
79
|
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
80
80
|
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
81
81
|
"animatediff_rgb": "controlnet_cond_embedding.weight",
|
82
|
-
"flux":
|
82
|
+
"flux": [
|
83
|
+
"double_blocks.0.img_attn.norm.key_norm.scale",
|
84
|
+
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
85
|
+
],
|
83
86
|
}
|
84
87
|
|
85
88
|
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
@@ -258,7 +261,7 @@ SCHEDULER_DEFAULT_CONFIG = {
|
|
258
261
|
"timestep_spacing": "leading",
|
259
262
|
}
|
260
263
|
|
261
|
-
|
264
|
+
LDM_VAE_KEYS = ["first_stage_model.", "vae."]
|
262
265
|
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
263
266
|
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
264
267
|
LDM_UNET_KEY = "model.diffusion_model."
|
@@ -267,8 +270,8 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
|
|
267
270
|
"cond_stage_model.transformer.",
|
268
271
|
"conditioner.embedders.0.transformer.",
|
269
272
|
]
|
270
|
-
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
271
273
|
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
274
|
+
SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
|
272
275
|
|
273
276
|
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
274
277
|
|
@@ -318,6 +321,10 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
|
|
318
321
|
return weights_exist
|
319
322
|
|
320
323
|
|
324
|
+
def _is_legacy_scheduler_kwargs(kwargs):
|
325
|
+
return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys())
|
326
|
+
|
327
|
+
|
321
328
|
def load_single_file_checkpoint(
|
322
329
|
pretrained_model_link_or_path,
|
323
330
|
force_download=False,
|
@@ -516,8 +523,10 @@ def infer_diffusers_model_type(checkpoint):
|
|
516
523
|
else:
|
517
524
|
model_type = "animatediff_v3"
|
518
525
|
|
519
|
-
elif CHECKPOINT_KEY_NAMES["flux"]
|
520
|
-
if
|
526
|
+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
|
527
|
+
if any(
|
528
|
+
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
529
|
+
):
|
521
530
|
model_type = "flux-dev"
|
522
531
|
else:
|
523
532
|
model_type = "flux-schnell"
|
@@ -1176,7 +1185,11 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
1176
1185
|
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
|
1177
1186
|
vae_state_dict = {}
|
1178
1187
|
keys = list(checkpoint.keys())
|
1179
|
-
vae_key =
|
1188
|
+
vae_key = ""
|
1189
|
+
for ldm_vae_key in LDM_VAE_KEYS:
|
1190
|
+
if any(k.startswith(ldm_vae_key) for k in keys):
|
1191
|
+
vae_key = ldm_vae_key
|
1192
|
+
|
1180
1193
|
for key in keys:
|
1181
1194
|
if key.startswith(vae_key):
|
1182
1195
|
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
@@ -1477,14 +1490,22 @@ def _legacy_load_scheduler(
|
|
1477
1490
|
|
1478
1491
|
if scheduler_type is not None:
|
1479
1492
|
deprecation_message = (
|
1480
|
-
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file
|
1493
|
+
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n"
|
1494
|
+
"Example:\n\n"
|
1495
|
+
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
|
1496
|
+
"scheduler = DDIMScheduler()\n"
|
1497
|
+
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
|
1481
1498
|
)
|
1482
1499
|
deprecate("scheduler_type", "1.0.0", deprecation_message)
|
1483
1500
|
|
1484
1501
|
if prediction_type is not None:
|
1485
1502
|
deprecation_message = (
|
1486
|
-
"Please configure an instance of a Scheduler with the appropriate `prediction_type` "
|
1487
|
-
"
|
1503
|
+
"Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
|
1504
|
+
"pass the object directly to the `scheduler` argument in `from_single_file`.\n\n"
|
1505
|
+
"Example:\n\n"
|
1506
|
+
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
|
1507
|
+
'scheduler = DDIMScheduler(prediction_type="v_prediction")\n'
|
1508
|
+
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
|
1488
1509
|
)
|
1489
1510
|
deprecate("prediction_type", "1.0.0", deprecation_message)
|
1490
1511
|
|
@@ -1881,6 +1902,10 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
1881
1902
|
|
1882
1903
|
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
1883
1904
|
converted_state_dict = {}
|
1905
|
+
keys = list(checkpoint.keys())
|
1906
|
+
for k in keys:
|
1907
|
+
if "model.diffusion_model." in k:
|
1908
|
+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
1884
1909
|
|
1885
1910
|
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
1886
1911
|
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
@@ -1868,6 +1868,148 @@ class FluxAttnProcessor2_0:
|
|
1868
1868
|
return hidden_states, encoder_hidden_states
|
1869
1869
|
|
1870
1870
|
|
1871
|
+
class CogVideoXAttnProcessor2_0:
|
1872
|
+
r"""
|
1873
|
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
1874
|
+
query and key vectors, but does not include spatial normalization.
|
1875
|
+
"""
|
1876
|
+
|
1877
|
+
def __init__(self):
|
1878
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1879
|
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1880
|
+
|
1881
|
+
def __call__(
|
1882
|
+
self,
|
1883
|
+
attn: Attention,
|
1884
|
+
hidden_states: torch.Tensor,
|
1885
|
+
encoder_hidden_states: torch.Tensor,
|
1886
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1887
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
1888
|
+
) -> torch.Tensor:
|
1889
|
+
text_seq_length = encoder_hidden_states.size(1)
|
1890
|
+
|
1891
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
1892
|
+
|
1893
|
+
batch_size, sequence_length, _ = (
|
1894
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1895
|
+
)
|
1896
|
+
|
1897
|
+
if attention_mask is not None:
|
1898
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1899
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1900
|
+
|
1901
|
+
query = attn.to_q(hidden_states)
|
1902
|
+
key = attn.to_k(hidden_states)
|
1903
|
+
value = attn.to_v(hidden_states)
|
1904
|
+
|
1905
|
+
inner_dim = key.shape[-1]
|
1906
|
+
head_dim = inner_dim // attn.heads
|
1907
|
+
|
1908
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1909
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1910
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1911
|
+
|
1912
|
+
if attn.norm_q is not None:
|
1913
|
+
query = attn.norm_q(query)
|
1914
|
+
if attn.norm_k is not None:
|
1915
|
+
key = attn.norm_k(key)
|
1916
|
+
|
1917
|
+
# Apply RoPE if needed
|
1918
|
+
if image_rotary_emb is not None:
|
1919
|
+
from .embeddings import apply_rotary_emb
|
1920
|
+
|
1921
|
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
1922
|
+
if not attn.is_cross_attention:
|
1923
|
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
1924
|
+
|
1925
|
+
hidden_states = F.scaled_dot_product_attention(
|
1926
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1927
|
+
)
|
1928
|
+
|
1929
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1930
|
+
|
1931
|
+
# linear proj
|
1932
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1933
|
+
# dropout
|
1934
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1935
|
+
|
1936
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
1937
|
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
1938
|
+
)
|
1939
|
+
return hidden_states, encoder_hidden_states
|
1940
|
+
|
1941
|
+
|
1942
|
+
class FusedCogVideoXAttnProcessor2_0:
|
1943
|
+
r"""
|
1944
|
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
1945
|
+
query and key vectors, but does not include spatial normalization.
|
1946
|
+
"""
|
1947
|
+
|
1948
|
+
def __init__(self):
|
1949
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1950
|
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1951
|
+
|
1952
|
+
def __call__(
|
1953
|
+
self,
|
1954
|
+
attn: Attention,
|
1955
|
+
hidden_states: torch.Tensor,
|
1956
|
+
encoder_hidden_states: torch.Tensor,
|
1957
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1958
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
1959
|
+
) -> torch.Tensor:
|
1960
|
+
text_seq_length = encoder_hidden_states.size(1)
|
1961
|
+
|
1962
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
1963
|
+
|
1964
|
+
batch_size, sequence_length, _ = (
|
1965
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1966
|
+
)
|
1967
|
+
|
1968
|
+
if attention_mask is not None:
|
1969
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1970
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1971
|
+
|
1972
|
+
qkv = attn.to_qkv(hidden_states)
|
1973
|
+
split_size = qkv.shape[-1] // 3
|
1974
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
1975
|
+
|
1976
|
+
inner_dim = key.shape[-1]
|
1977
|
+
head_dim = inner_dim // attn.heads
|
1978
|
+
|
1979
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1980
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1981
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1982
|
+
|
1983
|
+
if attn.norm_q is not None:
|
1984
|
+
query = attn.norm_q(query)
|
1985
|
+
if attn.norm_k is not None:
|
1986
|
+
key = attn.norm_k(key)
|
1987
|
+
|
1988
|
+
# Apply RoPE if needed
|
1989
|
+
if image_rotary_emb is not None:
|
1990
|
+
from .embeddings import apply_rotary_emb
|
1991
|
+
|
1992
|
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
1993
|
+
if not attn.is_cross_attention:
|
1994
|
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
1995
|
+
|
1996
|
+
hidden_states = F.scaled_dot_product_attention(
|
1997
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1998
|
+
)
|
1999
|
+
|
2000
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2001
|
+
|
2002
|
+
# linear proj
|
2003
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2004
|
+
# dropout
|
2005
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2006
|
+
|
2007
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
2008
|
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
2009
|
+
)
|
2010
|
+
return hidden_states, encoder_hidden_states
|
2011
|
+
|
2012
|
+
|
1871
2013
|
class XFormersAttnAddedKVProcessor:
|
1872
2014
|
r"""
|
1873
2015
|
Processor for implementing memory efficient attention using xFormers.
|