x-transformers 1.39.0__py3-none-any.whl → 1.39.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +33 -9
- x_transformers/x_transformers.py +5 -0
- {x_transformers-1.39.0.dist-info → x_transformers-1.39.3.dist-info}/METADATA +1 -1
- {x_transformers-1.39.0.dist-info → x_transformers-1.39.3.dist-info}/RECORD +7 -7
- {x_transformers-1.39.0.dist-info → x_transformers-1.39.3.dist-info}/WHEEL +1 -1
- {x_transformers-1.39.0.dist-info → x_transformers-1.39.3.dist-info}/LICENSE +0 -0
- {x_transformers-1.39.0.dist-info → x_transformers-1.39.3.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -107,6 +107,15 @@ def qk_l2_dist_squared(q, k):
|
|
107
107
|
l2_dist_squared = torch.cdist(q, k) ** 2
|
108
108
|
return unpack_one(l2_dist_squared, packed_shape, '* i j')
|
109
109
|
|
110
|
+
# gumbel softmax
|
111
|
+
|
112
|
+
def gumbel_softmax(t, temperature = 1.):
|
113
|
+
one_hot_indices = t.argmax(dim = -1, keepdim = True)
|
114
|
+
one_hot = torch.zeros_like(t).scatter(-1, one_hot_indices, 1.)
|
115
|
+
|
116
|
+
t = (t / temperature).softmax(dim = -1)
|
117
|
+
return one_hot + t - t.detach()
|
118
|
+
|
110
119
|
# functions for creating causal mask
|
111
120
|
# need a special one for onnx cpu (no support for .triu)
|
112
121
|
|
@@ -128,8 +137,9 @@ class Attend(Module):
|
|
128
137
|
dropout = 0.,
|
129
138
|
causal = False,
|
130
139
|
heads = None,
|
131
|
-
pre_talking_heads =
|
132
|
-
post_talking_heads =
|
140
|
+
pre_talking_heads = False,
|
141
|
+
post_talking_heads = False,
|
142
|
+
pre_scale_post_talking_heads = False,
|
133
143
|
sparse_topk = None,
|
134
144
|
scale = None,
|
135
145
|
qk_norm = False,
|
@@ -141,6 +151,7 @@ class Attend(Module):
|
|
141
151
|
logit_softclamp_value = 50.,
|
142
152
|
add_zero_kv = False,
|
143
153
|
selective = False,
|
154
|
+
hard = False,
|
144
155
|
sigsoftmax = False,
|
145
156
|
cope = None,
|
146
157
|
onnxable = False,
|
@@ -161,17 +172,18 @@ class Attend(Module):
|
|
161
172
|
# attention type
|
162
173
|
|
163
174
|
assert not (flash and sigmoid), 'sigmoid attention not available for flash'
|
164
|
-
assert
|
165
|
-
|
166
|
-
self.sigmoid = sigmoid
|
175
|
+
assert not (flash and hard), 'hard attention not available for flash'
|
176
|
+
assert at_most_one_of(sigmoid, hard, l2_distance)
|
167
177
|
|
168
178
|
if exists(custom_attn_fn):
|
169
179
|
self.attn_fn = custom_attn_fn
|
170
|
-
elif
|
180
|
+
elif sigmoid:
|
181
|
+
self.attn_fn = F.sigmoid
|
182
|
+
elif hard:
|
183
|
+
self.attn_fn = gumbel_softmax
|
184
|
+
else:
|
171
185
|
softmax_fn = partial(F.softmax, dim = -1)
|
172
186
|
self.attn_fn = partial(softmax_fn, dtype = torch.float32) if not qk_norm else softmax_fn
|
173
|
-
else:
|
174
|
-
self.attn_fn = F.sigmoid
|
175
187
|
|
176
188
|
# dropouts
|
177
189
|
|
@@ -180,10 +192,11 @@ class Attend(Module):
|
|
180
192
|
|
181
193
|
# talking heads
|
182
194
|
|
183
|
-
assert not (flash and (pre_talking_heads or post_talking_heads)), 'talking heads not compatible with flash attention'
|
195
|
+
assert not (flash and (pre_talking_heads or post_talking_heads or pre_scale_post_talking_heads)), 'talking heads not compatible with flash attention'
|
184
196
|
|
185
197
|
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
|
186
198
|
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
|
199
|
+
self.pre_scale_post_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_scale_post_talking_heads else None
|
187
200
|
|
188
201
|
if exists(self.pre_softmax_talking_heads):
|
189
202
|
nn.init.dirac_(self.pre_softmax_talking_heads.weight)
|
@@ -191,6 +204,10 @@ class Attend(Module):
|
|
191
204
|
if exists(self.post_softmax_talking_heads):
|
192
205
|
nn.init.dirac_(self.post_softmax_talking_heads.weight)
|
193
206
|
|
207
|
+
if exists(self.pre_scale_post_talking_heads):
|
208
|
+
# an improvisation where heads are combined pre-softmax attention, then used to scale post-softmax attention
|
209
|
+
nn.init.dirac_(self.pre_scale_post_talking_heads.weight)
|
210
|
+
|
194
211
|
# selective attention
|
195
212
|
|
196
213
|
assert not (flash and selective), 'selective attention cannot work on flash attention'
|
@@ -436,6 +453,9 @@ class Attend(Module):
|
|
436
453
|
|
437
454
|
qk_similarities = sim.clone()
|
438
455
|
|
456
|
+
if exists(self.pre_scale_post_talking_heads):
|
457
|
+
pre_to_post_scale = self.pre_scale_post_talking_heads(sim)
|
458
|
+
|
439
459
|
if exists(self.pre_softmax_talking_heads):
|
440
460
|
sim = sim + self.pre_softmax_talking_heads(sim)
|
441
461
|
|
@@ -478,6 +498,7 @@ class Attend(Module):
|
|
478
498
|
sim = sim + sim.sigmoid().log()
|
479
499
|
|
480
500
|
attn = self.attn_fn(sim)
|
501
|
+
|
481
502
|
attn = attn.type(dtype)
|
482
503
|
|
483
504
|
post_softmax_attn = attn
|
@@ -487,6 +508,9 @@ class Attend(Module):
|
|
487
508
|
if exists(self.post_softmax_talking_heads):
|
488
509
|
attn = self.post_softmax_talking_heads(attn)
|
489
510
|
|
511
|
+
if exists(self.pre_scale_post_talking_heads):
|
512
|
+
attn = attn * pre_to_post_scale
|
513
|
+
|
490
514
|
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
491
515
|
|
492
516
|
intermediates = Intermediates(
|
x_transformers/x_transformers.py
CHANGED
@@ -909,6 +909,7 @@ class Attention(Module):
|
|
909
909
|
flash = False,
|
910
910
|
pre_talking_heads = False,
|
911
911
|
post_talking_heads = False,
|
912
|
+
pre_scale_post_talking_heads = False,
|
912
913
|
head_scale = False,
|
913
914
|
sparse_topk = None,
|
914
915
|
num_mem_kv = 0,
|
@@ -918,6 +919,7 @@ class Attention(Module):
|
|
918
919
|
swiglu_values = False,
|
919
920
|
gate_values = False,
|
920
921
|
zero_init_output = False,
|
922
|
+
hard = False,
|
921
923
|
sigsoftmax = False,
|
922
924
|
max_attend_past = None,
|
923
925
|
qk_norm = False,
|
@@ -1039,8 +1041,10 @@ class Attention(Module):
|
|
1039
1041
|
causal = causal,
|
1040
1042
|
pre_talking_heads = pre_talking_heads,
|
1041
1043
|
post_talking_heads = post_talking_heads,
|
1044
|
+
pre_scale_post_talking_heads = pre_scale_post_talking_heads,
|
1042
1045
|
dropout = dropout,
|
1043
1046
|
sparse_topk = sparse_topk,
|
1047
|
+
hard = hard,
|
1044
1048
|
qk_norm = qk_norm,
|
1045
1049
|
scale = qk_norm_scale if qk_norm else self.scale,
|
1046
1050
|
l2_distance = l2_distance,
|
@@ -1610,6 +1614,7 @@ class AttentionLayers(Module):
|
|
1610
1614
|
layer = post_branch_fn(layer)
|
1611
1615
|
|
1612
1616
|
residual_fn = GRUGating if gate_residual else Residual
|
1617
|
+
|
1613
1618
|
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1614
1619
|
|
1615
1620
|
# handle unet skip connection
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=K4DNcWAgOHzNtOTIJA15VA3VQ2KMyv-PX8oO1R0Z5Rw,16670
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=3JrVqwYVrd5UVf2esdunTcer7QL72H7VF4mL3UsCWOI,84508
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.39.
|
12
|
-
x_transformers-1.39.
|
13
|
-
x_transformers-1.39.
|
14
|
-
x_transformers-1.39.
|
15
|
-
x_transformers-1.39.
|
11
|
+
x_transformers-1.39.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.39.3.dist-info/METADATA,sha256=aqCockVfuZfLM3D9ZdlgN_HfUPdhZm4UWyKg2HLkuUo,661
|
13
|
+
x_transformers-1.39.3.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.39.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.39.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|