x-transformers 2.6.2__py3-none-any.whl → 2.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1795,6 +1795,13 @@ class Attention(Module):
1795
1795
  seq_len = k.shape[-2]
1796
1796
 
1797
1797
  added_k, added_v = additional_key_values
1798
+ added_kv_heads, added_kv_len = added_k.shape[1], added_k.shape[-2]
1799
+
1800
+ # take care of expanding to query heads if mismatch between key / value heads with the ones coming from vlm
1801
+
1802
+ if added_kv_heads != kv_h:
1803
+ assert divisible_by(h, added_kv_heads)
1804
+ k, v, added_k, added_v = tuple(repeat(t, 'b h ... -> b (r h) ...', r = h // t.shape[1]) for t in (k, v, added_k, added_v))
1798
1805
 
1799
1806
  k = cat((added_k, k), dim = -2)
1800
1807
  v = cat((added_v, v), dim = -2)
@@ -1802,7 +1809,6 @@ class Attention(Module):
1802
1809
  if (exists(input_mask) or exists(additional_key_value_mask)):
1803
1810
 
1804
1811
  if not exists(additional_key_value_mask):
1805
- added_kv_len = added_k.shape[-2]
1806
1812
  input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
1807
1813
  elif not exists(input_mask):
1808
1814
  input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.6.2
3
+ Version: 2.6.3
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
11
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=kcdU6gp4QXfab9P0M5WYnVic6nTFrTtRGyLEcPFBQcY,121719
12
+ x_transformers/x_transformers.py,sha256=B7dv_LuzODwCrTsfDnp28g-_lMnirQE3gteQwSGyW5k,122100
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.6.2.dist-info/METADATA,sha256=zHbOLX82fzv9Lzg6i2ZqkJhGJnDtAvwNIl9fbnKnhl8,90223
16
- x_transformers-2.6.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.6.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.6.2.dist-info/RECORD,,
15
+ x_transformers-2.6.3.dist-info/METADATA,sha256=DaTrEChlXc_zpUXv-Jw3A4ca4aon0Ons7wl4-wj1XzY,90223
16
+ x_transformers-2.6.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.6.3.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.6.3.dist-info/RECORD,,