x-transformers 2.11.7__py3-none-any.whl → 2.11.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of x-transformers might be problematic. Click here for more details.
- x_transformers/free_transformer.py +2 -5
- x_transformers/x_transformers.py +29 -3
- {x_transformers-2.11.7.dist-info → x_transformers-2.11.8.dist-info}/METADATA +1 -1
- {x_transformers-2.11.7.dist-info → x_transformers-2.11.8.dist-info}/RECORD +6 -6
- {x_transformers-2.11.7.dist-info → x_transformers-2.11.8.dist-info}/WHEEL +0 -0
- {x_transformers-2.11.7.dist-info → x_transformers-2.11.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -296,10 +296,7 @@ class FreeTransformer(Module):
|
|
|
296
296
|
|
|
297
297
|
head_embed = self.decoder_head(tokens)
|
|
298
298
|
|
|
299
|
-
|
|
300
|
-
head_embed = head_embed + condition
|
|
301
|
-
|
|
302
|
-
tail_embed = self.decoder_tail(head_embed)
|
|
299
|
+
tail_embed = self.decoder_tail(head_embed, self_attn_kv_residuals = condition)
|
|
303
300
|
|
|
304
301
|
tail_embed = tail_embed[:, -1]
|
|
305
302
|
|
|
@@ -339,7 +336,7 @@ class FreeTransformer(Module):
|
|
|
339
336
|
|
|
340
337
|
# decoder tail
|
|
341
338
|
|
|
342
|
-
tokens = self.decoder_tail(tokens
|
|
339
|
+
tokens = self.decoder_tail(tokens, self_attn_kv_residuals = condition)
|
|
343
340
|
|
|
344
341
|
# cross entropy loss
|
|
345
342
|
|
x_transformers/x_transformers.py
CHANGED
|
@@ -1686,6 +1686,7 @@ class Attention(Module):
|
|
|
1686
1686
|
value_residual = None,
|
|
1687
1687
|
additional_key_values: tuple[Tensor, Tensor] | None = None,
|
|
1688
1688
|
additional_key_value_mask = None,
|
|
1689
|
+
kv_input_residual = None,
|
|
1689
1690
|
):
|
|
1690
1691
|
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
|
|
1691
1692
|
|
|
@@ -1702,6 +1703,12 @@ class Attention(Module):
|
|
|
1702
1703
|
kv_input = default(context, x)
|
|
1703
1704
|
q_input, k_input, v_input = x, kv_input, kv_input
|
|
1704
1705
|
|
|
1706
|
+
# done for free transformer
|
|
1707
|
+
|
|
1708
|
+
if exists(kv_input_residual):
|
|
1709
|
+
k_input = k_input + kv_input_residual
|
|
1710
|
+
v_input = v_input + kv_input_residual
|
|
1711
|
+
|
|
1705
1712
|
if exists(mem):
|
|
1706
1713
|
k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
|
|
1707
1714
|
v_input, _ = pack([mem, v_input], 'b * d')
|
|
@@ -2543,7 +2550,9 @@ class AttentionLayers(Module):
|
|
|
2543
2550
|
route_additional_kv_to_top = True,
|
|
2544
2551
|
condition = None,
|
|
2545
2552
|
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
|
|
2546
|
-
layers_execute_order: tuple[int, ...] | None = None
|
|
2553
|
+
layers_execute_order: tuple[int, ...] | None = None,
|
|
2554
|
+
self_attn_kv_residuals: Tensor | None = None,
|
|
2555
|
+
cross_attn_kv_residuals: Tensor | None = None
|
|
2547
2556
|
):
|
|
2548
2557
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
|
2549
2558
|
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
|
|
@@ -2721,6 +2730,23 @@ class AttentionLayers(Module):
|
|
|
2721
2730
|
|
|
2722
2731
|
skip_hiddens = []
|
|
2723
2732
|
|
|
2733
|
+
# for residuals to key value inputs for self and cross attention
|
|
2734
|
+
|
|
2735
|
+
self_attn_kv_residuals_iter = iter((None,))
|
|
2736
|
+
cross_attn_kv_residuals_iter = iter((None,))
|
|
2737
|
+
|
|
2738
|
+
if exists(self_attn_kv_residuals):
|
|
2739
|
+
if self_attn_kv_residuals.ndim == 3:
|
|
2740
|
+
self_attn_kv_residuals = rearrange(self_attn_kv_residuals, '... -> 1 ...')
|
|
2741
|
+
|
|
2742
|
+
self_attn_kv_residuals_iter = iter(self_attn_kv_residuals)
|
|
2743
|
+
|
|
2744
|
+
if exists(cross_attn_kv_residuals):
|
|
2745
|
+
if cross_attn_kv_residuals.ndim == 3:
|
|
2746
|
+
cross_attn_kv_residuals = rearrange(cross_attn_kv_residuals, '... -> 1 ...')
|
|
2747
|
+
|
|
2748
|
+
cross_attn_kv_residuals_iter = iter(cross_attn_kv_residuals)
|
|
2749
|
+
|
|
2724
2750
|
# for value residuals
|
|
2725
2751
|
|
|
2726
2752
|
first_self_attn_inter = None
|
|
@@ -2794,9 +2820,9 @@ class AttentionLayers(Module):
|
|
|
2794
2820
|
# forward depending on layer type
|
|
2795
2821
|
|
|
2796
2822
|
if layer_type == 'a':
|
|
2797
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
|
2823
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, kv_input_residual = next(self_attn_kv_residuals_iter, None), value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
|
2798
2824
|
elif layer_type == 'c':
|
|
2799
|
-
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
|
2825
|
+
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), kv_input_residual = next(cross_attn_kv_residuals_iter, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
|
|
2800
2826
|
elif layer_type == 'f':
|
|
2801
2827
|
out = block(x, deep_embed = next(deep_embeds_iter, None))
|
|
2802
2828
|
|
|
@@ -5,16 +5,16 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
|
|
|
5
5
|
x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
|
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
|
8
|
-
x_transformers/free_transformer.py,sha256=
|
|
8
|
+
x_transformers/free_transformer.py,sha256=kfl_MIZxv4TARRQbq3NroGwZSBVHdYoJNu1hfWMloco,9555
|
|
9
9
|
x_transformers/gpt_vae.py,sha256=4QdznXZcU7pmMXUeEocAOKpcTkREYS-zDHktN5ADtNk,5981
|
|
10
10
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
|
11
11
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
|
12
12
|
x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
|
|
13
13
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
|
14
|
-
x_transformers/x_transformers.py,sha256=
|
|
14
|
+
x_transformers/x_transformers.py,sha256=bYnVtkcfr082ALprIGgYIUx53lLADGYpi9t6QEJp1Kc,126907
|
|
15
15
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
16
16
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
17
|
-
x_transformers-2.11.
|
|
18
|
-
x_transformers-2.11.
|
|
19
|
-
x_transformers-2.11.
|
|
20
|
-
x_transformers-2.11.
|
|
17
|
+
x_transformers-2.11.8.dist-info/METADATA,sha256=NTTtQVh5bRCnk7RDpma7JHairCWIvaO2euEE2djXUFA,96011
|
|
18
|
+
x_transformers-2.11.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
19
|
+
x_transformers-2.11.8.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
20
|
+
x_transformers-2.11.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|