x-transformers 2.6.2__py3-none-any.whl → 2.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +7 -1
- {x_transformers-2.6.2.dist-info → x_transformers-2.6.3.dist-info}/METADATA +1 -1
- {x_transformers-2.6.2.dist-info → x_transformers-2.6.3.dist-info}/RECORD +5 -5
- {x_transformers-2.6.2.dist-info → x_transformers-2.6.3.dist-info}/WHEEL +0 -0
- {x_transformers-2.6.2.dist-info → x_transformers-2.6.3.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1795,6 +1795,13 @@ class Attention(Module):
|
|
1795
1795
|
seq_len = k.shape[-2]
|
1796
1796
|
|
1797
1797
|
added_k, added_v = additional_key_values
|
1798
|
+
added_kv_heads, added_kv_len = added_k.shape[1], added_k.shape[-2]
|
1799
|
+
|
1800
|
+
# take care of expanding to query heads if mismatch between key / value heads with the ones coming from vlm
|
1801
|
+
|
1802
|
+
if added_kv_heads != kv_h:
|
1803
|
+
assert divisible_by(h, added_kv_heads)
|
1804
|
+
k, v, added_k, added_v = tuple(repeat(t, 'b h ... -> b (r h) ...', r = h // t.shape[1]) for t in (k, v, added_k, added_v))
|
1798
1805
|
|
1799
1806
|
k = cat((added_k, k), dim = -2)
|
1800
1807
|
v = cat((added_v, v), dim = -2)
|
@@ -1802,7 +1809,6 @@ class Attention(Module):
|
|
1802
1809
|
if (exists(input_mask) or exists(additional_key_value_mask)):
|
1803
1810
|
|
1804
1811
|
if not exists(additional_key_value_mask):
|
1805
|
-
added_kv_len = added_k.shape[-2]
|
1806
1812
|
input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
|
1807
1813
|
elif not exists(input_mask):
|
1808
1814
|
input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
12
|
+
x_transformers/x_transformers.py,sha256=B7dv_LuzODwCrTsfDnp28g-_lMnirQE3gteQwSGyW5k,122100
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.6.
|
16
|
-
x_transformers-2.6.
|
17
|
-
x_transformers-2.6.
|
18
|
-
x_transformers-2.6.
|
15
|
+
x_transformers-2.6.3.dist-info/METADATA,sha256=DaTrEChlXc_zpUXv-Jw3A4ca4aon0Ons7wl4-wj1XzY,90223
|
16
|
+
x_transformers-2.6.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.6.3.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.6.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|