x-transformers 1.16.6__py3-none-any.whl → 1.16.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +4 -6
- x_transformers/x_transformers.py +4 -2
- {x_transformers-1.16.6.dist-info → x_transformers-1.16.7.dist-info}/METADATA +1 -1
- {x_transformers-1.16.6.dist-info → x_transformers-1.16.7.dist-info}/RECORD +7 -7
- {x_transformers-1.16.6.dist-info → x_transformers-1.16.7.dist-info}/LICENSE +0 -0
- {x_transformers-1.16.6.dist-info → x_transformers-1.16.7.dist-info}/WHEEL +0 -0
- {x_transformers-1.16.6.dist-info → x_transformers-1.16.7.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -253,9 +253,9 @@ class Attend(nn.Module):
|
|
253
253
|
|
254
254
|
# cascading heads logic
|
255
255
|
|
256
|
-
def to_single_heads(t):
|
257
|
-
heads = t.unbind(dim =
|
258
|
-
return tuple(
|
256
|
+
def to_single_heads(t, dim = 1):
|
257
|
+
heads = t.unbind(dim = dim)
|
258
|
+
return tuple(head.unsqueeze(dim) for head in heads)
|
259
259
|
|
260
260
|
class CascadingHeads(nn.Module):
|
261
261
|
def __init__(self, attend: Attend):
|
@@ -281,9 +281,7 @@ class CascadingHeads(nn.Module):
|
|
281
281
|
|
282
282
|
mask = (mask,) * heads
|
283
283
|
|
284
|
-
attn_bias = attn_bias
|
285
|
-
attn_bias = map(lambda t: rearrange(t, '... -> 1 ...'), attn_bias)
|
286
|
-
|
284
|
+
attn_bias = to_single_heads(attn_bias, dim = 0) if exists(attn_bias) else ((None,) * heads)
|
287
285
|
prev_attn = to_single_heads(prev_attn) if exists(prev_attn) else ((None,) * heads)
|
288
286
|
|
289
287
|
# now loop through each head, without output of previous head summed with the next head
|
x_transformers/x_transformers.py
CHANGED
@@ -358,7 +358,7 @@ class AlibiPositionalBias(nn.Module):
|
|
358
358
|
def forward(self, i, j):
|
359
359
|
h, device = self.total_heads, self.device
|
360
360
|
|
361
|
-
if exists(self.bias) and self.bias.shape[-1] >= j:
|
361
|
+
if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
|
362
362
|
return self.bias[..., :i, :j]
|
363
363
|
|
364
364
|
bias = self.get_bias(i, j, device)
|
@@ -382,7 +382,7 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias):
|
|
382
382
|
def get_slopes(param):
|
383
383
|
return pad_at_dim(param.exp(), (0, h - param.shape[0]), dim = -2)
|
384
384
|
|
385
|
-
if exists(self.bias) and self.bias.shape[-1] >= j:
|
385
|
+
if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
|
386
386
|
bias = self.bias[..., :i, :j]
|
387
387
|
else:
|
388
388
|
bias = self.get_bias(i, j, device)
|
@@ -971,6 +971,8 @@ class AttentionLayers(nn.Module):
|
|
971
971
|
|
972
972
|
self.residual_attn = residual_attn
|
973
973
|
self.cross_residual_attn = cross_residual_attn
|
974
|
+
assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
|
975
|
+
|
974
976
|
self.cross_attend = cross_attend
|
975
977
|
|
976
978
|
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
@@ -1,12 +1,12 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=ivhVpP_5vd6798HNq92DY0XZWjAJGmpE4qOdKW5yRaI,10379
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=u2celA8KeHm_Gd83Q7qaiLbJnwaDGdsbUck-JiokpKg,4446
|
4
4
|
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=xc4b05Y9vlGBXayJvpK775r4Dr7NlVusIVdqS3I09-4,54199
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=-CAYjTtqrks8ZTxjYm2stOelZpU4MbZIvLjUxWO0P9Y,3988
|
8
|
-
x_transformers-1.16.
|
9
|
-
x_transformers-1.16.
|
10
|
-
x_transformers-1.16.
|
11
|
-
x_transformers-1.16.
|
12
|
-
x_transformers-1.16.
|
8
|
+
x_transformers-1.16.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.16.7.dist-info/METADATA,sha256=wpuosM4b40fjCe0WHAilFQDVzLDx7_yqIyImTp-2380,665
|
10
|
+
x_transformers-1.16.7.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
|
11
|
+
x_transformers-1.16.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.16.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|