x-transformers 1.40.1__py3-none-any.whl → 1.40.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -442,10 +442,10 @@ class DynamicPositionBias(Module):
442
442
  return bias
443
443
 
444
444
  class AlibiPositionalBias(Module):
445
- def __init__(self, heads, total_heads, **kwargs):
445
+ def __init__(self, heads, total_heads = None, **kwargs):
446
446
  super().__init__()
447
447
  self.heads = heads
448
- self.total_heads = total_heads
448
+ self.total_heads = default(total_heads, heads)
449
449
 
450
450
  slopes = Tensor(self._get_slopes(heads))
451
451
  slopes = rearrange(slopes, 'h -> h 1 1')
@@ -1107,6 +1107,7 @@ class Attention(Module):
1107
1107
  context_mask = None,
1108
1108
  attn_mask = None,
1109
1109
  rel_pos = None,
1110
+ attn_bias = None,
1110
1111
  rotary_pos_emb = None,
1111
1112
  prev_attn = None,
1112
1113
  mem = None,
@@ -1237,8 +1238,8 @@ class Attention(Module):
1237
1238
 
1238
1239
  # prepare relative positional bias, if needed
1239
1240
 
1240
- attn_bias = None
1241
1241
  if exists(rel_pos):
1242
+ assert not exists(attn_bias)
1242
1243
  attn_bias = rel_pos(i, j)
1243
1244
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1244
1245
 
@@ -1664,6 +1665,7 @@ class AttentionLayers(Module):
1664
1665
  cache_age = 1,
1665
1666
  return_hiddens = False,
1666
1667
  rotary_pos_emb = None,
1668
+ attn_bias = None,
1667
1669
  condition = None,
1668
1670
  layers_execute_order: tuple[int, ...] | None = None
1669
1671
  ):
@@ -1817,7 +1819,7 @@ class AttentionLayers(Module):
1817
1819
  block = partial(block, **block_forward_kwargs)
1818
1820
 
1819
1821
  if layer_type == 'a':
1820
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
1822
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, return_intermediates = True)
1821
1823
  elif layer_type == 'c':
1822
1824
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1823
1825
  elif layer_type == 'f':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.1
3
+ Version: 1.40.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=TGXJZXCWR5BiMkS5Kx-JhFQ85AxkiJabLiHnrCTC874,84562
8
+ x_transformers/x_transformers.py,sha256=17wNbMh5Mt6cyrAvkBToIozmFt-p9ZBhQCqqlnDyHPI,84676
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.40.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.1.dist-info/METADATA,sha256=WouMl3Ld1llknOwj7BcKi-_YZ9Hx9RZ-ni-eGCP_uQY,661
13
- x_transformers-1.40.1.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.1.dist-info/RECORD,,
11
+ x_transformers-1.40.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.3.dist-info/METADATA,sha256=2Qb9WF8pa8bmqMI4VdO9f9bBEmKMcUatlbcU2SDJGwM,661
13
+ x_transformers-1.40.3.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.3.dist-info/RECORD,,