x-transformers 1.40.7__py3-none-any.whl → 1.40.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1803,7 +1803,10 @@ class AttentionLayers(Module):
1803
1803
 
1804
1804
  skip_hiddens = []
1805
1805
 
1806
+ # for value residuals
1807
+
1806
1808
  first_self_attn_inter = None
1809
+ first_cross_attn_inter = None
1807
1810
 
1808
1811
  # go through the attention and feedforward layers
1809
1812
 
@@ -1856,20 +1859,35 @@ class AttentionLayers(Module):
1856
1859
 
1857
1860
  block = partial(block, **block_forward_kwargs)
1858
1861
 
1859
- maybe_value_residual = None
1860
- if self.add_value_residual and exists(first_self_attn_inter):
1861
- maybe_value_residual = first_self_attn_inter.values
1862
+ # handle maybe value residuals
1863
+
1864
+ maybe_self_attn_value_residual = None
1865
+ maybe_cross_attn_value_residual = None
1866
+
1867
+ if self.add_value_residual:
1868
+ if exists(first_self_attn_inter):
1869
+ maybe_self_attn_value_residual = first_self_attn_inter.values
1870
+
1871
+ if exists(first_cross_attn_inter):
1872
+ maybe_cross_attn_value_residual = first_cross_attn_inter.values
1873
+
1874
+ # forward depending on layer type
1862
1875
 
1863
1876
  if layer_type == 'a':
1864
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_value_residual, return_intermediates = True)
1877
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
1865
1878
  elif layer_type == 'c':
1866
- out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1879
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
1867
1880
  elif layer_type == 'f':
1868
1881
  out = block(x)
1869
1882
 
1883
+ # store first self or cross attention intermediate for value residual
1884
+
1870
1885
  if not exists(first_self_attn_inter) and layer_type == 'a':
1871
1886
  first_self_attn_inter = inter
1872
1887
 
1888
+ if not exists(first_cross_attn_inter) and layer_type == 'c':
1889
+ first_cross_attn_inter = inter
1890
+
1873
1891
  if self.resi_dual:
1874
1892
  outer_residual = outer_residual + out * self.resi_dual_scale
1875
1893
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.7
3
+ Version: 1.40.8
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=azvG00DKeg4s0tH28QeDhU6X2GaeFjq3-0RMxgVt408,86715
8
+ x_transformers/x_transformers.py,sha256=UyJ3-XoanCHOUZLoyp29jt-yNWbgc7pFqpfjOGFFEPY,87367
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.40.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.7.dist-info/METADATA,sha256=mhcww_WPNU2q-piUEvoDEXwForJdgk9s_K6iCQVFDqo,661
13
- x_transformers-1.40.7.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.7.dist-info/RECORD,,
11
+ x_transformers-1.40.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.8.dist-info/METADATA,sha256=epku2YLENdljdFU7ge7hhUDXMujquHyb0jhoDSv9yFk,661
13
+ x_transformers-1.40.8.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.8.dist-info/RECORD,,