x-transformers 1.42.12__tar.gz → 1.42.14__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.42.12/x_transformers.egg-info → x_transformers-1.42.14}/PKG-INFO +1 -1
  2. {x_transformers-1.42.12 → x_transformers-1.42.14}/setup.py +1 -1
  3. {x_transformers-1.42.12 → x_transformers-1.42.14}/tests/test_x_transformers.py +5 -1
  4. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers/x_transformers.py +24 -5
  5. {x_transformers-1.42.12 → x_transformers-1.42.14/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.42.12 → x_transformers-1.42.14}/LICENSE +0 -0
  7. {x_transformers-1.42.12 → x_transformers-1.42.14}/README.md +0 -0
  8. {x_transformers-1.42.12 → x_transformers-1.42.14}/setup.cfg +0 -0
  9. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.42.12 → x_transformers-1.42.14}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.12
3
+ Version: 1.42.14
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.42.12',
6
+ version = '1.42.14',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -331,7 +331,10 @@ def test_reinject_input():
331
331
 
332
332
  model(x) # (1, 1024, 20000)
333
333
 
334
- def test_value_residual():
334
+ @pytest.mark.parametrize('learned_value_residual_mix', (False, True))
335
+ def test_value_residual(
336
+ learned_value_residual_mix: bool
337
+ ):
335
338
 
336
339
  model = TransformerWrapper(
337
340
  num_tokens = 20000,
@@ -341,6 +344,7 @@ def test_value_residual():
341
344
  depth = 6,
342
345
  heads = 8,
343
346
  add_value_residual = True,
347
+ learned_value_residual_mix = learned_value_residual_mix
344
348
  )
345
349
  )
346
350
 
@@ -1072,6 +1072,7 @@ class Attention(Module):
1072
1072
  logit_softclamp_value = 50.,
1073
1073
  neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1074
1074
  neutreno_alpha = 0.4,
1075
+ learned_value_residual_mix = False,
1075
1076
  onnxable = False,
1076
1077
  attend_sdp_kwargs: dict = dict(
1077
1078
  enable_flash = True,
@@ -1231,6 +1232,14 @@ class Attention(Module):
1231
1232
  self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
1232
1233
  self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
1233
1234
 
1235
+ # maybe learned value residual mixer per token
1236
+
1237
+ self.to_value_residual_mix = nn.Sequential(
1238
+ nn.Linear(dim, 1),
1239
+ nn.Sigmoid(),
1240
+ Rearrange('b n 1 -> b 1 n 1')
1241
+ ) if learned_value_residual_mix else always(0.5)
1242
+
1234
1243
  # attention on attention
1235
1244
 
1236
1245
  self.attn_on_attn = on_attn
@@ -1303,7 +1312,8 @@ class Attention(Module):
1303
1312
  diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
1304
1313
  else:
1305
1314
  # https://arxiv.org/abs/2410.17897v1
1306
- v = 0.5 * (v + value_residual)
1315
+ value_residual_mix = self.to_value_residual_mix(q_input)
1316
+ v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
1307
1317
 
1308
1318
  # take care of caching
1309
1319
 
@@ -1541,8 +1551,9 @@ class AttentionLayers(Module):
1541
1551
  use_layerscale = False,
1542
1552
  layerscale_init_value = 0.,
1543
1553
  unet_skips = False,
1544
- reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1545
- add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1554
+ reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1555
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1556
+ learned_value_residual_mix = False, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1546
1557
  rel_pos_kwargs: dict = dict(),
1547
1558
  **kwargs
1548
1559
  ):
@@ -1786,6 +1797,10 @@ class AttentionLayers(Module):
1786
1797
 
1787
1798
  self.add_value_residual = add_value_residual
1788
1799
 
1800
+ is_first_self_attn = True
1801
+ is_first_cross_attn = True
1802
+ learned_value_residual_mix &= add_value_residual
1803
+
1789
1804
  # iterate and construct layers
1790
1805
 
1791
1806
  for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
@@ -1801,9 +1816,13 @@ class AttentionLayers(Module):
1801
1816
  # attention, cross attention, feedforward
1802
1817
 
1803
1818
  if layer_type == 'a':
1804
- layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1819
+ self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
1820
+ layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
1821
+ is_first_self_attn = False
1805
1822
  elif layer_type == 'c':
1806
- layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
1823
+ cross_attn_learned_value_residual = learned_value_residual_mix and not is_first_cross_attn
1824
+ layer = Attention(dim, heads = heads, learned_value_residual_mix = learned_value_residual_mix and not is_first_cross_attn, **{**attn_kwargs, **cross_attn_kwargs})
1825
+ is_first_cross_attn = False
1807
1826
  elif layer_type == 'f':
1808
1827
  layer = FeedForward(dim, **ff_kwargs)
1809
1828
  layer = layer if not macaron else Scale(0.5, layer)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.12
3
+ Version: 1.42.14
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang