x-transformers 2.6.1__py3-none-any.whl → 2.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2435,6 +2435,7 @@ class AttentionLayers(Module):
2435
2435
  deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
2436
2436
  self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
2437
2437
  additional_kv_mask = None,
2438
+ route_additional_kv_to_top = True,
2438
2439
  condition = None,
2439
2440
  in_attn_cond = None, # https://arxiv.org/abs/2105.04090
2440
2441
  layers_execute_order: tuple[int, ...] | None = None
@@ -2544,10 +2545,6 @@ class AttentionLayers(Module):
2544
2545
 
2545
2546
  iter_attn_cache = iter(attn_cache)
2546
2547
 
2547
- # additional self attn key / values
2548
-
2549
- iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2550
-
2551
2548
  # handle deep embeds if needed
2552
2549
 
2553
2550
  deep_embeds = []
@@ -2582,6 +2579,16 @@ class AttentionLayers(Module):
2582
2579
  layers_execute_order = default(layers_execute_order, self.layers_execute_order)
2583
2580
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
2584
2581
 
2582
+ # additional self attn key / values - say coming from vlm
2583
+
2584
+ if exists(self_attn_additional_kv) and route_additional_kv_to_top:
2585
+ num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
2586
+
2587
+ self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
2588
+ self_attn_additional_kv = [None] * (num_self_attns - len(self_attn_additional_kv)) + self_attn_additional_kv
2589
+
2590
+ iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
2591
+
2585
2592
  # derived input for reinjection if needed
2586
2593
 
2587
2594
  inp_inject = None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.6.1
3
+ Version: 2.6.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
11
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=O8Z4j7wDrj47-lZxmpvToHbHpoFqLy2pk199tQ4v4hI,121281
12
+ x_transformers/x_transformers.py,sha256=kcdU6gp4QXfab9P0M5WYnVic6nTFrTtRGyLEcPFBQcY,121719
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.6.1.dist-info/METADATA,sha256=VVzitcHytmh6tNmtjSMyDWjFxfjpQu6PhR4sTFkxjpk,90223
16
- x_transformers-2.6.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.6.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.6.1.dist-info/RECORD,,
15
+ x_transformers-2.6.2.dist-info/METADATA,sha256=zHbOLX82fzv9Lzg6i2ZqkJhGJnDtAvwNIl9fbnKnhl8,90223
16
+ x_transformers-2.6.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.6.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.6.2.dist-info/RECORD,,