x-transformers 2.5.6__tar.gz → 2.6.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.5.6 → x_transformers-2.6.1}/PKG-INFO +1 -1
- {x_transformers-2.5.6 → x_transformers-2.6.1}/pyproject.toml +1 -1
- {x_transformers-2.5.6 → x_transformers-2.6.1}/tests/test_x_transformers.py +25 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/x_transformers.py +30 -2
- {x_transformers-2.5.6 → x_transformers-2.6.1}/.github/FUNDING.yml +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/.gitignore +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/LICENSE +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/README.md +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/data/README.md +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/data/enwik8.gz +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/all-attention.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/attention-on-attention.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/deepnorm.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/fcm.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/ffglu.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/flash-attention.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/gate_values.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/gating.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/macaron-1.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/macaron-2.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/memory-transformer.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/normformer.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/pia.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/resi_dual.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/residual_attn.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/rezero.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/rotary.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/sandwich-2.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/sandwich.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/sandwich_norm.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/scalenorm.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/talking-heads.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/topk-attention.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/images/xval.png +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/train_belief_state.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/train_copy.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/train_enwik8.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/train_length_extrapolate.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/train_parity.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/__init__.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/attend.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/continuous.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/dpo.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.5.6 → x_transformers-2.6.1}/x_transformers/xval.py +0 -0
@@ -1210,3 +1210,28 @@ def test_prompts_given_as_list_tensor():
|
|
1210
1210
|
], 256)
|
1211
1211
|
|
1212
1212
|
assert sampled.shape == (4, 256)
|
1213
|
+
|
1214
|
+
def test_external_key_values():
|
1215
|
+
from x_transformers import AutoregressiveWrapper
|
1216
|
+
|
1217
|
+
model = TransformerWrapper(
|
1218
|
+
num_tokens = 20000,
|
1219
|
+
max_seq_len = 1024,
|
1220
|
+
attn_layers = Decoder(
|
1221
|
+
dim = 512,
|
1222
|
+
depth = 2,
|
1223
|
+
heads = 8,
|
1224
|
+
attn_dim_head = 16
|
1225
|
+
)
|
1226
|
+
)
|
1227
|
+
|
1228
|
+
seq = torch.randint(0, 20000, (3, 1024))
|
1229
|
+
|
1230
|
+
key_values = [
|
1231
|
+
(torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
|
1232
|
+
(torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
|
1233
|
+
]
|
1234
|
+
|
1235
|
+
additional_kv_mask = torch.randint(0, 2, (3, 32)).bool()
|
1236
|
+
|
1237
|
+
logits = model(seq, self_attn_additional_kv = key_values, additional_kv_mask = additional_kv_mask)
|
@@ -1617,7 +1617,9 @@ class Attention(Module):
|
|
1617
1617
|
mem_mask = None,
|
1618
1618
|
return_intermediates = False,
|
1619
1619
|
cache: Intermediates | None = None,
|
1620
|
-
value_residual = None
|
1620
|
+
value_residual = None,
|
1621
|
+
additional_key_values: tuple[Tensor, Tensor] | None = None,
|
1622
|
+
additional_key_value_mask = None,
|
1621
1623
|
):
|
1622
1624
|
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
|
1623
1625
|
|
@@ -1787,6 +1789,26 @@ class Attention(Module):
|
|
1787
1789
|
if exists(input_mask):
|
1788
1790
|
input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
|
1789
1791
|
|
1792
|
+
# maybe append additional key / values
|
1793
|
+
|
1794
|
+
if exists(additional_key_values):
|
1795
|
+
seq_len = k.shape[-2]
|
1796
|
+
|
1797
|
+
added_k, added_v = additional_key_values
|
1798
|
+
|
1799
|
+
k = cat((added_k, k), dim = -2)
|
1800
|
+
v = cat((added_v, v), dim = -2)
|
1801
|
+
|
1802
|
+
if (exists(input_mask) or exists(additional_key_value_mask)):
|
1803
|
+
|
1804
|
+
if not exists(additional_key_value_mask):
|
1805
|
+
added_kv_len = added_k.shape[-2]
|
1806
|
+
input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
|
1807
|
+
elif not exists(input_mask):
|
1808
|
+
input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
|
1809
|
+
else:
|
1810
|
+
input_mask = cat((additional_key_value_mask, input_mask), dim = -1)
|
1811
|
+
|
1790
1812
|
# determine masking
|
1791
1813
|
|
1792
1814
|
mask_value = max_neg_value(q)
|
@@ -2411,6 +2433,8 @@ class AttentionLayers(Module):
|
|
2411
2433
|
context_pos = None,
|
2412
2434
|
attn_bias = None,
|
2413
2435
|
deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
|
2436
|
+
self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
|
2437
|
+
additional_kv_mask = None,
|
2414
2438
|
condition = None,
|
2415
2439
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
2416
2440
|
layers_execute_order: tuple[int, ...] | None = None
|
@@ -2520,6 +2544,10 @@ class AttentionLayers(Module):
|
|
2520
2544
|
|
2521
2545
|
iter_attn_cache = iter(attn_cache)
|
2522
2546
|
|
2547
|
+
# additional self attn key / values
|
2548
|
+
|
2549
|
+
iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
|
2550
|
+
|
2523
2551
|
# handle deep embeds if needed
|
2524
2552
|
|
2525
2553
|
deep_embeds = []
|
@@ -2647,7 +2675,7 @@ class AttentionLayers(Module):
|
|
2647
2675
|
# forward depending on layer type
|
2648
2676
|
|
2649
2677
|
if layer_type == 'a':
|
2650
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2678
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2651
2679
|
elif layer_type == 'c':
|
2652
2680
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
2653
2681
|
elif layer_type == 'f':
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|