x-transformers 1.44.5__tar.gz → 1.44.7__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.44.5/x_transformers.egg-info → x_transformers-1.44.7}/PKG-INFO +1 -1
  2. {x_transformers-1.44.5 → x_transformers-1.44.7}/setup.py +1 -1
  3. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers/x_transformers.py +48 -11
  4. {x_transformers-1.44.5 → x_transformers-1.44.7/x_transformers.egg-info}/PKG-INFO +1 -1
  5. {x_transformers-1.44.5 → x_transformers-1.44.7}/LICENSE +0 -0
  6. {x_transformers-1.44.5 → x_transformers-1.44.7}/README.md +0 -0
  7. {x_transformers-1.44.5 → x_transformers-1.44.7}/setup.cfg +0 -0
  8. {x_transformers-1.44.5 → x_transformers-1.44.7}/tests/test_x_transformers.py +0 -0
  9. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.44.5 → x_transformers-1.44.7}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: x-transformers
3
- Version: 1.44.5
3
+ Version: 1.44.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.44.5',
6
+ version = '1.44.7',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -882,6 +882,7 @@ class HyperConnection(Module):
882
882
  *,
883
883
  layer_index,
884
884
  num_residual_streams,
885
+ num_input_views = 1,
885
886
  tanh = True,
886
887
  **kwargs
887
888
  ):
@@ -900,13 +901,16 @@ class HyperConnection(Module):
900
901
 
901
902
  self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
902
903
 
903
- init_alpha0 = torch.zeros((num_residual_streams, 1))
904
- init_alpha0[layer_index % num_residual_streams, 0] = 1.
904
+ init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
905
+ init_alpha0[layer_index % num_residual_streams, :] = 1.
905
906
 
906
907
  self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
907
908
 
908
- self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
909
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
909
910
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
911
+
912
+ self.num_input_views = num_input_views
913
+
910
914
  self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
911
915
  self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
912
916
 
@@ -928,7 +932,13 @@ class HyperConnection(Module):
928
932
 
929
933
  mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
930
934
 
931
- branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
935
+ views = self.num_input_views
936
+
937
+ if views == 1:
938
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
939
+ else:
940
+ branch_input, residuals = mix_h[..., :views, :], mix_h[..., views:, :]
941
+ branch_input = rearrange(branch_input, '... v d -> v ... d')
932
942
 
933
943
  return branch_input, residuals, dict(beta = beta)
934
944
 
@@ -1200,6 +1210,7 @@ class Attention(Module):
1200
1210
  learned_value_residual_mix = False,
1201
1211
  laser = False, # https://arxiv.org/abs/2411.03493v1
1202
1212
  laser_softclamp_value = 15.,
1213
+ qkv_receive_diff_residuals = False,
1203
1214
  onnxable = False,
1204
1215
  attend_sdp_kwargs: dict = dict(
1205
1216
  enable_flash = True,
@@ -1239,6 +1250,10 @@ class Attention(Module):
1239
1250
  assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1240
1251
  self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
1241
1252
 
1253
+ # whether qkv receives different residual stream combinations from hyper connections
1254
+
1255
+ self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
1256
+
1242
1257
  # enhancing gradients to attention through exponentiated values
1243
1258
 
1244
1259
  self.laser = laser
@@ -1423,14 +1438,21 @@ class Attention(Module):
1423
1438
  cache: Intermediates | None = None,
1424
1439
  value_residual = None
1425
1440
  ):
1426
- b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
1441
+ b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals
1427
1442
 
1428
- kv_input = default(context, x)
1443
+ assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'
1429
1444
 
1430
- q_input = x
1431
- k_input = kv_input
1432
- v_input = kv_input
1433
- r_input = x
1445
+ if qkv_receive_diff_residuals:
1446
+ assert not exists(self.to_r)
1447
+
1448
+ q_input, k_input, v_input = x
1449
+ else:
1450
+ kv_input = default(context, x)
1451
+
1452
+ q_input = x
1453
+ k_input = kv_input
1454
+ v_input = kv_input
1455
+ r_input = x
1434
1456
 
1435
1457
  if exists(mem):
1436
1458
  k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
@@ -1735,7 +1757,9 @@ class AttentionLayers(Module):
1735
1757
  layerscale_init_value = 0.,
1736
1758
  unet_skips = False,
1737
1759
  num_residual_streams = 1,
1760
+ qkv_receive_diff_residuals = False,
1738
1761
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1762
+ learned_reinject_input_gate = False,
1739
1763
  add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
1740
1764
  learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
1741
1765
  rel_pos_kwargs: dict = dict(),
@@ -1770,6 +1794,8 @@ class AttentionLayers(Module):
1770
1794
 
1771
1795
  assert not (num_residual_streams > 1 and gate_residual)
1772
1796
 
1797
+ assert not (num_residual_streams == 1 and qkv_receive_diff_residuals)
1798
+
1773
1799
  # positions related
1774
1800
 
1775
1801
  self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
@@ -1993,6 +2019,7 @@ class AttentionLayers(Module):
1993
2019
 
1994
2020
  self.reinject_input = reinject_input
1995
2021
  self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
2022
+ self.learned_reinject_input_gate = nn.Linear(dim, 1, bias = False) if learned_reinject_input_gate else None
1996
2023
 
1997
2024
  # add the value from the first self attention block to all latter projected self attention values as a residual
1998
2025
 
@@ -2018,7 +2045,7 @@ class AttentionLayers(Module):
2018
2045
 
2019
2046
  if layer_type == 'a':
2020
2047
  self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
2021
- layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
2048
+ layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receive_diff_residuals, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
2022
2049
  is_first_self_attn = False
2023
2050
  elif layer_type == 'c':
2024
2051
  layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
@@ -2039,6 +2066,10 @@ class AttentionLayers(Module):
2039
2066
 
2040
2067
  if num_residual_streams > 1:
2041
2068
  residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
2069
+
2070
+ if layer_type == 'a':
2071
+ residual_fn = partial(residual_fn, num_input_views = 3)
2072
+
2042
2073
  elif gate_residual:
2043
2074
  residual_fn = GRUGating
2044
2075
  else:
@@ -2224,7 +2255,9 @@ class AttentionLayers(Module):
2224
2255
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
2225
2256
 
2226
2257
  # derived input for reinjection if needed
2258
+
2227
2259
  inp_inject = None
2260
+
2228
2261
  if self.reinject_input:
2229
2262
  assert not exists(in_attn_cond)
2230
2263
  inp_inject = self.reinject_input_proj(x)
@@ -2233,6 +2266,10 @@ class AttentionLayers(Module):
2233
2266
  # handle in-attention conditioning, which serves the same purpose of having the network learn the residual
2234
2267
  inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
2235
2268
 
2269
+ if exists(inp_inject) and exists(self.learned_reinject_input_gate):
2270
+ inp_inject_gate = self.learned_reinject_input_gate(x).sigmoid()
2271
+ inp_inject = inp_inject * inp_inject_gate
2272
+
2236
2273
  # store all hiddens for skips
2237
2274
 
2238
2275
  skip_hiddens = []
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: x-transformers
3
- Version: 1.44.5
3
+ Version: 1.44.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes