x-transformers 1.39.3__tar.gz → 1.39.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.39.3/x_transformers.egg-info → x_transformers-1.39.4}/PKG-INFO +1 -1
  2. {x_transformers-1.39.3 → x_transformers-1.39.4}/setup.py +1 -1
  3. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers/attend.py +3 -3
  4. {x_transformers-1.39.3 → x_transformers-1.39.4/x_transformers.egg-info}/PKG-INFO +1 -1
  5. {x_transformers-1.39.3 → x_transformers-1.39.4}/LICENSE +0 -0
  6. {x_transformers-1.39.3 → x_transformers-1.39.4}/README.md +0 -0
  7. {x_transformers-1.39.3 → x_transformers-1.39.4}/setup.cfg +0 -0
  8. {x_transformers-1.39.3 → x_transformers-1.39.4}/tests/test_x_transformers.py +0 -0
  9. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers/autoregressive_wrapper.py +0 -0
  11. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers/continuous.py +0 -0
  12. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers/dpo.py +0 -0
  13. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers/multi_input.py +0 -0
  14. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers/nonautoregressive_wrapper.py +0 -0
  15. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers/x_transformers.py +0 -0
  16. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.39.3 → x_transformers-1.39.4}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.39.3
3
+ Version: 1.39.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.39.3',
6
+ version = '1.39.4',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -107,9 +107,9 @@ def qk_l2_dist_squared(q, k):
107
107
  l2_dist_squared = torch.cdist(q, k) ** 2
108
108
  return unpack_one(l2_dist_squared, packed_shape, '* i j')
109
109
 
110
- # gumbel softmax
110
+ # one-hot straight through softmax
111
111
 
112
- def gumbel_softmax(t, temperature = 1.):
112
+ def one_hot_straight_through(t, temperature = 1.):
113
113
  one_hot_indices = t.argmax(dim = -1, keepdim = True)
114
114
  one_hot = torch.zeros_like(t).scatter(-1, one_hot_indices, 1.)
115
115
 
@@ -180,7 +180,7 @@ class Attend(Module):
180
180
  elif sigmoid:
181
181
  self.attn_fn = F.sigmoid
182
182
  elif hard:
183
- self.attn_fn = gumbel_softmax
183
+ self.attn_fn = one_hot_straight_through
184
184
  else:
185
185
  softmax_fn = partial(F.softmax, dim = -1)
186
186
  self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.39.3
3
+ Version: 1.39.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes