x-transformers 1.27.22__py3-none-any.whl → 1.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -889,7 +889,15 @@ class Attention(nn.Module):
889
889
  else:
890
890
  input_mask = torch.cat((mem_mask, input_mask), dim = -1)
891
891
 
892
- if self.num_mem_kv > 0:
892
+ # i, j determined for relative positional bias, excluding memory key / values
893
+
894
+ i, j = map(lambda t: t.shape[-2], (q, k))
895
+
896
+ # maybe append memory key / values
897
+
898
+ has_mem_kv = self.num_mem_kv > 0
899
+
900
+ if has_mem_kv:
893
901
  mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
894
902
 
895
903
  if self.qk_norm:
@@ -902,8 +910,6 @@ class Attention(nn.Module):
902
910
  if exists(input_mask):
903
911
  input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
904
912
 
905
- i, j = map(lambda t: t.shape[-2], (q, k))
906
-
907
913
  # determine masking
908
914
 
909
915
  mask_value = max_neg_value(q)
@@ -938,6 +944,11 @@ class Attention(nn.Module):
938
944
  if exists(rel_pos):
939
945
  attn_bias = rel_pos(i, j)
940
946
 
947
+ # append with no bias for memory key / values
948
+
949
+ if has_mem_kv:
950
+ attn_bias = pad_at_dim(attn_bias, (self.num_mem_kv, 0), value = 0.)
951
+
941
952
  # attention is all we need
942
953
 
943
954
  out, intermediates = self.attend(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.27.22
3
+ Version: 1.28.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
4
4
  x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=kQhRUMGDsinzkdYcOfE1GriJ057j7D4xSjbH79FFRSE,63574
7
+ x_transformers/x_transformers.py,sha256=GvqVKQZRtIldnSWX4V6qE2sWOGruRvBhk4MVit7ZD_M,63897
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
10
- x_transformers-1.27.22.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.27.22.dist-info/METADATA,sha256=RTbXIIpFRnve8FVp8vLQ4LE-9x59IV6ADnu34gGAZXA,662
12
- x_transformers-1.27.22.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.27.22.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.27.22.dist-info/RECORD,,
10
+ x_transformers-1.28.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.28.0.dist-info/METADATA,sha256=o1AbarRMIJY_R0gNaEm5SNUWm3YHEesLL2EEy_Uk6gA,661
12
+ x_transformers-1.28.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.28.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.28.0.dist-info/RECORD,,