x-transformers 1.37.6__py3-none-any.whl → 1.37.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +11 -1
- {x_transformers-1.37.6.dist-info → x_transformers-1.37.8.dist-info}/METADATA +1 -1
- {x_transformers-1.37.6.dist-info → x_transformers-1.37.8.dist-info}/RECORD +6 -6
- {x_transformers-1.37.6.dist-info → x_transformers-1.37.8.dist-info}/LICENSE +0 -0
- {x_transformers-1.37.6.dist-info → x_transformers-1.37.8.dist-info}/WHEEL +0 -0
- {x_transformers-1.37.6.dist-info → x_transformers-1.37.8.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -145,7 +145,6 @@ class Attend(Module):
|
|
145
145
|
|
146
146
|
# l2 distance attention
|
147
147
|
|
148
|
-
assert not (flash and l2_distance), 'l2 distance attention does not work with flash attention just yet'
|
149
148
|
self.l2_distance = l2_distance
|
150
149
|
|
151
150
|
# add a key / value token composed of zeros
|
@@ -208,6 +207,17 @@ class Attend(Module):
|
|
208
207
|
if v.ndim == 3:
|
209
208
|
v = repeat(v, 'b ... -> b h ...', h = q.shape[1])
|
210
209
|
|
210
|
+
# handle maybe l2 distance
|
211
|
+
|
212
|
+
if self.l2_distance:
|
213
|
+
k_norm_sq = k.norm(dim = -1, keepdim = True) ** 2
|
214
|
+
k = F.pad(k, (0, 1), value = -1.)
|
215
|
+
k = torch.cat((k, k_norm_sq), dim = -1)
|
216
|
+
|
217
|
+
q_norm_sq = q.norm(dim = -1, keepdim = True) ** 2
|
218
|
+
q = torch.cat((2 * q, q_norm_sq), dim = -1)
|
219
|
+
q = F.pad(q, (0, 1), value = -1.)
|
220
|
+
|
211
221
|
# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
|
212
222
|
|
213
223
|
if exists(self.scale):
|
@@ -1,5 +1,5 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=we7mkwVCD7_ColUD8_Fj0HM5jjOaa3wbstllp_XXK4k,13434
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
@@ -8,8 +8,8 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
8
8
|
x_transformers/x_transformers.py,sha256=-2fj6QcDSfMI5lJA_fzOW2mdzdS1C1LD6jMBtGQY48E,83752
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.37.
|
12
|
-
x_transformers-1.37.
|
13
|
-
x_transformers-1.37.
|
14
|
-
x_transformers-1.37.
|
15
|
-
x_transformers-1.37.
|
11
|
+
x_transformers-1.37.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.37.8.dist-info/METADATA,sha256=fiT94VbrxWL-8jJBjxvFloWsH6n6reOGitRSlpAhvWs,661
|
13
|
+
x_transformers-1.37.8.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.37.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.37.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|