x-transformers 1.39.1__py3-none-any.whl → 1.39.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +18 -6
- x_transformers/x_transformers.py +3 -0
- {x_transformers-1.39.1.dist-info → x_transformers-1.39.3.dist-info}/METADATA +1 -1
- {x_transformers-1.39.1.dist-info → x_transformers-1.39.3.dist-info}/RECORD +7 -7
- {x_transformers-1.39.1.dist-info → x_transformers-1.39.3.dist-info}/WHEEL +1 -1
- {x_transformers-1.39.1.dist-info → x_transformers-1.39.3.dist-info}/LICENSE +0 -0
- {x_transformers-1.39.1.dist-info → x_transformers-1.39.3.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -107,6 +107,15 @@ def qk_l2_dist_squared(q, k):
|
|
107
107
|
l2_dist_squared = torch.cdist(q, k) ** 2
|
108
108
|
return unpack_one(l2_dist_squared, packed_shape, '* i j')
|
109
109
|
|
110
|
+
# gumbel softmax
|
111
|
+
|
112
|
+
def gumbel_softmax(t, temperature = 1.):
|
113
|
+
one_hot_indices = t.argmax(dim = -1, keepdim = True)
|
114
|
+
one_hot = torch.zeros_like(t).scatter(-1, one_hot_indices, 1.)
|
115
|
+
|
116
|
+
t = (t / temperature).softmax(dim = -1)
|
117
|
+
return one_hot + t - t.detach()
|
118
|
+
|
110
119
|
# functions for creating causal mask
|
111
120
|
# need a special one for onnx cpu (no support for .triu)
|
112
121
|
|
@@ -142,6 +151,7 @@ class Attend(Module):
|
|
142
151
|
logit_softclamp_value = 50.,
|
143
152
|
add_zero_kv = False,
|
144
153
|
selective = False,
|
154
|
+
hard = False,
|
145
155
|
sigsoftmax = False,
|
146
156
|
cope = None,
|
147
157
|
onnxable = False,
|
@@ -162,17 +172,18 @@ class Attend(Module):
|
|
162
172
|
# attention type
|
163
173
|
|
164
174
|
assert not (flash and sigmoid), 'sigmoid attention not available for flash'
|
165
|
-
assert
|
166
|
-
|
167
|
-
self.sigmoid = sigmoid
|
175
|
+
assert not (flash and hard), 'hard attention not available for flash'
|
176
|
+
assert at_most_one_of(sigmoid, hard, l2_distance)
|
168
177
|
|
169
178
|
if exists(custom_attn_fn):
|
170
179
|
self.attn_fn = custom_attn_fn
|
171
|
-
elif
|
180
|
+
elif sigmoid:
|
181
|
+
self.attn_fn = F.sigmoid
|
182
|
+
elif hard:
|
183
|
+
self.attn_fn = gumbel_softmax
|
184
|
+
else:
|
172
185
|
softmax_fn = partial(F.softmax, dim = -1)
|
173
186
|
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
174
|
-
else:
|
175
|
-
self.attn_fn = F.sigmoid
|
176
187
|
|
177
188
|
# dropouts
|
178
189
|
|
@@ -487,6 +498,7 @@ class Attend(Module):
|
|
487
498
|
sim = sim + sim.sigmoid().log()
|
488
499
|
|
489
500
|
attn = self.attn_fn(sim)
|
501
|
+
|
490
502
|
attn = attn.type(dtype)
|
491
503
|
|
492
504
|
post_softmax_attn = attn
|
x_transformers/x_transformers.py
CHANGED
@@ -919,6 +919,7 @@ class Attention(Module):
|
|
919
919
|
swiglu_values = False,
|
920
920
|
gate_values = False,
|
921
921
|
zero_init_output = False,
|
922
|
+
hard = False,
|
922
923
|
sigsoftmax = False,
|
923
924
|
max_attend_past = None,
|
924
925
|
qk_norm = False,
|
@@ -1043,6 +1044,7 @@ class Attention(Module):
|
|
1043
1044
|
pre_scale_post_talking_heads = pre_scale_post_talking_heads,
|
1044
1045
|
dropout = dropout,
|
1045
1046
|
sparse_topk = sparse_topk,
|
1047
|
+
hard = hard,
|
1046
1048
|
qk_norm = qk_norm,
|
1047
1049
|
scale = qk_norm_scale if qk_norm else self.scale,
|
1048
1050
|
l2_distance = l2_distance,
|
@@ -1612,6 +1614,7 @@ class AttentionLayers(Module):
|
|
1612
1614
|
layer = post_branch_fn(layer)
|
1613
1615
|
|
1614
1616
|
residual_fn = GRUGating if gate_residual else Residual
|
1617
|
+
|
1615
1618
|
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1616
1619
|
|
1617
1620
|
# handle unet skip connection
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=K4DNcWAgOHzNtOTIJA15VA3VQ2KMyv-PX8oO1R0Z5Rw,16670
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=3JrVqwYVrd5UVf2esdunTcer7QL72H7VF4mL3UsCWOI,84508
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.39.
|
12
|
-
x_transformers-1.39.
|
13
|
-
x_transformers-1.39.
|
14
|
-
x_transformers-1.39.
|
15
|
-
x_transformers-1.39.
|
11
|
+
x_transformers-1.39.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.39.3.dist-info/METADATA,sha256=aqCockVfuZfLM3D9ZdlgN_HfUPdhZm4UWyKg2HLkuUo,661
|
13
|
+
x_transformers-1.39.3.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.39.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.39.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|