x-transformers 1.39.3__py3-none-any.whl → 1.39.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +3 -3
- {x_transformers-1.39.3.dist-info → x_transformers-1.39.4.dist-info}/METADATA +1 -1
- {x_transformers-1.39.3.dist-info → x_transformers-1.39.4.dist-info}/RECORD +6 -6
- {x_transformers-1.39.3.dist-info → x_transformers-1.39.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.39.3.dist-info → x_transformers-1.39.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.39.3.dist-info → x_transformers-1.39.4.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -107,9 +107,9 @@ def qk_l2_dist_squared(q, k):
|
|
107
107
|
l2_dist_squared = torch.cdist(q, k) ** 2
|
108
108
|
return unpack_one(l2_dist_squared, packed_shape, '* i j')
|
109
109
|
|
110
|
-
#
|
110
|
+
# one-hot straight through softmax
|
111
111
|
|
112
|
-
def
|
112
|
+
def one_hot_straight_through(t, temperature = 1.):
|
113
113
|
one_hot_indices = t.argmax(dim = -1, keepdim = True)
|
114
114
|
one_hot = torch.zeros_like(t).scatter(-1, one_hot_indices, 1.)
|
115
115
|
|
@@ -180,7 +180,7 @@ class Attend(Module):
|
|
180
180
|
elif sigmoid:
|
181
181
|
self.attn_fn = F.sigmoid
|
182
182
|
elif hard:
|
183
|
-
self.attn_fn =
|
183
|
+
self.attn_fn = one_hot_straight_through
|
184
184
|
else:
|
185
185
|
softmax_fn = partial(F.softmax, dim = -1)
|
186
186
|
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
@@ -1,5 +1,5 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=013qsFwoABVbyc-1L3RZTRCWo6BW9fAD8IVnC_qALGk,16708
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
@@ -8,8 +8,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
8
8
|
x_transformers/x_transformers.py,sha256=3JrVqwYVrd5UVf2esdunTcer7QL72H7VF4mL3UsCWOI,84508
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.39.
|
12
|
-
x_transformers-1.39.
|
13
|
-
x_transformers-1.39.
|
14
|
-
x_transformers-1.39.
|
15
|
-
x_transformers-1.39.
|
11
|
+
x_transformers-1.39.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.39.4.dist-info/METADATA,sha256=2KawHim0IOdlRjbRJCVsELM10T7nojxnMy6WrWtG0UE,661
|
13
|
+
x_transformers-1.39.4.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.39.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.39.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|