x-transformers 1.26.3__tar.gz → 1.26.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. {x-transformers-1.26.3/x_transformers.egg-info → x-transformers-1.26.4}/PKG-INFO +1 -1
  2. {x-transformers-1.26.3 → x-transformers-1.26.4}/setup.py +1 -1
  3. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers/autoregressive_wrapper.py +8 -5
  4. {x-transformers-1.26.3 → x-transformers-1.26.4/x_transformers.egg-info}/PKG-INFO +1 -1
  5. {x-transformers-1.26.3 → x-transformers-1.26.4}/LICENSE +0 -0
  6. {x-transformers-1.26.3 → x-transformers-1.26.4}/README.md +0 -0
  7. {x-transformers-1.26.3 → x-transformers-1.26.4}/setup.cfg +0 -0
  8. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers/__init__.py +0 -0
  9. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers/attend.py +0 -0
  10. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers/continuous.py +0 -0
  11. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers/nonautoregressive_wrapper.py +0 -0
  12. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers/x_transformers.py +0 -0
  13. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  14. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers/xval.py +0 -0
  15. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers.egg-info/SOURCES.txt +0 -0
  16. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers.egg-info/dependency_links.txt +0 -0
  17. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers.egg-info/requires.txt +0 -0
  18. {x-transformers-1.26.3 → x-transformers-1.26.4}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.26.3
3
+ Version: 1.26.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.26.3',
6
+ version = '1.26.4',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -145,7 +145,7 @@ class AutoregressiveWrapper(Module):
145
145
  cache_kv = True,
146
146
  **kwargs
147
147
  ):
148
- max_seq_len, device = self.max_seq_len, prompts.device
148
+ max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
149
149
 
150
150
  prompts, ps = pack([prompts], '* n')
151
151
 
@@ -230,11 +230,14 @@ class AutoregressiveWrapper(Module):
230
230
 
231
231
  # filter by top_k, top_p (nucleus), top_a, or custom
232
232
 
233
- filtered_logits = filter_logits_fn(logits, **filter_kwargs)
233
+ if greedy:
234
+ sample = logits.argmax(dim = -1, keepdim = True)
235
+ else:
236
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
237
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
238
+ sample = torch.multinomial(probs, 1)
234
239
 
235
- probs = F.softmax(filtered_logits / temperature, dim=-1)
236
-
237
- sample = torch.multinomial(probs, 1)
240
+ # concat sample
238
241
 
239
242
  out = torch.cat((out, sample), dim=-1)
240
243
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.26.3
3
+ Version: 1.26.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes