x-transformers 2.6.0__py3-none-any.whl → 2.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +25 -9
- {x_transformers-2.6.0.dist-info → x_transformers-2.6.2.dist-info}/METADATA +1 -1
- {x_transformers-2.6.0.dist-info → x_transformers-2.6.2.dist-info}/RECORD +5 -5
- {x_transformers-2.6.0.dist-info → x_transformers-2.6.2.dist-info}/WHEEL +0 -0
- {x_transformers-2.6.0.dist-info → x_transformers-2.6.2.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1618,7 +1618,8 @@ class Attention(Module):
|
|
1618
1618
|
return_intermediates = False,
|
1619
1619
|
cache: Intermediates | None = None,
|
1620
1620
|
value_residual = None,
|
1621
|
-
additional_key_values: tuple[Tensor, Tensor] | None = None
|
1621
|
+
additional_key_values: tuple[Tensor, Tensor] | None = None,
|
1622
|
+
additional_key_value_mask = None,
|
1622
1623
|
):
|
1623
1624
|
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
|
1624
1625
|
|
@@ -1791,15 +1792,22 @@ class Attention(Module):
|
|
1791
1792
|
# maybe append additional key / values
|
1792
1793
|
|
1793
1794
|
if exists(additional_key_values):
|
1795
|
+
seq_len = k.shape[-2]
|
1794
1796
|
|
1795
1797
|
added_k, added_v = additional_key_values
|
1796
|
-
added_kv_len = added_k.shape[-2]
|
1797
1798
|
|
1798
1799
|
k = cat((added_k, k), dim = -2)
|
1799
1800
|
v = cat((added_v, v), dim = -2)
|
1800
1801
|
|
1801
|
-
if exists(input_mask):
|
1802
|
-
|
1802
|
+
if (exists(input_mask) or exists(additional_key_value_mask)):
|
1803
|
+
|
1804
|
+
if not exists(additional_key_value_mask):
|
1805
|
+
added_kv_len = added_k.shape[-2]
|
1806
|
+
input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
|
1807
|
+
elif not exists(input_mask):
|
1808
|
+
input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
|
1809
|
+
else:
|
1810
|
+
input_mask = cat((additional_key_value_mask, input_mask), dim = -1)
|
1803
1811
|
|
1804
1812
|
# determine masking
|
1805
1813
|
|
@@ -2426,6 +2434,8 @@ class AttentionLayers(Module):
|
|
2426
2434
|
attn_bias = None,
|
2427
2435
|
deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
|
2428
2436
|
self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
|
2437
|
+
additional_kv_mask = None,
|
2438
|
+
route_additional_kv_to_top = True,
|
2429
2439
|
condition = None,
|
2430
2440
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
2431
2441
|
layers_execute_order: tuple[int, ...] | None = None
|
@@ -2535,10 +2545,6 @@ class AttentionLayers(Module):
|
|
2535
2545
|
|
2536
2546
|
iter_attn_cache = iter(attn_cache)
|
2537
2547
|
|
2538
|
-
# additional self attn key / values
|
2539
|
-
|
2540
|
-
iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
|
2541
|
-
|
2542
2548
|
# handle deep embeds if needed
|
2543
2549
|
|
2544
2550
|
deep_embeds = []
|
@@ -2573,6 +2579,16 @@ class AttentionLayers(Module):
|
|
2573
2579
|
layers_execute_order = default(layers_execute_order, self.layers_execute_order)
|
2574
2580
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
2575
2581
|
|
2582
|
+
# additional self attn key / values - say coming from vlm
|
2583
|
+
|
2584
|
+
if exists(self_attn_additional_kv) and route_additional_kv_to_top:
|
2585
|
+
num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
|
2586
|
+
|
2587
|
+
self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
|
2588
|
+
self_attn_additional_kv = [None] * (num_self_attns - len(self_attn_additional_kv)) + self_attn_additional_kv
|
2589
|
+
|
2590
|
+
iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
|
2591
|
+
|
2576
2592
|
# derived input for reinjection if needed
|
2577
2593
|
|
2578
2594
|
inp_inject = None
|
@@ -2666,7 +2682,7 @@ class AttentionLayers(Module):
|
|
2666
2682
|
# forward depending on layer type
|
2667
2683
|
|
2668
2684
|
if layer_type == 'a':
|
2669
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2685
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2670
2686
|
elif layer_type == 'c':
|
2671
2687
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
2672
2688
|
elif layer_type == 'f':
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
12
|
+
x_transformers/x_transformers.py,sha256=kcdU6gp4QXfab9P0M5WYnVic6nTFrTtRGyLEcPFBQcY,121719
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.6.
|
16
|
-
x_transformers-2.6.
|
17
|
-
x_transformers-2.6.
|
18
|
-
x_transformers-2.6.
|
15
|
+
x_transformers-2.6.2.dist-info/METADATA,sha256=zHbOLX82fzv9Lzg6i2ZqkJhGJnDtAvwNIl9fbnKnhl8,90223
|
16
|
+
x_transformers-2.6.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.6.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.6.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|