x-transformers 2.6.1__py3-none-any.whl → 2.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1795,6 +1795,13 @@ class Attention(Module):
1795
1795
  seq_len = k.shape[-2]
1796
1796
 
1797
1797
  added_k, added_v = additional_key_values
1798
+ added_kv_heads, added_kv_len = added_k.shape[1], added_k.shape[-2]
1799
+
1800
+ # take care of expanding to query heads if mismatch between key / value heads with the ones coming from vlm
1801
+
1802
+ if added_kv_heads != kv_h:
1803
+ assert divisible_by(h, added_kv_heads)
1804
+ k, v, added_k, added_v = tuple(repeat(t, 'b h ... -> b (r h) ...', r = h // t.shape[1]) for t in (k, v, added_k, added_v))
1798
1805
 
1799
1806
  k = cat((added_k, k), dim = -2)
1800
1807
  v = cat((added_v, v), dim = -2)
@@ -1802,7 +1809,6 @@ class Attention(Module):
1802
1809
  if (exists(input_mask) or exists(additional_key_value_mask)):
1803
1810
 
1804
1811
  if not exists(additional_key_value_mask):
1805
- added_kv_len = added_k.shape[-2]
1806
1812
  input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
1807
1813
  elif not exists(input_mask):
1808
1814
  input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
@@ -2435,6 +2441,7 @@ class AttentionLayers(Module):
2435
2441
  deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
2436
2442
  self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
2437
2443
  additional_kv_mask = None,
2444
+ route_additional_kv_to_top = True,
2438
2445
  condition = None,
2439
2446
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
2440
2447
  layers_execute_order: tuple[int, ...] | None = None
@@ -2544,10 +2551,6 @@ class AttentionLayers(Module):
2544
2551
 
2545
2552
  iter_attn_cache = iter(attn_cache)
2546
2553
 
2547
- # additional self attn key / values
2548
-
2549
- iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2550
-
2551
2554
  # handle deep embeds if needed
2552
2555
 
2553
2556
  deep_embeds = []
@@ -2582,6 +2585,16 @@ class AttentionLayers(Module):
2582
2585
  layers_execute_order = default(layers_execute_order, self.layers_execute_order)
2583
2586
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
2584
2587
 
2588
+ # additional self attn key / values - say coming from vlm
2589
+
2590
+ if exists(self_attn_additional_kv) and route_additional_kv_to_top:
2591
+ num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
2592
+
2593
+ self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
2594
+ self_attn_additional_kv = [None] * (num_self_attns - len(self_attn_additional_kv)) + self_attn_additional_kv
2595
+
2596
+ iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2597
+
2585
2598
  # derived input for reinjection if needed
2586
2599
 
2587
2600
  inp_inject = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.6.1
3
+ Version: 2.6.3
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
11
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=O8Z4j7wDrj47-lZxmpvToHbHpoFqLy2pk199tQ4v4hI,121281
12
+ x_transformers/x_transformers.py,sha256=B7dv_LuzODwCrTsfDnp28g-_lMnirQE3gteQwSGyW5k,122100
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.6.1.dist-info/METADATA,sha256=VVzitcHytmh6tNmtjSMyDWjFxfjpQu6PhR4sTFkxjpk,90223
16
- x_transformers-2.6.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.6.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.6.1.dist-info/RECORD,,
15
+ x_transformers-2.6.3.dist-info/METADATA,sha256=DaTrEChlXc_zpUXv-Jw3A4ca4aon0Ons7wl4-wj1XzY,90223
16
+ x_transformers-2.6.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.6.3.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.6.3.dist-info/RECORD,,