x-transformers 1.37.5__py3-none-any.whl → 1.37.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +15 -5
- {x_transformers-1.37.5.dist-info → x_transformers-1.37.7.dist-info}/METADATA +1 -1
- {x_transformers-1.37.5.dist-info → x_transformers-1.37.7.dist-info}/RECORD +6 -6
- {x_transformers-1.37.5.dist-info → x_transformers-1.37.7.dist-info}/LICENSE +0 -0
- {x_transformers-1.37.5.dist-info → x_transformers-1.37.7.dist-info}/WHEEL +0 -0
- {x_transformers-1.37.5.dist-info → x_transformers-1.37.7.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -64,15 +64,15 @@ print_once = once(print)
|
|
64
64
|
|
65
65
|
# alternative distance functions
|
66
66
|
|
67
|
-
def
|
67
|
+
def qk_l2_dist_squared(q, k):
|
68
68
|
if k.ndim == 3:
|
69
69
|
k = repeat(k, 'b j d -> b h j d', h = q.shape[1])
|
70
70
|
|
71
71
|
q, packed_shape = pack_one(q, '* i d')
|
72
72
|
k, _ = pack_one(k, '* j d')
|
73
73
|
|
74
|
-
|
75
|
-
return unpack_one(
|
74
|
+
l2_dist_squared = torch.cdist(q, k) ** 2
|
75
|
+
return unpack_one(l2_dist_squared, packed_shape, '* i j')
|
76
76
|
|
77
77
|
# functions for creating causal mask
|
78
78
|
# need a special one for onnx cpu (no support for .triu)
|
@@ -145,7 +145,6 @@ class Attend(Module):
|
|
145
145
|
|
146
146
|
# l2 distance attention
|
147
147
|
|
148
|
-
assert not (flash and l2_distance), 'l2 distance attention does not work with flash attention just yet'
|
149
148
|
self.l2_distance = l2_distance
|
150
149
|
|
151
150
|
# add a key / value token composed of zeros
|
@@ -208,6 +207,17 @@ class Attend(Module):
|
|
208
207
|
if v.ndim == 3:
|
209
208
|
v = repeat(v, 'b ... -> b h ...', h = q.shape[1])
|
210
209
|
|
210
|
+
# handle maybe l2 distance
|
211
|
+
|
212
|
+
if self.l2_distance:
|
213
|
+
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
|
214
|
+
k = F.pad(k, (0, 1), value = 1.)
|
215
|
+
k = torch.cat((k, -k_norm_sq), dim = -1)
|
216
|
+
|
217
|
+
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
|
218
|
+
q = torch.cat((2 * q, -q_norm_sq), dim = -1)
|
219
|
+
q = F.pad(q, (0, 1), value = 1.)
|
220
|
+
|
211
221
|
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
|
212
222
|
|
213
223
|
if exists(self.scale):
|
@@ -353,7 +363,7 @@ class Attend(Module):
|
|
353
363
|
if not self.l2_distance:
|
354
364
|
sim = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k)
|
355
365
|
else:
|
356
|
-
sim = -
|
366
|
+
sim = -qk_l2_dist_squared(q, k)
|
357
367
|
|
358
368
|
sim = sim * scale
|
359
369
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=BOnMjgV5O5DyAM_bUz1rI6n1j_eLXu8GIEljT-MMnWU,13434
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
@@ -8,8 +8,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
8
8
|
x_transformers/x_transformers.py,sha256=-2fj6QcDSfMI5lJA_fzOW2mdzdS1C1LD6jMBtGQY48E,83752
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.37.
|
12
|
-
x_transformers-1.37.
|
13
|
-
x_transformers-1.37.
|
14
|
-
x_transformers-1.37.
|
15
|
-
x_transformers-1.37.
|
11
|
+
x_transformers-1.37.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.37.7.dist-info/METADATA,sha256=svdldk1hpiBN4xxKkOaQuMEXAxvaj6fT7Ri9NXwZJCU,661
|
13
|
+
x_transformers-1.37.7.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.37.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.37.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|