x-transformers 1.40.2__tar.gz → 1.40.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.40.2/x_transformers.egg-info → x_transformers-1.40.3}/PKG-INFO +1 -1
  2. {x_transformers-1.40.2 → x_transformers-1.40.3}/README.md +2 -2
  3. {x_transformers-1.40.2 → x_transformers-1.40.3}/setup.py +1 -1
  4. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers/x_transformers.py +4 -3
  5. {x_transformers-1.40.2 → x_transformers-1.40.3/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.40.2 → x_transformers-1.40.3}/LICENSE +0 -0
  7. {x_transformers-1.40.2 → x_transformers-1.40.3}/setup.cfg +0 -0
  8. {x_transformers-1.40.2 → x_transformers-1.40.3}/tests/test_x_transformers.py +0 -0
  9. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers/nonautoregressive_wrapper.py +0 -0
  16. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.40.2 → x_transformers-1.40.3}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.2
3
+ Version: 1.40.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -525,8 +525,8 @@ model = TransformerWrapper(
525
525
  dim = 512,
526
526
  depth = 6,
527
527
  heads = 8,
528
- attn_sparse_topk = 8, # keep only the top 8 values before attention (softmax)
529
- sparse_topk_straight_through = True # straight through the original gradients
528
+ attn_sparse_topk = 8, # keep only the top 8 values before attention (softmax)
529
+ attn_sparse_topk_straight_through = True # straight through the original gradients
530
530
  )
531
531
  )
532
532
  ```
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.40.2',
6
+ version = '1.40.3',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -442,10 +442,10 @@ class DynamicPositionBias(Module):
442
442
  return bias
443
443
 
444
444
  class AlibiPositionalBias(Module):
445
- def __init__(self, heads, total_heads, **kwargs):
445
+ def __init__(self, heads, total_heads = None, **kwargs):
446
446
  super().__init__()
447
447
  self.heads = heads
448
- self.total_heads = total_heads
448
+ self.total_heads = default(total_heads, heads)
449
449
 
450
450
  slopes = Tensor(self._get_slopes(heads))
451
451
  slopes = rearrange(slopes, 'h -> h 1 1')
@@ -1665,6 +1665,7 @@ class AttentionLayers(Module):
1665
1665
  cache_age = 1,
1666
1666
  return_hiddens = False,
1667
1667
  rotary_pos_emb = None,
1668
+ attn_bias = None,
1668
1669
  condition = None,
1669
1670
  layers_execute_order: tuple[int, ...] | None = None
1670
1671
  ):
@@ -1818,7 +1819,7 @@ class AttentionLayers(Module):
1818
1819
  block = partial(block, **block_forward_kwargs)
1819
1820
 
1820
1821
  if layer_type == 'a':
1821
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
1822
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, return_intermediates = True)
1822
1823
  elif layer_type == 'c':
1823
1824
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1824
1825
  elif layer_type == 'f':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.2
3
+ Version: 1.40.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes