x-transformers 1.40.0__tar.gz → 1.40.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.40.0/x_transformers.egg-info → x_transformers-1.40.2}/PKG-INFO +1 -1
  2. {x_transformers-1.40.0 → x_transformers-1.40.2}/setup.py +1 -1
  3. {x_transformers-1.40.0 → x_transformers-1.40.2}/tests/test_x_transformers.py +0 -22
  4. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers/attend.py +7 -2
  5. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers/x_transformers.py +2 -1
  6. {x_transformers-1.40.0 → x_transformers-1.40.2/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.40.0 → x_transformers-1.40.2}/LICENSE +0 -0
  8. {x_transformers-1.40.0 → x_transformers-1.40.2}/README.md +0 -0
  9. {x_transformers-1.40.0 → x_transformers-1.40.2}/setup.cfg +0 -0
  10. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers/__init__.py +0 -0
  11. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
  16. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.40.0 → x_transformers-1.40.2}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.0
3
+ Version: 1.40.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.40.0',
6
+ version = '1.40.2',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -280,28 +280,6 @@ def test_mos():
280
280
 
281
281
  eval_logits = model(x)
282
282
 
283
- def test_sigsoftmax():
284
- model = TransformerWrapper(
285
- num_tokens = 20000,
286
- max_seq_len = 1024,
287
- mixture_of_softmax = True,
288
- sigsoftmax_logits = True,
289
- attn_layers = Decoder(
290
- attn_sigsoftmax = True,
291
- dim = 128,
292
- depth = 6,
293
- heads = 8
294
- )
295
- )
296
-
297
- x = torch.randint(0, 20000, (2, 1024))
298
-
299
- logits = model(x)
300
-
301
- model.eval()
302
-
303
- eval_logits = model(x)
304
-
305
283
  @pytest.mark.parametrize('attn_one_kv_head', (True, False))
306
284
  def test_l2_distance(attn_one_kv_head):
307
285
 
@@ -119,7 +119,12 @@ def one_hot_straight_through(logits, temperature = 1.):
119
119
  # sparse topk attention - only keep topk attn logits for softmax
120
120
  # optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True`
121
121
 
122
- def sparse_topk_attn(logits, sparse_topk, temperature = 1., straight_through = False):
122
+ def sparse_topk_attn(
123
+ logits,
124
+ sparse_topk,
125
+ temperature = 1.,
126
+ straight_through = False
127
+ ):
123
128
  orig_logits = logits
124
129
 
125
130
  mask_value = -torch.finfo(logits.dtype).max
@@ -132,7 +137,7 @@ def sparse_topk_attn(logits, sparse_topk, temperature = 1., straight_through = F
132
137
  return topk_attn
133
138
 
134
139
  soft_attn = (orig_logits / temperature).softmax(dim = -1)
135
- return topk_attn + soft_attn - soft_attn.detach()
140
+ return topk_attn.detach() + soft_attn - soft_attn.detach()
136
141
 
137
142
  # functions for creating causal mask
138
143
  # need a special one for onnx cpu (no support for .triu)
@@ -1107,6 +1107,7 @@ class Attention(Module):
1107
1107
  context_mask = None,
1108
1108
  attn_mask = None,
1109
1109
  rel_pos = None,
1110
+ attn_bias = None,
1110
1111
  rotary_pos_emb = None,
1111
1112
  prev_attn = None,
1112
1113
  mem = None,
@@ -1237,8 +1238,8 @@ class Attention(Module):
1237
1238
 
1238
1239
  # prepare relative positional bias, if needed
1239
1240
 
1240
- attn_bias = None
1241
1241
  if exists(rel_pos):
1242
+ assert not exists(attn_bias)
1242
1243
  attn_bias = rel_pos(i, j)
1243
1244
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1244
1245
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.0
3
+ Version: 1.40.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes