x-transformers 2.11.5__tar.gz → 2.11.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- {x_transformers-2.11.5 → x_transformers-2.11.8}/PKG-INFO +1 -1
- {x_transformers-2.11.5 → x_transformers-2.11.8}/pyproject.toml +1 -1
- {x_transformers-2.11.5 → x_transformers-2.11.8}/tests/test_x_transformers.py +17 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/free_transformer.py +28 -21
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/x_transformers.py +29 -3
- {x_transformers-2.11.5 → x_transformers-2.11.8}/.github/FUNDING.yml +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/.gitignore +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/LICENSE +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/README.md +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/data/README.md +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/data/enwik8.gz +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/all-attention.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/attention-on-attention.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/deepnorm.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/fcm.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/ffglu.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/flash-attention.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/gate_values.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/gating.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/macaron-1.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/macaron-2.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/memory-transformer.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/normformer.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/pia.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/resi_dual.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/residual_attn.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/rezero.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/rotary.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/sandwich-2.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/sandwich.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/sandwich_norm.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/scalenorm.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/talking-heads.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/topk-attention.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/images/xval.png +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/train_belief_state.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/train_copy.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/train_enwik8.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/train_free.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/train_gpt_vae.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/train_length_extrapolate.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/train_parity.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/train_with_muon.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/__init__.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/attend.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/continuous.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/dpo.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.11.5 → x_transformers-2.11.8}/x_transformers/xval.py +0 -0
|
@@ -1434,3 +1434,20 @@ def test_free(
|
|
|
1434
1434
|
loss.backward()
|
|
1435
1435
|
|
|
1436
1436
|
assert aux_loss.numel() == 1
|
|
1437
|
+
|
|
1438
|
+
def test_kv_input_residual():
|
|
1439
|
+
attn = Decoder(
|
|
1440
|
+
dim = 256,
|
|
1441
|
+
depth = 2,
|
|
1442
|
+
heads = 4,
|
|
1443
|
+
cross_attend = True
|
|
1444
|
+
)
|
|
1445
|
+
|
|
1446
|
+
tokens = torch.randn(3, 32, 256)
|
|
1447
|
+
context = torch.randn(3, 64, 256)
|
|
1448
|
+
|
|
1449
|
+
condition = torch.randn(2, 3, 64, 256)
|
|
1450
|
+
|
|
1451
|
+
out = attn(tokens, context = context, cross_attn_kv_residuals = condition)
|
|
1452
|
+
|
|
1453
|
+
assert tokens.shape == out.shape
|
|
@@ -66,7 +66,6 @@ class BinaryMapper(Module):
|
|
|
66
66
|
|
|
67
67
|
self.bits = bits
|
|
68
68
|
self.num_codes = 2 ** bits
|
|
69
|
-
self.kl_loss_threshold = kl_loss_threshold
|
|
70
69
|
|
|
71
70
|
power_two = 2 ** arange(bits)
|
|
72
71
|
codes = (arange(self.num_codes)[:, None].bitwise_and(power_two) != 0).byte().bool()
|
|
@@ -74,13 +73,20 @@ class BinaryMapper(Module):
|
|
|
74
73
|
self.register_buffer('power_two', power_two, persistent = False)
|
|
75
74
|
self.register_buffer('codes', codes, persistent = False)
|
|
76
75
|
|
|
76
|
+
# aux loss
|
|
77
|
+
|
|
78
|
+
self.kl_loss_threshold = kl_loss_threshold
|
|
79
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
80
|
+
|
|
77
81
|
def forward(
|
|
78
82
|
self,
|
|
79
83
|
logits,
|
|
80
84
|
temperature = 1.,
|
|
81
|
-
straight_through = None
|
|
85
|
+
straight_through = None,
|
|
86
|
+
calc_aux_loss = None
|
|
82
87
|
):
|
|
83
88
|
straight_through = default(straight_through, self.training)
|
|
89
|
+
calc_aux_loss = default(calc_aux_loss, self.training)
|
|
84
90
|
|
|
85
91
|
assert logits.shape[-1] == self.bits, f'logits must have a last dimension of {self.bits}'
|
|
86
92
|
|
|
@@ -95,26 +101,29 @@ class BinaryMapper(Module):
|
|
|
95
101
|
|
|
96
102
|
one_hot = F.one_hot(indices, self.num_codes).float()
|
|
97
103
|
|
|
98
|
-
#
|
|
104
|
+
# maybe calculate aux loss
|
|
99
105
|
|
|
100
|
-
|
|
101
|
-
return one_hot
|
|
106
|
+
aux_kl_loss = self.zero
|
|
102
107
|
|
|
103
|
-
|
|
108
|
+
if calc_aux_loss:
|
|
109
|
+
# calculate negative entropy
|
|
104
110
|
|
|
105
|
-
|
|
106
|
-
|
|
111
|
+
kl_div = self.bits * NAT - binary_entropy(logits)
|
|
112
|
+
aux_kl_loss = F.relu(kl_div - self.kl_loss_threshold).mean()
|
|
107
113
|
|
|
108
|
-
#
|
|
114
|
+
# maybe straight through
|
|
109
115
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
|
|
113
|
-
).exp()
|
|
116
|
+
if straight_through:
|
|
117
|
+
# get the soft G for the gradients and do a straight through
|
|
114
118
|
|
|
115
|
-
|
|
119
|
+
soft_G = (
|
|
120
|
+
einsum(F.logsigmoid(logits), self.codes.float(), '... bits, codes bits -> ... codes') +
|
|
121
|
+
einsum(F.logsigmoid(-logits), (~self.codes).float(), '... bits, codes bits -> ... codes')
|
|
122
|
+
).exp()
|
|
116
123
|
|
|
117
|
-
|
|
124
|
+
# straight through
|
|
125
|
+
|
|
126
|
+
one_hot = one_hot + soft_G - soft_G.detach()
|
|
118
127
|
|
|
119
128
|
return one_hot, aux_kl_loss
|
|
120
129
|
|
|
@@ -163,6 +172,7 @@ class FreeTransformer(Module):
|
|
|
163
172
|
cross_attend = True,
|
|
164
173
|
use_rmsnorm = True,
|
|
165
174
|
rotary_pos_emb = True,
|
|
175
|
+
pre_norm_has_final_norm = True,
|
|
166
176
|
**kwargs,
|
|
167
177
|
**enc_kwargs
|
|
168
178
|
)
|
|
@@ -242,7 +252,7 @@ class FreeTransformer(Module):
|
|
|
242
252
|
|
|
243
253
|
bit_logits = self.to_latent_bit_logits(pooled)
|
|
244
254
|
|
|
245
|
-
one_hot_latents, kl_loss = self.binary_mapper(bit_logits,
|
|
255
|
+
one_hot_latents, kl_loss = self.binary_mapper(bit_logits, calc_aux_loss = return_kl_loss)
|
|
246
256
|
|
|
247
257
|
if not return_kl_loss:
|
|
248
258
|
return one_hot_latents
|
|
@@ -286,10 +296,7 @@ class FreeTransformer(Module):
|
|
|
286
296
|
|
|
287
297
|
head_embed = self.decoder_head(tokens)
|
|
288
298
|
|
|
289
|
-
|
|
290
|
-
head_embed = head_embed + condition
|
|
291
|
-
|
|
292
|
-
tail_embed = self.decoder_tail(head_embed)
|
|
299
|
+
tail_embed = self.decoder_tail(head_embed, self_attn_kv_residuals = condition)
|
|
293
300
|
|
|
294
301
|
tail_embed = tail_embed[:, -1]
|
|
295
302
|
|
|
@@ -329,7 +336,7 @@ class FreeTransformer(Module):
|
|
|
329
336
|
|
|
330
337
|
# decoder tail
|
|
331
338
|
|
|
332
|
-
tokens = self.decoder_tail(tokens
|
|
339
|
+
tokens = self.decoder_tail(tokens, self_attn_kv_residuals = condition)
|
|
333
340
|
|
|
334
341
|
# cross entropy loss
|
|
335
342
|
|
|
@@ -1686,6 +1686,7 @@ class Attention(Module):
|
|
|
1686
1686
|
value_residual = None,
|
|
1687
1687
|
additional_key_values: tuple[Tensor, Tensor] | None = None,
|
|
1688
1688
|
additional_key_value_mask = None,
|
|
1689
|
+
kv_input_residual = None,
|
|
1689
1690
|
):
|
|
1690
1691
|
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
|
|
1691
1692
|
|
|
@@ -1702,6 +1703,12 @@ class Attention(Module):
|
|
|
1702
1703
|
kv_input = default(context, x)
|
|
1703
1704
|
q_input, k_input, v_input = x, kv_input, kv_input
|
|
1704
1705
|
|
|
1706
|
+
# done for free transformer
|
|
1707
|
+
|
|
1708
|
+
if exists(kv_input_residual):
|
|
1709
|
+
k_input = k_input + kv_input_residual
|
|
1710
|
+
v_input = v_input + kv_input_residual
|
|
1711
|
+
|
|
1705
1712
|
if exists(mem):
|
|
1706
1713
|
k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
|
|
1707
1714
|
v_input, _ = pack([mem, v_input], 'b * d')
|
|
@@ -2543,7 +2550,9 @@ class AttentionLayers(Module):
|
|
|
2543
2550
|
route_additional_kv_to_top = True,
|
|
2544
2551
|
condition = None,
|
|
2545
2552
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
|
2546
|
-
layers_execute_order: tuple[int, ...] | None = None
|
|
2553
|
+
layers_execute_order: tuple[int, ...] | None = None,
|
|
2554
|
+
self_attn_kv_residuals: Tensor | None = None,
|
|
2555
|
+
cross_attn_kv_residuals: Tensor | None = None
|
|
2547
2556
|
):
|
|
2548
2557
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
|
2549
2558
|
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
|
|
@@ -2721,6 +2730,23 @@ class AttentionLayers(Module):
|
|
|
2721
2730
|
|
|
2722
2731
|
skip_hiddens = []
|
|
2723
2732
|
|
|
2733
|
+
# for residuals to key value inputs for self and cross attention
|
|
2734
|
+
|
|
2735
|
+
self_attn_kv_residuals_iter = iter((None,))
|
|
2736
|
+
cross_attn_kv_residuals_iter = iter((None,))
|
|
2737
|
+
|
|
2738
|
+
if exists(self_attn_kv_residuals):
|
|
2739
|
+
if self_attn_kv_residuals.ndim == 3:
|
|
2740
|
+
self_attn_kv_residuals = rearrange(self_attn_kv_residuals, '... -> 1 ...')
|
|
2741
|
+
|
|
2742
|
+
self_attn_kv_residuals_iter = iter(self_attn_kv_residuals)
|
|
2743
|
+
|
|
2744
|
+
if exists(cross_attn_kv_residuals):
|
|
2745
|
+
if cross_attn_kv_residuals.ndim == 3:
|
|
2746
|
+
cross_attn_kv_residuals = rearrange(cross_attn_kv_residuals, '... -> 1 ...')
|
|
2747
|
+
|
|
2748
|
+
cross_attn_kv_residuals_iter = iter(cross_attn_kv_residuals)
|
|
2749
|
+
|
|
2724
2750
|
# for value residuals
|
|
2725
2751
|
|
|
2726
2752
|
first_self_attn_inter = None
|
|
@@ -2794,9 +2820,9 @@ class AttentionLayers(Module):
|
|
|
2794
2820
|
# forward depending on layer type
|
|
2795
2821
|
|
|
2796
2822
|
if layer_type == 'a':
|
|
2797
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
|
2823
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
|
2798
2824
|
elif layer_type == 'c':
|
|
2799
|
-
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
|
2825
|
+
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), kv_input_residual = next(cross_attn_kv_residuals_iter, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
|
2800
2826
|
elif layer_type == 'f':
|
|
2801
2827
|
out = block(x, deep_embed = next(deep_embeds_iter, None))
|
|
2802
2828
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|