x-transformers 1.27.22__py3-none-any.whl → 1.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +14 -3
- {x_transformers-1.27.22.dist-info → x_transformers-1.28.1.dist-info}/METADATA +1 -1
- {x_transformers-1.27.22.dist-info → x_transformers-1.28.1.dist-info}/RECORD +6 -6
- {x_transformers-1.27.22.dist-info → x_transformers-1.28.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.27.22.dist-info → x_transformers-1.28.1.dist-info}/WHEEL +0 -0
- {x_transformers-1.27.22.dist-info → x_transformers-1.28.1.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -889,7 +889,15 @@ class Attention(nn.Module):
|
|
889
889
|
else:
|
890
890
|
input_mask = torch.cat((mem_mask, input_mask), dim = -1)
|
891
891
|
|
892
|
-
|
892
|
+
# i, j determined for relative positional bias, excluding memory key / values
|
893
|
+
|
894
|
+
i, j = map(lambda t: t.shape[-2], (q, k))
|
895
|
+
|
896
|
+
# maybe append memory key / values
|
897
|
+
|
898
|
+
has_mem_kv = self.num_mem_kv > 0
|
899
|
+
|
900
|
+
if has_mem_kv:
|
893
901
|
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
|
894
902
|
|
895
903
|
if self.qk_norm:
|
@@ -902,8 +910,6 @@ class Attention(nn.Module):
|
|
902
910
|
if exists(input_mask):
|
903
911
|
input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
|
904
912
|
|
905
|
-
i, j = map(lambda t: t.shape[-2], (q, k))
|
906
|
-
|
907
913
|
# determine masking
|
908
914
|
|
909
915
|
mask_value = max_neg_value(q)
|
@@ -938,6 +944,11 @@ class Attention(nn.Module):
|
|
938
944
|
if exists(rel_pos):
|
939
945
|
attn_bias = rel_pos(i, j)
|
940
946
|
|
947
|
+
# append with no bias for memory key / values
|
948
|
+
|
949
|
+
if exists(attn_bias) and has_mem_kv:
|
950
|
+
attn_bias = pad_at_dim(attn_bias, (self.num_mem_kv, 0), value = 0.)
|
951
|
+
|
941
952
|
# attention is all we need
|
942
953
|
|
943
954
|
out, intermediates = self.attend(
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
|
|
4
4
|
x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=F1ukKlTGhNaYmMLnm7nJy0XaUyAYwtrGGq8Gw94RPyQ,63919
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
|
10
|
-
x_transformers-1.
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
10
|
+
x_transformers-1.28.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.28.1.dist-info/METADATA,sha256=BjW_kBCR9LCWa1CbdAkd7vMBEVYkXTgLBbNw74Tz2R4,661
|
12
|
+
x_transformers-1.28.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.28.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.28.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|