x-transformers 2.6.1__py3-none-any.whl → 2.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +11 -4
- {x_transformers-2.6.1.dist-info → x_transformers-2.6.2.dist-info}/METADATA +1 -1
- {x_transformers-2.6.1.dist-info → x_transformers-2.6.2.dist-info}/RECORD +5 -5
- {x_transformers-2.6.1.dist-info → x_transformers-2.6.2.dist-info}/WHEEL +0 -0
- {x_transformers-2.6.1.dist-info → x_transformers-2.6.2.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -2435,6 +2435,7 @@ class AttentionLayers(Module):
|
|
2435
2435
|
deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
|
2436
2436
|
self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
|
2437
2437
|
additional_kv_mask = None,
|
2438
|
+
route_additional_kv_to_top = True,
|
2438
2439
|
condition = None,
|
2439
2440
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
2440
2441
|
layers_execute_order: tuple[int, ...] | None = None
|
@@ -2544,10 +2545,6 @@ class AttentionLayers(Module):
|
|
2544
2545
|
|
2545
2546
|
iter_attn_cache = iter(attn_cache)
|
2546
2547
|
|
2547
|
-
# additional self attn key / values
|
2548
|
-
|
2549
|
-
iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
|
2550
|
-
|
2551
2548
|
# handle deep embeds if needed
|
2552
2549
|
|
2553
2550
|
deep_embeds = []
|
@@ -2582,6 +2579,16 @@ class AttentionLayers(Module):
|
|
2582
2579
|
layers_execute_order = default(layers_execute_order, self.layers_execute_order)
|
2583
2580
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
2584
2581
|
|
2582
|
+
# additional self attn key / values - say coming from vlm
|
2583
|
+
|
2584
|
+
if exists(self_attn_additional_kv) and route_additional_kv_to_top:
|
2585
|
+
num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
|
2586
|
+
|
2587
|
+
self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
|
2588
|
+
self_attn_additional_kv = [None] * (num_self_attns - len(self_attn_additional_kv)) + self_attn_additional_kv
|
2589
|
+
|
2590
|
+
iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
|
2591
|
+
|
2585
2592
|
# derived input for reinjection if needed
|
2586
2593
|
|
2587
2594
|
inp_inject = None
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
12
|
+
x_transformers/x_transformers.py,sha256=kcdU6gp4QXfab9P0M5WYnVic6nTFrTtRGyLEcPFBQcY,121719
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.6.
|
16
|
-
x_transformers-2.6.
|
17
|
-
x_transformers-2.6.
|
18
|
-
x_transformers-2.6.
|
15
|
+
x_transformers-2.6.2.dist-info/METADATA,sha256=zHbOLX82fzv9Lzg6i2ZqkJhGJnDtAvwNIl9fbnKnhl8,90223
|
16
|
+
x_transformers-2.6.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.6.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.6.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|