x-transformers 1.37.5__tar.gz → 1.37.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.37.5/x_transformers.egg-info → x_transformers-1.37.6}/PKG-INFO +1 -1
  2. {x_transformers-1.37.5 → x_transformers-1.37.6}/setup.py +1 -1
  3. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers/attend.py +4 -4
  4. {x_transformers-1.37.5 → x_transformers-1.37.6/x_transformers.egg-info}/PKG-INFO +1 -1
  5. {x_transformers-1.37.5 → x_transformers-1.37.6}/LICENSE +0 -0
  6. {x_transformers-1.37.5 → x_transformers-1.37.6}/README.md +0 -0
  7. {x_transformers-1.37.5 → x_transformers-1.37.6}/setup.cfg +0 -0
  8. {x_transformers-1.37.5 → x_transformers-1.37.6}/tests/test_x_transformers.py +0 -0
  9. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers/autoregressive_wrapper.py +0 -0
  11. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers/continuous.py +0 -0
  12. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers/dpo.py +0 -0
  13. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers/multi_input.py +0 -0
  14. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers/nonautoregressive_wrapper.py +0 -0
  15. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers/x_transformers.py +0 -0
  16. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.37.5 → x_transformers-1.37.6}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.5
3
+ Version: 1.37.6
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.37.5',
6
+ version = '1.37.6',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -64,15 +64,15 @@ print_once = once(print)
64
64
 
65
65
  # alternative distance functions
66
66
 
67
- def qk_l2_distance(q, k):
67
+ def qk_l2_dist_squared(q, k):
68
68
  if k.ndim == 3:
69
69
  k = repeat(k, 'b j d -> b h j d', h = q.shape[1])
70
70
 
71
71
  q, packed_shape = pack_one(q, '* i d')
72
72
  k, _ = pack_one(k, '* j d')
73
73
 
74
- distance = torch.cdist(q, k)
75
- return unpack_one(distance, packed_shape, '* i j')
74
+ l2_dist_squared = torch.cdist(q, k) ** 2
75
+ return unpack_one(l2_dist_squared, packed_shape, '* i j')
76
76
 
77
77
  # functions for creating causal mask
78
78
  # need a special one for onnx cpu (no support for .triu)
@@ -353,7 +353,7 @@ class Attend(Module):
353
353
  if not self.l2_distance:
354
354
  sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
355
355
  else:
356
- sim = -qk_l2_distance(q, k)
356
+ sim = -qk_l2_dist_squared(q, k)
357
357
 
358
358
  sim = sim * scale
359
359
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.37.5
3
+ Version: 1.37.6
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes