x-transformers 1.40.5__py3-none-any.whl → 1.40.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -22,6 +22,7 @@ class Intermediates:
22
22
  qk_similarities: Tensor | None = None
23
23
  pre_softmax_attn: Tensor | None = None
24
24
  post_softmax_attn: Tensor | None = None
25
+ values: Tensor | None = None
25
26
  cached_kv: Tuple[Tensor, Tensor] | None = None
26
27
  layer_type: str | None = None
27
28
 
@@ -1114,6 +1114,7 @@ class Attention(Module):
1114
1114
  mem_mask = None,
1115
1115
  return_intermediates = False,
1116
1116
  cache: Intermediates | None = None,
1117
+ value_residual = None
1117
1118
  ):
1118
1119
  b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
1119
1120
 
@@ -1243,6 +1244,12 @@ class Attention(Module):
1243
1244
  attn_bias = rel_pos(i, j)
1244
1245
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1245
1246
 
1247
+ # previous values passed in
1248
+ # https://arxiv.org/abs/2410.17897v1
1249
+
1250
+ if exists(value_residual):
1251
+ v = v + value_residual
1252
+
1246
1253
  # attention is all we need
1247
1254
 
1248
1255
  out, intermediates = self.attend(
@@ -1252,6 +1259,10 @@ class Attention(Module):
1252
1259
  prev_attn = prev_attn
1253
1260
  )
1254
1261
 
1262
+ # store the values for resformer from Zhou et al. https://arxiv.org/abs/2410.17897v1
1263
+
1264
+ intermediates.values = v
1265
+
1255
1266
  # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1256
1267
 
1257
1268
  if exists(r):
@@ -1354,6 +1365,7 @@ class AttentionLayers(Module):
1354
1365
  layerscale_init_value = 0.,
1355
1366
  unet_skips = False,
1356
1367
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1368
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 | TODO: also add NeuTRENO from Nguyen et al. https://arxiv.org/abs/2312.00751
1357
1369
  **kwargs
1358
1370
  ):
1359
1371
  super().__init__()
@@ -1588,6 +1600,10 @@ class AttentionLayers(Module):
1588
1600
  self.reinject_input = reinject_input
1589
1601
  self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
1590
1602
 
1603
+ # add the value from the first self attention block to all latter projected self attention values as a residual
1604
+
1605
+ self.add_value_residual = add_value_residual
1606
+
1591
1607
  # iterate and construct layers
1592
1608
 
1593
1609
  for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
@@ -1787,6 +1803,8 @@ class AttentionLayers(Module):
1787
1803
 
1788
1804
  skip_hiddens = []
1789
1805
 
1806
+ first_self_attn_inter = None
1807
+
1790
1808
  # go through the attention and feedforward layers
1791
1809
 
1792
1810
  for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
@@ -1838,13 +1856,20 @@ class AttentionLayers(Module):
1838
1856
 
1839
1857
  block = partial(block, **block_forward_kwargs)
1840
1858
 
1859
+ maybe_value_residual = None
1860
+ if self.add_value_residual and exists(first_self_attn_inter):
1861
+ maybe_value_residual = first_self_attn_inter.values
1862
+
1841
1863
  if layer_type == 'a':
1842
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, return_intermediates = True)
1864
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_value_residual, return_intermediates = True)
1843
1865
  elif layer_type == 'c':
1844
1866
  out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1845
1867
  elif layer_type == 'f':
1846
1868
  out = block(x)
1847
1869
 
1870
+ if not exists(first_self_attn_inter) and layer_type == 'a':
1871
+ first_self_attn_inter = inter
1872
+
1848
1873
  if self.resi_dual:
1849
1874
  outer_residual = outer_residual + out * self.resi_dual_scale
1850
1875
 
x_transformers/xval.py CHANGED
@@ -18,7 +18,8 @@ from x_transformers.x_transformers import (
18
18
  AttentionLayers,
19
19
  TokenEmbedding,
20
20
  ScaledSinusoidalEmbedding,
21
- AbsolutePositionalEmbedding
21
+ AbsolutePositionalEmbedding,
22
+ always
22
23
  )
23
24
 
24
25
  from x_transformers.autoregressive_wrapper import (
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.5
3
+ Version: 1.40.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=VbB0fi-ETgAF4dc2a_Meaqvt14LMaRVIjZ8NexUX8F0,17239
2
+ x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=FKHaJFQuMNiFMrjDF13OE3vk-iYf_qwogBNxVpiQSc4,85671
8
+ x_transformers/x_transformers.py,sha256=azvG00DKeg4s0tH28QeDhU6X2GaeFjq3-0RMxgVt408,86715
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
- x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.40.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.5.dist-info/METADATA,sha256=WfbVonMAKfuqdCoXwi_AfnwsmCyx1310dqKoFnEWtiY,661
13
- x_transformers-1.40.5.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.5.dist-info/RECORD,,
10
+ x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
+ x_transformers-1.40.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.7.dist-info/METADATA,sha256=mhcww_WPNU2q-piUEvoDEXwForJdgk9s_K6iCQVFDqo,661
13
+ x_transformers-1.40.7.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.7.dist-info/RECORD,,