x-transformers 2.6.1__tar.gz → 2.6.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.6.1 → x_transformers-2.6.3}/PKG-INFO +1 -1
- {x_transformers-2.6.1 → x_transformers-2.6.3}/pyproject.toml +1 -1
- {x_transformers-2.6.1 → x_transformers-2.6.3}/tests/test_x_transformers.py +3 -3
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/x_transformers.py +18 -5
- {x_transformers-2.6.1 → x_transformers-2.6.3}/.github/FUNDING.yml +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/.gitignore +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/LICENSE +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/README.md +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/data/README.md +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/data/enwik8.gz +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/all-attention.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/attention-on-attention.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/deepnorm.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/fcm.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/ffglu.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/flash-attention.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/gate_values.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/gating.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/macaron-1.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/macaron-2.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/memory-transformer.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/normformer.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/pia.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/resi_dual.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/residual_attn.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/rezero.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/rotary.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/sandwich-2.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/sandwich.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/sandwich_norm.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/scalenorm.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/talking-heads.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/topk-attention.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/images/xval.png +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/train_belief_state.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/train_copy.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/train_enwik8.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/train_length_extrapolate.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/train_parity.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/__init__.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/attend.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/continuous.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/dpo.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.6.1 → x_transformers-2.6.3}/x_transformers/xval.py +0 -0
@@ -1219,7 +1219,7 @@ def test_external_key_values():
|
|
1219
1219
|
max_seq_len = 1024,
|
1220
1220
|
attn_layers = Decoder(
|
1221
1221
|
dim = 512,
|
1222
|
-
depth =
|
1222
|
+
depth = 3,
|
1223
1223
|
heads = 8,
|
1224
1224
|
attn_dim_head = 16
|
1225
1225
|
)
|
@@ -1228,8 +1228,8 @@ def test_external_key_values():
|
|
1228
1228
|
seq = torch.randint(0, 20000, (3, 1024))
|
1229
1229
|
|
1230
1230
|
key_values = [
|
1231
|
-
(torch.randn(3,
|
1232
|
-
(torch.randn(3,
|
1231
|
+
(torch.randn(3, 2, 32, 16), torch.randn(3, 2, 32, 16)),
|
1232
|
+
(torch.randn(3, 2, 32, 16), torch.randn(3, 2, 32, 16)),
|
1233
1233
|
]
|
1234
1234
|
|
1235
1235
|
additional_kv_mask = torch.randint(0, 2, (3, 32)).bool()
|
@@ -1795,6 +1795,13 @@ class Attention(Module):
|
|
1795
1795
|
seq_len = k.shape[-2]
|
1796
1796
|
|
1797
1797
|
added_k, added_v = additional_key_values
|
1798
|
+
added_kv_heads, added_kv_len = added_k.shape[1], added_k.shape[-2]
|
1799
|
+
|
1800
|
+
# take care of expanding to query heads if mismatch between key / value heads with the ones coming from vlm
|
1801
|
+
|
1802
|
+
if added_kv_heads != kv_h:
|
1803
|
+
assert divisible_by(h, added_kv_heads)
|
1804
|
+
k, v, added_k, added_v = tuple(repeat(t, 'b h ... -> b (r h) ...', r = h // t.shape[1]) for t in (k, v, added_k, added_v))
|
1798
1805
|
|
1799
1806
|
k = cat((added_k, k), dim = -2)
|
1800
1807
|
v = cat((added_v, v), dim = -2)
|
@@ -1802,7 +1809,6 @@ class Attention(Module):
|
|
1802
1809
|
if (exists(input_mask) or exists(additional_key_value_mask)):
|
1803
1810
|
|
1804
1811
|
if not exists(additional_key_value_mask):
|
1805
|
-
added_kv_len = added_k.shape[-2]
|
1806
1812
|
input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
|
1807
1813
|
elif not exists(input_mask):
|
1808
1814
|
input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
|
@@ -2435,6 +2441,7 @@ class AttentionLayers(Module):
|
|
2435
2441
|
deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
|
2436
2442
|
self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
|
2437
2443
|
additional_kv_mask = None,
|
2444
|
+
route_additional_kv_to_top = True,
|
2438
2445
|
condition = None,
|
2439
2446
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
2440
2447
|
layers_execute_order: tuple[int, ...] | None = None
|
@@ -2544,10 +2551,6 @@ class AttentionLayers(Module):
|
|
2544
2551
|
|
2545
2552
|
iter_attn_cache = iter(attn_cache)
|
2546
2553
|
|
2547
|
-
# additional self attn key / values
|
2548
|
-
|
2549
|
-
iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
|
2550
|
-
|
2551
2554
|
# handle deep embeds if needed
|
2552
2555
|
|
2553
2556
|
deep_embeds = []
|
@@ -2582,6 +2585,16 @@ class AttentionLayers(Module):
|
|
2582
2585
|
layers_execute_order = default(layers_execute_order, self.layers_execute_order)
|
2583
2586
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
2584
2587
|
|
2588
|
+
# additional self attn key / values - say coming from vlm
|
2589
|
+
|
2590
|
+
if exists(self_attn_additional_kv) and route_additional_kv_to_top:
|
2591
|
+
num_self_attns = sum([layer_type == 'a' for layer_type in first(layer_variables)])
|
2592
|
+
|
2593
|
+
self_attn_additional_kv = self_attn_additional_kv[-num_self_attns:]
|
2594
|
+
self_attn_additional_kv = [None] * (num_self_attns - len(self_attn_additional_kv)) + self_attn_additional_kv
|
2595
|
+
|
2596
|
+
iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
|
2597
|
+
|
2585
2598
|
# derived input for reinjection if needed
|
2586
2599
|
|
2587
2600
|
inp_inject = None
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|