x-transformers 1.40.6__py3-none-any.whl → 1.40.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -22,6 +22,7 @@ class Intermediates:
22
22
  qk_similarities: Tensor | None = None
23
23
  pre_softmax_attn: Tensor | None = None
24
24
  post_softmax_attn: Tensor | None = None
25
+ values: Tensor | None = None
25
26
  cached_kv: Tuple[Tensor, Tensor] | None = None
26
27
  layer_type: str | None = None
27
28
 
@@ -1114,6 +1114,7 @@ class Attention(Module):
1114
1114
  mem_mask = None,
1115
1115
  return_intermediates = False,
1116
1116
  cache: Intermediates | None = None,
1117
+ value_residual = None
1117
1118
  ):
1118
1119
  b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
1119
1120
 
@@ -1243,6 +1244,12 @@ class Attention(Module):
1243
1244
  attn_bias = rel_pos(i, j)
1244
1245
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1245
1246
 
1247
+ # previous values passed in
1248
+ # https://arxiv.org/abs/2410.17897v1
1249
+
1250
+ if exists(value_residual):
1251
+ v = v + value_residual
1252
+
1246
1253
  # attention is all we need
1247
1254
 
1248
1255
  out, intermediates = self.attend(
@@ -1252,6 +1259,10 @@ class Attention(Module):
1252
1259
  prev_attn = prev_attn
1253
1260
  )
1254
1261
 
1262
+ # store the values for resformer from Zhou et al. https://arxiv.org/abs/2410.17897v1
1263
+
1264
+ intermediates.values = v
1265
+
1255
1266
  # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1256
1267
 
1257
1268
  if exists(r):
@@ -1354,6 +1365,7 @@ class AttentionLayers(Module):
1354
1365
  layerscale_init_value = 0.,
1355
1366
  unet_skips = False,
1356
1367
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1368
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 | TODO: also add NeuTRENO from Nguyen et al. https://arxiv.org/abs/2312.00751
1357
1369
  **kwargs
1358
1370
  ):
1359
1371
  super().__init__()
@@ -1588,6 +1600,10 @@ class AttentionLayers(Module):
1588
1600
  self.reinject_input = reinject_input
1589
1601
  self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
1590
1602
 
1603
+ # add the value from the first self attention block to all latter projected self attention values as a residual
1604
+
1605
+ self.add_value_residual = add_value_residual
1606
+
1591
1607
  # iterate and construct layers
1592
1608
 
1593
1609
  for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
@@ -1787,6 +1803,11 @@ class AttentionLayers(Module):
1787
1803
 
1788
1804
  skip_hiddens = []
1789
1805
 
1806
+ # for value residuals
1807
+
1808
+ first_self_attn_inter = None
1809
+ first_cross_attn_inter = None
1810
+
1790
1811
  # go through the attention and feedforward layers
1791
1812
 
1792
1813
  for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
@@ -1838,13 +1859,35 @@ class AttentionLayers(Module):
1838
1859
 
1839
1860
  block = partial(block, **block_forward_kwargs)
1840
1861
 
1862
+ # handle maybe value residuals
1863
+
1864
+ maybe_self_attn_value_residual = None
1865
+ maybe_cross_attn_value_residual = None
1866
+
1867
+ if self.add_value_residual:
1868
+ if exists(first_self_attn_inter):
1869
+ maybe_self_attn_value_residual = first_self_attn_inter.values
1870
+
1871
+ if exists(first_cross_attn_inter):
1872
+ maybe_cross_attn_value_residual = first_cross_attn_inter.values
1873
+
1874
+ # forward depending on layer type
1875
+
1841
1876
  if layer_type == 'a':
1842
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, return_intermediates = True)
1877
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
1843
1878
  elif layer_type == 'c':
1844
- out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1879
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
1845
1880
  elif layer_type == 'f':
1846
1881
  out = block(x)
1847
1882
 
1883
+ # store first self or cross attention intermediate for value residual
1884
+
1885
+ if not exists(first_self_attn_inter) and layer_type == 'a':
1886
+ first_self_attn_inter = inter
1887
+
1888
+ if not exists(first_cross_attn_inter) and layer_type == 'c':
1889
+ first_cross_attn_inter = inter
1890
+
1848
1891
  if self.resi_dual:
1849
1892
  outer_residual = outer_residual + out * self.resi_dual_scale
1850
1893
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.6
3
+ Version: 1.40.8
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=VbB0fi-ETgAF4dc2a_Meaqvt14LMaRVIjZ8NexUX8F0,17239
2
+ x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=FKHaJFQuMNiFMrjDF13OE3vk-iYf_qwogBNxVpiQSc4,85671
8
+ x_transformers/x_transformers.py,sha256=UyJ3-XoanCHOUZLoyp29jt-yNWbgc7pFqpfjOGFFEPY,87367
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.40.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.6.dist-info/METADATA,sha256=HsoNUu71hkonsBhThVN46rakFnIAGOav3pHDpYnX9t8,661
13
- x_transformers-1.40.6.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.6.dist-info/RECORD,,
11
+ x_transformers-1.40.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.8.dist-info/METADATA,sha256=epku2YLENdljdFU7ge7hhUDXMujquHyb0jhoDSv9yFk,661
13
+ x_transformers-1.40.8.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.8.dist-info/RECORD,,