x-transformers 1.40.4__py3-none-any.whl → 1.40.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1353,7 +1353,7 @@ class AttentionLayers(Module):
1353
1353
  use_layerscale = False,
1354
1354
  layerscale_init_value = 0.,
1355
1355
  unet_skips = False,
1356
- reinject_input = True, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1356
+ reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1357
1357
  **kwargs
1358
1358
  ):
1359
1359
  super().__init__()
@@ -1673,6 +1673,7 @@ class AttentionLayers(Module):
1673
1673
  rotary_pos_emb = None,
1674
1674
  attn_bias = None,
1675
1675
  condition = None,
1676
+ in_attn_cond = None, # https://arxiv.org/abs/2105.04090
1676
1677
  layers_execute_order: tuple[int, ...] | None = None
1677
1678
  ):
1678
1679
  assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
@@ -1775,8 +1776,13 @@ class AttentionLayers(Module):
1775
1776
  # derived input for reinjection if needed
1776
1777
 
1777
1778
  if self.reinject_input:
1779
+ assert not exists(in_attn_cond)
1778
1780
  inp_inject = self.reinject_input_proj(x)
1779
1781
 
1782
+ elif exists(in_attn_cond):
1783
+ # handle in-attention conditioning, which serves the same purpose of having the network learn the residual
1784
+ inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
1785
+
1780
1786
  # store all hiddens for skips
1781
1787
 
1782
1788
  skip_hiddens = []
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.4
3
+ Version: 1.40.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=g7cGYUcq344QZejqdB_PPHHz1O_zqmdplxgQA07hqks,85298
8
+ x_transformers/x_transformers.py,sha256=FKHaJFQuMNiFMrjDF13OE3vk-iYf_qwogBNxVpiQSc4,85671
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.40.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.4.dist-info/METADATA,sha256=K7tcPd4ZY7qO773XErYQbik-XDEfLbFQ0lFYSfjAFvY,661
13
- x_transformers-1.40.4.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.4.dist-info/RECORD,,
11
+ x_transformers-1.40.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.5.dist-info/METADATA,sha256=WfbVonMAKfuqdCoXwi_AfnwsmCyx1310dqKoFnEWtiY,661
13
+ x_transformers-1.40.5.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.5.dist-info/RECORD,,