x-transformers 1.42.17__tar.gz → 1.42.19__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.42.17/x_transformers.egg-info → x_transformers-1.42.19}/PKG-INFO +3 -2
  2. {x_transformers-1.42.17 → x_transformers-1.42.19}/README.md +9 -0
  3. {x_transformers-1.42.17 → x_transformers-1.42.19}/setup.py +3 -2
  4. {x_transformers-1.42.17 → x_transformers-1.42.19}/tests/test_x_transformers.py +16 -0
  5. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers/x_transformers.py +24 -1
  6. {x_transformers-1.42.17 → x_transformers-1.42.19/x_transformers.egg-info}/PKG-INFO +3 -2
  7. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers.egg-info/requires.txt +2 -1
  8. {x_transformers-1.42.17 → x_transformers-1.42.19}/LICENSE +0 -0
  9. {x_transformers-1.42.17 → x_transformers-1.42.19}/setup.cfg +0 -0
  10. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers/__init__.py +0 -0
  11. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers/attend.py +0 -0
  12. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers/autoregressive_wrapper.py +0 -0
  13. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers/continuous.py +0 -0
  14. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers/dpo.py +0 -0
  15. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers/multi_input.py +0 -0
  16. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers/neo_mlp.py +0 -0
  17. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers/nonautoregressive_wrapper.py +0 -0
  18. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  19. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers/xval.py +0 -0
  20. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers.egg-info/SOURCES.txt +0 -0
  21. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers.egg-info/dependency_links.txt +0 -0
  22. {x_transformers-1.42.17 → x_transformers-1.42.19}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.17
3
+ Version: 1.42.19
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -14,7 +14,8 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Programming Language :: Python :: 3.6
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: torch>=2.0
18
17
  Requires-Dist: einx>=0.3.0
19
18
  Requires-Dist: einops>=0.8.0
19
+ Requires-Dist: loguru
20
20
  Requires-Dist: packaging>=21.0
21
+ Requires-Dist: torch>=2.0
@@ -2352,4 +2352,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2352
2352
  }
2353
2353
  ```
2354
2354
 
2355
+ ```bibtex
2356
+ @inproceedings{Duvvuri2024LASERAW,
2357
+ title = {LASER: Attention with Exponential Transformation},
2358
+ author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
2359
+ year = {2024},
2360
+ url = {https://api.semanticscholar.org/CorpusID:273849947}
2361
+ }
2362
+ ```
2363
+
2355
2364
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.42.17',
6
+ version = '1.42.19',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -16,10 +16,11 @@ setup(
16
16
  'transformers'
17
17
  ],
18
18
  install_requires=[
19
- 'torch>=2.0',
20
19
  'einx>=0.3.0',
21
20
  'einops>=0.8.0',
21
+ 'loguru',
22
22
  'packaging>=21.0',
23
+ 'torch>=2.0',
23
24
  ],
24
25
  setup_requires=[
25
26
  'pytest-runner',
@@ -516,3 +516,19 @@ def test_to_logits(to_logits):
516
516
  output = model(x, to_logits_kwargs=to_logits_kwargs)
517
517
 
518
518
  assert output.shape == (2, 1024, 20000)
519
+
520
+ def test_laser():
521
+ model = TransformerWrapper(
522
+ num_tokens = 20000,
523
+ max_seq_len = 1024,
524
+ attn_layers = Decoder(
525
+ dim = 128,
526
+ depth = 6,
527
+ heads = 8,
528
+ attn_laser = True
529
+ )
530
+ )
531
+
532
+ x = torch.randint(0, 20000, (2, 1024))
533
+
534
+ model(x)
@@ -20,6 +20,8 @@ import einx
20
20
  from einops.layers.torch import Rearrange
21
21
  from einops import rearrange, repeat, reduce, pack, unpack
22
22
 
23
+ from loguru import logger
24
+
23
25
  from x_transformers.attend import Attend, Intermediates
24
26
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
25
27
 
@@ -1073,6 +1075,7 @@ class Attention(Module):
1073
1075
  neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1074
1076
  neutreno_alpha = 0.4,
1075
1077
  learned_value_residual_mix = False,
1078
+ laser = False, # https://arxiv.org/abs/2411.03493v1
1076
1079
  onnxable = False,
1077
1080
  attend_sdp_kwargs: dict = dict(
1078
1081
  enable_flash = True,
@@ -1112,6 +1115,11 @@ class Attention(Module):
1112
1115
  assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1113
1116
  self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
1114
1117
 
1118
+ # enhancing gradients to attention through exponentiated values
1119
+ # todo - compare it to `attn = attn * large_value + attn.detach() * (1. - large_value)`
1120
+
1121
+ self.laser = laser
1122
+
1115
1123
  # relations projection from tp-attention
1116
1124
 
1117
1125
  self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
@@ -1437,6 +1445,11 @@ class Attention(Module):
1437
1445
 
1438
1446
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
1439
1447
 
1448
+ if self.laser:
1449
+ values_max = v.amax(dim = -2, keepdim = True).detach() # numerical stability
1450
+ v = v - values_max
1451
+ v = v.exp()
1452
+
1440
1453
  # attention is all we need
1441
1454
 
1442
1455
  out, intermediates = self.attend(
@@ -1446,6 +1459,11 @@ class Attention(Module):
1446
1459
  prev_attn = prev_attn
1447
1460
  )
1448
1461
 
1462
+ # laser
1463
+
1464
+ if self.laser:
1465
+ out = out.log() + values_max
1466
+
1449
1467
  # store the values for resformer or Neutreno
1450
1468
 
1451
1469
  intermediates.values = orig_values
@@ -1580,7 +1598,12 @@ class AttentionLayers(Module):
1580
1598
 
1581
1599
  self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
1582
1600
 
1583
- rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1601
+ rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
1602
+
1603
+ assert rotary_emb_dim <= dim_head, f'rotary emb dim {rotary_emb_dim} must be less than or equal to attention head dimension {dim_head}'
1604
+
1605
+ if rotary_emb_dim < 32:
1606
+ logger.warning('when training language model, rotary embedding dimension should be at least 32')
1584
1607
 
1585
1608
  assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1586
1609
  self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.17
3
+ Version: 1.42.19
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -14,7 +14,8 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Programming Language :: Python :: 3.6
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: torch>=2.0
18
17
  Requires-Dist: einx>=0.3.0
19
18
  Requires-Dist: einops>=0.8.0
19
+ Requires-Dist: loguru
20
20
  Requires-Dist: packaging>=21.0
21
+ Requires-Dist: torch>=2.0
@@ -1,4 +1,5 @@
1
- torch>=2.0
2
1
  einx>=0.3.0
3
2
  einops>=0.8.0
3
+ loguru
4
4
  packaging>=21.0
5
+ torch>=2.0