x-transformers 1.35.0__tar.gz → 1.35.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.35.0/x_transformers.egg-info → x_transformers-1.35.2}/PKG-INFO +1 -1
  2. {x_transformers-1.35.0 → x_transformers-1.35.2}/setup.py +1 -1
  3. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers/autoregressive_wrapper.py +3 -2
  4. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers/x_transformers.py +1 -0
  5. {x_transformers-1.35.0 → x_transformers-1.35.2/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.35.0 → x_transformers-1.35.2}/LICENSE +0 -0
  7. {x_transformers-1.35.0 → x_transformers-1.35.2}/README.md +0 -0
  8. {x_transformers-1.35.0 → x_transformers-1.35.2}/setup.cfg +0 -0
  9. {x_transformers-1.35.0 → x_transformers-1.35.2}/tests/test_x_transformers.py +0 -0
  10. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers/__init__.py +0 -0
  11. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers/attend.py +0 -0
  12. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
  16. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.35.0 → x_transformers-1.35.2}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.35.0
3
+ Version: 1.35.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.35.0',
6
+ version = '1.35.2',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -220,13 +220,14 @@ class AutoregressiveWrapper(Module):
220
220
  if restrict_to_max_seq_len:
221
221
  max_len_exceeded = out.shape[-1] > max_seq_len
222
222
 
223
- assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embeeding. you can switch to rotary embeddings to resolve this issue'
223
+ assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
224
224
 
225
225
  x = out[:, -max_seq_len:]
226
226
 
227
227
  if exists(cache):
228
228
  for inter in cache.attn_intermediates:
229
- inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
229
+ if inter.layer_type == 'a':
230
+ inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
230
231
 
231
232
  logits, new_cache = self.net(
232
233
  x,
@@ -2321,6 +2321,7 @@ class XTransformer(Module):
2321
2321
 
2322
2322
  self.encoder = TransformerWrapper(
2323
2323
  **enc_transformer_kwargs,
2324
+ return_only_embed = True,
2324
2325
  attn_layers = Encoder(dim = dim, **enc_kwargs)
2325
2326
  )
2326
2327
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.35.0
3
+ Version: 1.35.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes