x-transformers 2.11.7__tar.gz → 2.11.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.11.7 → x_transformers-2.11.9}/PKG-INFO +1 -1
- {x_transformers-2.11.7 → x_transformers-2.11.9}/pyproject.toml +1 -1
- {x_transformers-2.11.7 → x_transformers-2.11.9}/tests/test_x_transformers.py +21 -2
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/free_transformer.py +5 -6
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/x_transformers.py +29 -3
- {x_transformers-2.11.7 → x_transformers-2.11.9}/.github/FUNDING.yml +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/.gitignore +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/LICENSE +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/README.md +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/data/README.md +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/data/enwik8.gz +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/all-attention.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/attention-on-attention.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/deepnorm.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/fcm.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/ffglu.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/flash-attention.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/gate_values.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/gating.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/macaron-1.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/macaron-2.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/memory-transformer.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/normformer.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/pia.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/resi_dual.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/residual_attn.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/rezero.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/rotary.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/sandwich-2.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/sandwich.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/sandwich_norm.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/scalenorm.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/talking-heads.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/topk-attention.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/images/xval.png +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/train_belief_state.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/train_copy.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/train_enwik8.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/train_free.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/train_gpt_vae.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/train_length_extrapolate.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/train_parity.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/train_with_muon.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/__init__.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/attend.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/continuous.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/dpo.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.7 → x_transformers-2.11.9}/x_transformers/xval.py +0 -0
|
@@ -1410,7 +1410,9 @@ def test_attn_negative_weights(
|
|
|
1410
1410
|
logits = model(x)
|
|
1411
1411
|
|
|
1412
1412
|
@param('per_token_latents', (False, True))
|
|
1413
|
+
@param('dec_head_depth', (0, 4))
|
|
1413
1414
|
def test_free(
|
|
1415
|
+
dec_head_depth,
|
|
1414
1416
|
per_token_latents
|
|
1415
1417
|
):
|
|
1416
1418
|
from x_transformers.free_transformer import FreeTransformer
|
|
@@ -1420,9 +1422,9 @@ def test_free(
|
|
|
1420
1422
|
max_seq_len = 1024,
|
|
1421
1423
|
dim = 512,
|
|
1422
1424
|
heads = 8,
|
|
1423
|
-
dec_head_depth =
|
|
1425
|
+
dec_head_depth = dec_head_depth,
|
|
1424
1426
|
dec_tail_depth = 4,
|
|
1425
|
-
enc_depth =
|
|
1427
|
+
enc_depth = 2,
|
|
1426
1428
|
kl_loss_weight = 1.,
|
|
1427
1429
|
per_token_latents = per_token_latents,
|
|
1428
1430
|
latent_bits = 8
|
|
@@ -1434,3 +1436,20 @@ def test_free(
|
|
|
1434
1436
|
loss.backward()
|
|
1435
1437
|
|
|
1436
1438
|
assert aux_loss.numel() == 1
|
|
1439
|
+
|
|
1440
|
+
def test_kv_input_residual():
|
|
1441
|
+
attn = Decoder(
|
|
1442
|
+
dim = 256,
|
|
1443
|
+
depth = 2,
|
|
1444
|
+
heads = 4,
|
|
1445
|
+
cross_attend = True
|
|
1446
|
+
)
|
|
1447
|
+
|
|
1448
|
+
tokens = torch.randn(3, 32, 256)
|
|
1449
|
+
context = torch.randn(3, 64, 256)
|
|
1450
|
+
|
|
1451
|
+
condition = torch.randn(2, 3, 64, 256)
|
|
1452
|
+
|
|
1453
|
+
out = attn(tokens, context = context, cross_attn_kv_residuals = condition)
|
|
1454
|
+
|
|
1455
|
+
assert tokens.shape == out.shape
|
|
@@ -197,7 +197,9 @@ class FreeTransformer(Module):
|
|
|
197
197
|
pre_norm_has_final_norm = False,
|
|
198
198
|
**kwargs,
|
|
199
199
|
**dec_kwargs
|
|
200
|
-
)
|
|
200
|
+
) if dec_head_depth > 0 else nn.Identity()
|
|
201
|
+
|
|
202
|
+
assert dec_tail_depth > 0
|
|
201
203
|
|
|
202
204
|
self.decoder_tail = Decoder(
|
|
203
205
|
dim = dim,
|
|
@@ -296,10 +298,7 @@ class FreeTransformer(Module):
|
|
|
296
298
|
|
|
297
299
|
head_embed = self.decoder_head(tokens)
|
|
298
300
|
|
|
299
|
-
|
|
300
|
-
head_embed = head_embed + condition
|
|
301
|
-
|
|
302
|
-
tail_embed = self.decoder_tail(head_embed)
|
|
301
|
+
tail_embed = self.decoder_tail(head_embed, self_attn_kv_residuals = condition)
|
|
303
302
|
|
|
304
303
|
tail_embed = tail_embed[:, -1]
|
|
305
304
|
|
|
@@ -339,7 +338,7 @@ class FreeTransformer(Module):
|
|
|
339
338
|
|
|
340
339
|
# decoder tail
|
|
341
340
|
|
|
342
|
-
tokens = self.decoder_tail(tokens
|
|
341
|
+
tokens = self.decoder_tail(tokens, self_attn_kv_residuals = condition)
|
|
343
342
|
|
|
344
343
|
# cross entropy loss
|
|
345
344
|
|
|
@@ -1686,6 +1686,7 @@ class Attention(Module):
|
|
|
1686
1686
|
value_residual = None,
|
|
1687
1687
|
additional_key_values: tuple[Tensor, Tensor] | None = None,
|
|
1688
1688
|
additional_key_value_mask = None,
|
|
1689
|
+
kv_input_residual = None,
|
|
1689
1690
|
):
|
|
1690
1691
|
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
|
|
1691
1692
|
|
|
@@ -1702,6 +1703,12 @@ class Attention(Module):
|
|
|
1702
1703
|
kv_input = default(context, x)
|
|
1703
1704
|
q_input, k_input, v_input = x, kv_input, kv_input
|
|
1704
1705
|
|
|
1706
|
+
# done for free transformer
|
|
1707
|
+
|
|
1708
|
+
if exists(kv_input_residual):
|
|
1709
|
+
k_input = k_input + kv_input_residual
|
|
1710
|
+
v_input = v_input + kv_input_residual
|
|
1711
|
+
|
|
1705
1712
|
if exists(mem):
|
|
1706
1713
|
k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
|
|
1707
1714
|
v_input, _ = pack([mem, v_input], 'b * d')
|
|
@@ -2543,7 +2550,9 @@ class AttentionLayers(Module):
|
|
|
2543
2550
|
route_additional_kv_to_top = True,
|
|
2544
2551
|
condition = None,
|
|
2545
2552
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
|
2546
|
-
layers_execute_order: tuple[int, ...] | None = None
|
|
2553
|
+
layers_execute_order: tuple[int, ...] | None = None,
|
|
2554
|
+
self_attn_kv_residuals: Tensor | None = None,
|
|
2555
|
+
cross_attn_kv_residuals: Tensor | None = None
|
|
2547
2556
|
):
|
|
2548
2557
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
|
2549
2558
|
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
|
|
@@ -2721,6 +2730,23 @@ class AttentionLayers(Module):
|
|
|
2721
2730
|
|
|
2722
2731
|
skip_hiddens = []
|
|
2723
2732
|
|
|
2733
|
+
# for residuals to key value inputs for self and cross attention
|
|
2734
|
+
|
|
2735
|
+
self_attn_kv_residuals_iter = iter((None,))
|
|
2736
|
+
cross_attn_kv_residuals_iter = iter((None,))
|
|
2737
|
+
|
|
2738
|
+
if exists(self_attn_kv_residuals):
|
|
2739
|
+
if self_attn_kv_residuals.ndim == 3:
|
|
2740
|
+
self_attn_kv_residuals = rearrange(self_attn_kv_residuals, '... -> 1 ...')
|
|
2741
|
+
|
|
2742
|
+
self_attn_kv_residuals_iter = iter(self_attn_kv_residuals)
|
|
2743
|
+
|
|
2744
|
+
if exists(cross_attn_kv_residuals):
|
|
2745
|
+
if cross_attn_kv_residuals.ndim == 3:
|
|
2746
|
+
cross_attn_kv_residuals = rearrange(cross_attn_kv_residuals, '... -> 1 ...')
|
|
2747
|
+
|
|
2748
|
+
cross_attn_kv_residuals_iter = iter(cross_attn_kv_residuals)
|
|
2749
|
+
|
|
2724
2750
|
# for value residuals
|
|
2725
2751
|
|
|
2726
2752
|
first_self_attn_inter = None
|
|
@@ -2794,9 +2820,9 @@ class AttentionLayers(Module):
|
|
|
2794
2820
|
# forward depending on layer type
|
|
2795
2821
|
|
|
2796
2822
|
if layer_type == 'a':
|
|
2797
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
|
2823
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
|
2798
2824
|
elif layer_type == 'c':
|
|
2799
|
-
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
|
2825
|
+
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), kv_input_residual = next(cross_attn_kv_residuals_iter, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
|
2800
2826
|
elif layer_type == 'f':
|
|
2801
2827
|
out = block(x, deep_embed = next(deep_embeds_iter, None))
|
|
2802
2828
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|