x-transformers 1.40.2__py3-none-any.whl → 1.40.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -442,10 +442,10 @@ class DynamicPositionBias(Module):
442
442
  return bias
443
443
 
444
444
  class AlibiPositionalBias(Module):
445
- def __init__(self, heads, total_heads, **kwargs):
445
+ def __init__(self, heads, total_heads = None, **kwargs):
446
446
  super().__init__()
447
447
  self.heads = heads
448
- self.total_heads = total_heads
448
+ self.total_heads = default(total_heads, heads)
449
449
 
450
450
  slopes = Tensor(self._get_slopes(heads))
451
451
  slopes = rearrange(slopes, 'h -> h 1 1')
@@ -1353,6 +1353,7 @@ class AttentionLayers(Module):
1353
1353
  use_layerscale = False,
1354
1354
  layerscale_init_value = 0.,
1355
1355
  unet_skips = False,
1356
+ reinject_input = True, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1356
1357
  **kwargs
1357
1358
  ):
1358
1359
  super().__init__()
@@ -1582,6 +1583,11 @@ class AttentionLayers(Module):
1582
1583
 
1583
1584
  self.skip_combines = ModuleList([])
1584
1585
 
1586
+ # whether there is reinjection of input at every layer
1587
+
1588
+ self.reinject_input = reinject_input
1589
+ self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
1590
+
1585
1591
  # iterate and construct layers
1586
1592
 
1587
1593
  for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
@@ -1665,6 +1671,7 @@ class AttentionLayers(Module):
1665
1671
  cache_age = 1,
1666
1672
  return_hiddens = False,
1667
1673
  rotary_pos_emb = None,
1674
+ attn_bias = None,
1668
1675
  condition = None,
1669
1676
  layers_execute_order: tuple[int, ...] | None = None
1670
1677
  ):
@@ -1765,6 +1772,11 @@ class AttentionLayers(Module):
1765
1772
 
1766
1773
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
1767
1774
 
1775
+ # derived input for reinjection if needed
1776
+
1777
+ if self.reinject_input:
1778
+ inp_inject = self.reinject_input_proj(x)
1779
+
1768
1780
  # store all hiddens for skips
1769
1781
 
1770
1782
  skip_hiddens = []
@@ -1809,6 +1821,9 @@ class AttentionLayers(Module):
1809
1821
  post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
1810
1822
  post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
1811
1823
 
1824
+ if self.reinject_input:
1825
+ x = x + inp_inject
1826
+
1812
1827
  if exists(pre_norm):
1813
1828
  x = pre_norm(x)
1814
1829
 
@@ -1818,7 +1833,7 @@ class AttentionLayers(Module):
1818
1833
  block = partial(block, **block_forward_kwargs)
1819
1834
 
1820
1835
  if layer_type == 'a':
1821
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
1836
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, return_intermediates = True)
1822
1837
  elif layer_type == 'c':
1823
1838
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1824
1839
  elif layer_type == 'f':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.2
3
+ Version: 1.40.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=SfM0ql3wK7t8KzBXRNnGTdcyq3tQVmHB4VIcfg5sSv4,84604
8
+ x_transformers/x_transformers.py,sha256=g7cGYUcq344QZejqdB_PPHHz1O_zqmdplxgQA07hqks,85298
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.40.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.2.dist-info/METADATA,sha256=G1LWuKpy25e1rXV7MFRT5r2F4bHjvyMUrgxTgJIQLic,661
13
- x_transformers-1.40.2.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.2.dist-info/RECORD,,
11
+ x_transformers-1.40.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.4.dist-info/METADATA,sha256=K7tcPd4ZY7qO773XErYQbik-XDEfLbFQ0lFYSfjAFvY,661
13
+ x_transformers-1.40.4.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.4.dist-info/RECORD,,