x-transformers 1.42.16__tar.gz → 1.42.18__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.42.16/x_transformers.egg-info → x_transformers-1.42.18}/PKG-INFO +3 -2
  2. {x_transformers-1.42.16 → x_transformers-1.42.18}/setup.py +3 -2
  3. {x_transformers-1.42.16 → x_transformers-1.42.18}/tests/test_x_transformers.py +6 -2
  4. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers/x_transformers.py +11 -2
  5. {x_transformers-1.42.16 → x_transformers-1.42.18/x_transformers.egg-info}/PKG-INFO +3 -2
  6. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers.egg-info/requires.txt +2 -1
  7. {x_transformers-1.42.16 → x_transformers-1.42.18}/LICENSE +0 -0
  8. {x_transformers-1.42.16 → x_transformers-1.42.18}/README.md +0 -0
  9. {x_transformers-1.42.16 → x_transformers-1.42.18}/setup.cfg +0 -0
  10. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers/__init__.py +0 -0
  11. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers/attend.py +0 -0
  12. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers/autoregressive_wrapper.py +0 -0
  13. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers/continuous.py +0 -0
  14. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers/dpo.py +0 -0
  15. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers/multi_input.py +0 -0
  16. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers/neo_mlp.py +0 -0
  17. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers/nonautoregressive_wrapper.py +0 -0
  18. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  19. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers/xval.py +0 -0
  20. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers.egg-info/SOURCES.txt +0 -0
  21. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers.egg-info/dependency_links.txt +0 -0
  22. {x_transformers-1.42.16 → x_transformers-1.42.18}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.16
3
+ Version: 1.42.18
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -14,7 +14,8 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Programming Language :: Python :: 3.6
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: torch>=2.0
18
17
  Requires-Dist: einx>=0.3.0
19
18
  Requires-Dist: einops>=0.8.0
19
+ Requires-Dist: loguru
20
20
  Requires-Dist: packaging>=21.0
21
+ Requires-Dist: torch>=2.0
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.42.16',
6
+ version = '1.42.18',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -16,10 +16,11 @@ setup(
16
16
  'transformers'
17
17
  ],
18
18
  install_requires=[
19
- 'torch>=2.0',
20
19
  'einx>=0.3.0',
21
20
  'einops>=0.8.0',
21
+ 'loguru',
22
22
  'packaging>=21.0',
23
+ 'torch>=2.0',
23
24
  ],
24
25
  setup_requires=[
25
26
  'pytest-runner',
@@ -352,7 +352,10 @@ def test_value_residual(
352
352
 
353
353
  model(x)
354
354
 
355
- def test_forgetting_transformer():
355
+ @pytest.mark.parametrize('has_num_mem_kv', (False, True))
356
+ def test_forgetting_transformer(
357
+ has_num_mem_kv: bool
358
+ ):
356
359
 
357
360
  model = TransformerWrapper(
358
361
  num_tokens = 20000,
@@ -361,7 +364,8 @@ def test_forgetting_transformer():
361
364
  dim = 128,
362
365
  depth = 6,
363
366
  heads = 8,
364
- attn_data_dependent_alibi = False
367
+ attn_num_mem_kv = 1 if has_num_mem_kv else 0,
368
+ attn_data_dependent_alibi = True
365
369
  )
366
370
  )
367
371
 
@@ -20,6 +20,8 @@ import einx
20
20
  from einops.layers.torch import Rearrange
21
21
  from einops import rearrange, repeat, reduce, pack, unpack
22
22
 
23
+ from loguru import logger
24
+
23
25
  from x_transformers.attend import Attend, Intermediates
24
26
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
25
27
 
@@ -1428,13 +1430,15 @@ class Attention(Module):
1428
1430
  else:
1429
1431
  attn_bias = rel_pos(i, j)
1430
1432
 
1431
- attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1433
+ attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0)) # handle memory key / values
1432
1434
 
1433
1435
  # prepare data dependent alibi from forgetting transformers paper, if needed
1434
1436
 
1435
1437
  if exists(self.data_dependent_alibi):
1436
1438
  attn_bias = self.data_dependent_alibi(x)
1437
1439
 
1440
+ attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
1441
+
1438
1442
  # attention is all we need
1439
1443
 
1440
1444
  out, intermediates = self.attend(
@@ -1578,7 +1582,12 @@ class AttentionLayers(Module):
1578
1582
 
1579
1583
  self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
1580
1584
 
1581
- rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1585
+ rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
1586
+
1587
+ assert rotary_emb_dim <= dim_head, f'rotary emb dim {rotary_emb_dim} must be less than or equal to attention head dimension {dim_head}'
1588
+
1589
+ if rotary_emb_dim < 32:
1590
+ logger.warning('when training language model, rotary embedding dimension should be at least 32')
1582
1591
 
1583
1592
  assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1584
1593
  self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.16
3
+ Version: 1.42.18
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -14,7 +14,8 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Programming Language :: Python :: 3.6
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: torch>=2.0
18
17
  Requires-Dist: einx>=0.3.0
19
18
  Requires-Dist: einops>=0.8.0
19
+ Requires-Dist: loguru
20
20
  Requires-Dist: packaging>=21.0
21
+ Requires-Dist: torch>=2.0
@@ -1,4 +1,5 @@
1
- torch>=2.0
2
1
  einx>=0.3.0
3
2
  einops>=0.8.0
3
+ loguru
4
4
  packaging>=21.0
5
+ torch>=2.0