x-transformers 1.40.2__py3-none-any.whl → 1.40.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -442,10 +442,10 @@ class DynamicPositionBias(Module):
442
442
  return bias
443
443
 
444
444
  class AlibiPositionalBias(Module):
445
- def __init__(self, heads, total_heads, **kwargs):
445
+ def __init__(self, heads, total_heads = None, **kwargs):
446
446
  super().__init__()
447
447
  self.heads = heads
448
- self.total_heads = total_heads
448
+ self.total_heads = default(total_heads, heads)
449
449
 
450
450
  slopes = Tensor(self._get_slopes(heads))
451
451
  slopes = rearrange(slopes, 'h -> h 1 1')
@@ -1665,6 +1665,7 @@ class AttentionLayers(Module):
1665
1665
  cache_age = 1,
1666
1666
  return_hiddens = False,
1667
1667
  rotary_pos_emb = None,
1668
+ attn_bias = None,
1668
1669
  condition = None,
1669
1670
  layers_execute_order: tuple[int, ...] | None = None
1670
1671
  ):
@@ -1818,7 +1819,7 @@ class AttentionLayers(Module):
1818
1819
  block = partial(block, **block_forward_kwargs)
1819
1820
 
1820
1821
  if layer_type == 'a':
1821
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
1822
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, return_intermediates = True)
1822
1823
  elif layer_type == 'c':
1823
1824
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1824
1825
  elif layer_type == 'f':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.2
3
+ Version: 1.40.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=SfM0ql3wK7t8KzBXRNnGTdcyq3tQVmHB4VIcfg5sSv4,84604
8
+ x_transformers/x_transformers.py,sha256=17wNbMh5Mt6cyrAvkBToIozmFt-p9ZBhQCqqlnDyHPI,84676
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.40.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.2.dist-info/METADATA,sha256=G1LWuKpy25e1rXV7MFRT5r2F4bHjvyMUrgxTgJIQLic,661
13
- x_transformers-1.40.2.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.2.dist-info/RECORD,,
11
+ x_transformers-1.40.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.3.dist-info/METADATA,sha256=2Qb9WF8pa8bmqMI4VdO9f9bBEmKMcUatlbcU2SDJGwM,661
13
+ x_transformers-1.40.3.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.3.dist-info/RECORD,,