x-transformers 1.40.0__py3-none-any.whl → 1.40.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -119,7 +119,12 @@ def one_hot_straight_through(logits, temperature = 1.):
119
119
  # sparse topk attention - only keep topk attn logits for softmax
120
120
  # optional straight through with masked out logits by setting `attn_sparse_topk_straight_through = True`
121
121
 
122
- def sparse_topk_attn(logits, sparse_topk, temperature = 1., straight_through = False):
122
+ def sparse_topk_attn(
123
+ logits,
124
+ sparse_topk,
125
+ temperature = 1.,
126
+ straight_through = False
127
+ ):
123
128
  orig_logits = logits
124
129
 
125
130
  mask_value = -torch.finfo(logits.dtype).max
@@ -132,7 +137,7 @@ def sparse_topk_attn(logits, sparse_topk, temperature = 1., straight_through = F
132
137
  return topk_attn
133
138
 
134
139
  soft_attn = (orig_logits / temperature).softmax(dim = -1)
135
- return topk_attn + soft_attn - soft_attn.detach()
140
+ return topk_attn.detach() + soft_attn - soft_attn.detach()
136
141
 
137
142
  # functions for creating causal mask
138
143
  # need a special one for onnx cpu (no support for .triu)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.0
3
+ Version: 1.40.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,5 +1,5 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=eoBEK0HdDCWaJgxwGZPeO36ydBt1NbB-gpij_Jkj4Mw,17212
2
+ x_transformers/attend.py,sha256=VbB0fi-ETgAF4dc2a_Meaqvt14LMaRVIjZ8NexUX8F0,17239
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
@@ -8,8 +8,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
8
8
  x_transformers/x_transformers.py,sha256=TGXJZXCWR5BiMkS5Kx-JhFQ85AxkiJabLiHnrCTC874,84562
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.40.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.0.dist-info/METADATA,sha256=WxBpjG7F8utkdJN1AF9vZMAfbltT0lmpdgtUErBbYQY,661
13
- x_transformers-1.40.0.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.0.dist-info/RECORD,,
11
+ x_transformers-1.40.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.1.dist-info/METADATA,sha256=WouMl3Ld1llknOwj7BcKi-_YZ9Hx9RZ-ni-eGCP_uQY,661
13
+ x_transformers-1.40.1.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.1.dist-info/RECORD,,