x-transformers 1.37.5__py3-none-any.whl → 1.37.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -64,15 +64,15 @@ print_once = once(print)
64
64
 
65
65
  # alternative distance functions
66
66
 
67
- def qk_l2_distance(q, k):
67
+ def qk_l2_dist_squared(q, k):
68
68
  if k.ndim == 3:
69
69
  k = repeat(k, 'b j d -> b h j d', h = q.shape[1])
70
70
 
71
71
  q, packed_shape = pack_one(q, '* i d')
72
72
  k, _ = pack_one(k, '* j d')
73
73
 
74
- distance = torch.cdist(q, k)
75
- return unpack_one(distance, packed_shape, '* i j')
74
+ l2_dist_squared = torch.cdist(q, k) ** 2
75
+ return unpack_one(l2_dist_squared, packed_shape, '* i j')
76
76
 
77
77
  # functions for creating causal mask
78
78
  # need a special one for onnx cpu (no support for .triu)
@@ -353,7 +353,7 @@ class Attend(Module):
353
353
  if not self.l2_distance:
354
354
  sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
355
355
  else:
356
- sim = -qk_l2_distance(q, k)
356
+ sim = -qk_l2_dist_squared(q, k)
357
357
 
358
358
  sim = sim * scale
359
359
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.5
3
+ Version: 1.37.6
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,5 +1,5 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=4RnX1yhWZIf8holucqnYXTIP7U1m40UpP58RZNT_2sM,13128
2
+ x_transformers/attend.py,sha256=w3Gcy3gbeKGhIyKUiJKo8F90bPN-JcnQq3rjlVhG5sE,13155
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
@@ -8,8 +8,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
8
8
  x_transformers/x_transformers.py,sha256=-2fj6QcDSfMI5lJA_fzOW2mdzdS1C1LD6jMBtGQY48E,83752
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.37.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.37.5.dist-info/METADATA,sha256=zHUhvP1bQjFbMtxnVO9iDESgXpGOQxuBCsm4b6K1w44,661
13
- x_transformers-1.37.5.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.37.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.37.5.dist-info/RECORD,,
11
+ x_transformers-1.37.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.37.6.dist-info/METADATA,sha256=IBUfeibj2CrLmh2-UPCokvC-ylmoBhc5OC-ZFRirvek,661
13
+ x_transformers-1.37.6.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.37.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.37.6.dist-info/RECORD,,