x-transformers 1.16.6__py3-none-any.whl → 1.16.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -253,9 +253,9 @@ class Attend(nn.Module):
253
253
 
254
254
  # cascading heads logic
255
255
 
256
- def to_single_heads(t):
257
- heads = t.unbind(dim = 1)
258
- return tuple(rearrange(head, 'b ... -> b 1 ...') for head in heads)
256
+ def to_single_heads(t, dim = 1):
257
+ heads = t.unbind(dim = dim)
258
+ return tuple(head.unsqueeze(dim) for head in heads)
259
259
 
260
260
  class CascadingHeads(nn.Module):
261
261
  def __init__(self, attend: Attend):
@@ -281,9 +281,7 @@ class CascadingHeads(nn.Module):
281
281
 
282
282
  mask = (mask,) * heads
283
283
 
284
- attn_bias = attn_bias.unbind(dim = 0) if exists(attn_bias) else ((None,) * heads)
285
- attn_bias = map(lambda t: rearrange(t, '... -> 1 ...'), attn_bias)
286
-
284
+ attn_bias = to_single_heads(attn_bias, dim = 0) if exists(attn_bias) else ((None,) * heads)
287
285
  prev_attn = to_single_heads(prev_attn) if exists(prev_attn) else ((None,) * heads)
288
286
 
289
287
  # now loop through each head, without output of previous head summed with the next head
@@ -358,7 +358,7 @@ class AlibiPositionalBias(nn.Module):
358
358
  def forward(self, i, j):
359
359
  h, device = self.total_heads, self.device
360
360
 
361
- if exists(self.bias) and self.bias.shape[-1] >= j:
361
+ if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
362
362
  return self.bias[..., :i, :j]
363
363
 
364
364
  bias = self.get_bias(i, j, device)
@@ -382,7 +382,7 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias):
382
382
  def get_slopes(param):
383
383
  return pad_at_dim(param.exp(), (0, h - param.shape[0]), dim = -2)
384
384
 
385
- if exists(self.bias) and self.bias.shape[-1] >= j:
385
+ if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
386
386
  bias = self.bias[..., :i, :j]
387
387
  else:
388
388
  bias = self.get_bias(i, j, device)
@@ -971,6 +971,8 @@ class AttentionLayers(nn.Module):
971
971
 
972
972
  self.residual_attn = residual_attn
973
973
  self.cross_residual_attn = cross_residual_attn
974
+ assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
975
+
974
976
  self.cross_attend = cross_attend
975
977
 
976
978
  norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.16.6
3
+ Version: 1.16.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,12 +1,12 @@
1
1
  x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
2
- x_transformers/attend.py,sha256=dm301IVJAxVI7UthAoyr2wrxaC4bgprjZL3hmEVZ91M,10450
2
+ x_transformers/attend.py,sha256=ivhVpP_5vd6798HNq92DY0XZWjAJGmpE4qOdKW5yRaI,10379
3
3
  x_transformers/autoregressive_wrapper.py,sha256=u2celA8KeHm_Gd83Q7qaiLbJnwaDGdsbUck-JiokpKg,4446
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=Rax04MaANgByO2ZERoqptGD4Lo-RL2nsWAFE85nQ3_I,54004
6
+ x_transformers/x_transformers.py,sha256=xc4b05Y9vlGBXayJvpK775r4Dr7NlVusIVdqS3I09-4,54199
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=-CAYjTtqrks8ZTxjYm2stOelZpU4MbZIvLjUxWO0P9Y,3988
8
- x_transformers-1.16.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.16.6.dist-info/METADATA,sha256=69XyDEAwgGUIZxWik2cqgqyeZqFxnN4KjWSeKPn0pzY,665
10
- x_transformers-1.16.6.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
11
- x_transformers-1.16.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.16.6.dist-info/RECORD,,
8
+ x_transformers-1.16.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.16.7.dist-info/METADATA,sha256=wpuosM4b40fjCe0WHAilFQDVzLDx7_yqIyImTp-2380,665
10
+ x_transformers-1.16.7.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
11
+ x_transformers-1.16.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.16.7.dist-info/RECORD,,