x-transformers 1.37.5__py3-none-any.whl → 1.37.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -64,15 +64,15 @@ print_once = once(print)
64
64
 
65
65
  # alternative distance functions
66
66
 
67
- def qk_l2_distance(q, k):
67
+ def qk_l2_dist_squared(q, k):
68
68
  if k.ndim == 3:
69
69
  k = repeat(k, 'b j d -> b h j d', h = q.shape[1])
70
70
 
71
71
  q, packed_shape = pack_one(q, '* i d')
72
72
  k, _ = pack_one(k, '* j d')
73
73
 
74
- distance = torch.cdist(q, k)
75
- return unpack_one(distance, packed_shape, '* i j')
74
+ l2_dist_squared = torch.cdist(q, k) ** 2
75
+ return unpack_one(l2_dist_squared, packed_shape, '* i j')
76
76
 
77
77
  # functions for creating causal mask
78
78
  # need a special one for onnx cpu (no support for .triu)
@@ -145,7 +145,6 @@ class Attend(Module):
145
145
 
146
146
  # l2 distance attention
147
147
 
148
- assert not (flash and l2_distance), 'l2 distance attention does not work with flash attention just yet'
149
148
  self.l2_distance = l2_distance
150
149
 
151
150
  # add a key / value token composed of zeros
@@ -208,6 +207,17 @@ class Attend(Module):
208
207
  if v.ndim == 3:
209
208
  v = repeat(v, 'b ... -> b h ...', h = q.shape[1])
210
209
 
210
+ # handle maybe l2 distance
211
+
212
+ if self.l2_distance:
213
+ k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
214
+ k = F.pad(k, (0, 1), value = 1.)
215
+ k = torch.cat((k, -k_norm_sq), dim = -1)
216
+
217
+ q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
218
+ q = torch.cat((2 * q, -q_norm_sq), dim = -1)
219
+ q = F.pad(q, (0, 1), value = 1.)
220
+
211
221
  # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
212
222
 
213
223
  if exists(self.scale):
@@ -353,7 +363,7 @@ class Attend(Module):
353
363
  if not self.l2_distance:
354
364
  sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
355
365
  else:
356
- sim = -qk_l2_distance(q, k)
366
+ sim = -qk_l2_dist_squared(q, k)
357
367
 
358
368
  sim = sim * scale
359
369
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.5
3
+ Version: 1.37.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,5 +1,5 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=4RnX1yhWZIf8holucqnYXTIP7U1m40UpP58RZNT_2sM,13128
2
+ x_transformers/attend.py,sha256=BOnMjgV5O5DyAM_bUz1rI6n1j_eLXu8GIEljT-MMnWU,13434
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
@@ -8,8 +8,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
8
8
  x_transformers/x_transformers.py,sha256=-2fj6QcDSfMI5lJA_fzOW2mdzdS1C1LD6jMBtGQY48E,83752
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.37.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.37.5.dist-info/METADATA,sha256=zHUhvP1bQjFbMtxnVO9iDESgXpGOQxuBCsm4b6K1w44,661
13
- x_transformers-1.37.5.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.37.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.37.5.dist-info/RECORD,,
11
+ x_transformers-1.37.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.37.7.dist-info/METADATA,sha256=svdldk1hpiBN4xxKkOaQuMEXAxvaj6fT7Ri9NXwZJCU,661
13
+ x_transformers-1.37.7.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.37.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.37.7.dist-info/RECORD,,