x-transformers 2.5.5__tar.gz → 2.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.5.5 → x_transformers-2.6.0}/PKG-INFO +1 -1
- {x_transformers-2.5.5 → x_transformers-2.6.0}/pyproject.toml +1 -1
- {x_transformers-2.5.5 → x_transformers-2.6.0}/tests/test_x_transformers.py +23 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/x_transformers.py +22 -3
- {x_transformers-2.5.5 → x_transformers-2.6.0}/.github/FUNDING.yml +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/.gitignore +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/LICENSE +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/README.md +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/data/README.md +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/data/enwik8.gz +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/all-attention.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/attention-on-attention.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/deepnorm.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/fcm.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/ffglu.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/flash-attention.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/gate_values.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/gating.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/macaron-1.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/macaron-2.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/memory-transformer.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/normformer.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/pia.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/resi_dual.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/residual_attn.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/rezero.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/rotary.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/sandwich-2.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/sandwich.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/sandwich_norm.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/scalenorm.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/talking-heads.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/topk-attention.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/images/xval.png +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/train_belief_state.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/train_copy.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/train_enwik8.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/train_length_extrapolate.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/train_parity.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/__init__.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/attend.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/continuous.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/dpo.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.5.5 → x_transformers-2.6.0}/x_transformers/xval.py +0 -0
@@ -1210,3 +1210,26 @@ def test_prompts_given_as_list_tensor():
|
|
1210
1210
|
], 256)
|
1211
1211
|
|
1212
1212
|
assert sampled.shape == (4, 256)
|
1213
|
+
|
1214
|
+
def test_external_key_values():
|
1215
|
+
from x_transformers import AutoregressiveWrapper
|
1216
|
+
|
1217
|
+
model = TransformerWrapper(
|
1218
|
+
num_tokens = 20000,
|
1219
|
+
max_seq_len = 1024,
|
1220
|
+
attn_layers = Decoder(
|
1221
|
+
dim = 512,
|
1222
|
+
depth = 2,
|
1223
|
+
heads = 8,
|
1224
|
+
attn_dim_head = 16
|
1225
|
+
)
|
1226
|
+
)
|
1227
|
+
|
1228
|
+
seq = torch.randint(0, 20000, (3, 1024))
|
1229
|
+
|
1230
|
+
key_values = [
|
1231
|
+
(torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
|
1232
|
+
(torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
|
1233
|
+
]
|
1234
|
+
|
1235
|
+
logits = model(seq, self_attn_additional_kv = key_values)
|
@@ -1617,7 +1617,8 @@ class Attention(Module):
|
|
1617
1617
|
mem_mask = None,
|
1618
1618
|
return_intermediates = False,
|
1619
1619
|
cache: Intermediates | None = None,
|
1620
|
-
value_residual = None
|
1620
|
+
value_residual = None,
|
1621
|
+
additional_key_values: tuple[Tensor, Tensor] | None = None
|
1621
1622
|
):
|
1622
1623
|
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
|
1623
1624
|
|
@@ -1787,6 +1788,19 @@ class Attention(Module):
|
|
1787
1788
|
if exists(input_mask):
|
1788
1789
|
input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
|
1789
1790
|
|
1791
|
+
# maybe append additional key / values
|
1792
|
+
|
1793
|
+
if exists(additional_key_values):
|
1794
|
+
|
1795
|
+
added_k, added_v = additional_key_values
|
1796
|
+
added_kv_len = added_k.shape[-2]
|
1797
|
+
|
1798
|
+
k = cat((added_k, k), dim = -2)
|
1799
|
+
v = cat((added_v, v), dim = -2)
|
1800
|
+
|
1801
|
+
if exists(input_mask):
|
1802
|
+
input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
|
1803
|
+
|
1790
1804
|
# determine masking
|
1791
1805
|
|
1792
1806
|
mask_value = max_neg_value(q)
|
@@ -2267,7 +2281,7 @@ class AttentionLayers(Module):
|
|
2267
2281
|
|
2268
2282
|
# whether it has post norm
|
2269
2283
|
|
2270
|
-
self.final_norm = norm_fn() if pre_norm else nn.Identity()
|
2284
|
+
self.final_norm = norm_fn() if pre_norm and pre_norm_has_final_norm else nn.Identity()
|
2271
2285
|
|
2272
2286
|
# whether unet or not
|
2273
2287
|
|
@@ -2411,6 +2425,7 @@ class AttentionLayers(Module):
|
|
2411
2425
|
context_pos = None,
|
2412
2426
|
attn_bias = None,
|
2413
2427
|
deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
|
2428
|
+
self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
|
2414
2429
|
condition = None,
|
2415
2430
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
2416
2431
|
layers_execute_order: tuple[int, ...] | None = None
|
@@ -2520,6 +2535,10 @@ class AttentionLayers(Module):
|
|
2520
2535
|
|
2521
2536
|
iter_attn_cache = iter(attn_cache)
|
2522
2537
|
|
2538
|
+
# additional self attn key / values
|
2539
|
+
|
2540
|
+
iter_self_attn_kv = iter(default(self_attn_additional_kv, ()))
|
2541
|
+
|
2523
2542
|
# handle deep embeds if needed
|
2524
2543
|
|
2525
2544
|
deep_embeds = []
|
@@ -2647,7 +2666,7 @@ class AttentionLayers(Module):
|
|
2647
2666
|
# forward depending on layer type
|
2648
2667
|
|
2649
2668
|
if layer_type == 'a':
|
2650
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2669
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
2651
2670
|
elif layer_type == 'c':
|
2652
2671
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
2653
2672
|
elif layer_type == 'f':
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|