x-transformers 1.39.3__py3-none-any.whl → 1.39.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -107,9 +107,9 @@ def qk_l2_dist_squared(q, k):
107
107
  l2_dist_squared = torch.cdist(q, k) ** 2
108
108
  return unpack_one(l2_dist_squared, packed_shape, '* i j')
109
109
 
110
- # gumbel softmax
110
+ # one-hot straight through softmax
111
111
 
112
- def gumbel_softmax(t, temperature = 1.):
112
+ def one_hot_straight_through(t, temperature = 1.):
113
113
  one_hot_indices = t.argmax(dim = -1, keepdim = True)
114
114
  one_hot = torch.zeros_like(t).scatter(-1, one_hot_indices, 1.)
115
115
 
@@ -180,7 +180,7 @@ class Attend(Module):
180
180
  elif sigmoid:
181
181
  self.attn_fn = F.sigmoid
182
182
  elif hard:
183
- self.attn_fn = gumbel_softmax
183
+ self.attn_fn = one_hot_straight_through
184
184
  else:
185
185
  softmax_fn = partial(F.softmax, dim = -1)
186
186
  self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.39.3
3
+ Version: 1.39.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,5 +1,5 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=K4DNcWAgOHzNtOTIJA15VA3VQ2KMyv-PX8oO1R0Z5Rw,16670
2
+ x_transformers/attend.py,sha256=013qsFwoABVbyc-1L3RZTRCWo6BW9fAD8IVnC_qALGk,16708
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
@@ -8,8 +8,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
8
8
  x_transformers/x_transformers.py,sha256=3JrVqwYVrd5UVf2esdunTcer7QL72H7VF4mL3UsCWOI,84508
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.39.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.39.3.dist-info/METADATA,sha256=aqCockVfuZfLM3D9ZdlgN_HfUPdhZm4UWyKg2HLkuUo,661
13
- x_transformers-1.39.3.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.39.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.39.3.dist-info/RECORD,,
11
+ x_transformers-1.39.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.39.4.dist-info/METADATA,sha256=2KawHim0IOdlRjbRJCVsELM10T7nojxnMy6WrWtG0UE,661
13
+ x_transformers-1.39.4.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.39.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.39.4.dist-info/RECORD,,