x-transformers 1.44.6__py3-none-any.whl → 1.44.7__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -882,6 +882,7 @@ class HyperConnection(Module):
882
882
  *,
883
883
  layer_index,
884
884
  num_residual_streams,
885
+ num_input_views = 1,
885
886
  tanh = True,
886
887
  **kwargs
887
888
  ):
@@ -900,13 +901,16 @@ class HyperConnection(Module):
900
901
 
901
902
  self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
902
903
 
903
- init_alpha0 = torch.zeros((num_residual_streams, 1))
904
- init_alpha0[layer_index % num_residual_streams, 0] = 1.
904
+ init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
905
+ init_alpha0[layer_index % num_residual_streams, :] = 1.
905
906
 
906
907
  self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
907
908
 
908
- self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
909
+ self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
909
910
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
911
+
912
+ self.num_input_views = num_input_views
913
+
910
914
  self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
911
915
  self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
912
916
 
@@ -928,7 +932,13 @@ class HyperConnection(Module):
928
932
 
929
933
  mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
930
934
 
931
- branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
935
+ views = self.num_input_views
936
+
937
+ if views == 1:
938
+ branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
939
+ else:
940
+ branch_input, residuals = mix_h[..., :views, :], mix_h[..., views:, :]
941
+ branch_input = rearrange(branch_input, '... v d -> v ... d')
932
942
 
933
943
  return branch_input, residuals, dict(beta = beta)
934
944
 
@@ -1200,6 +1210,7 @@ class Attention(Module):
1200
1210
  learned_value_residual_mix = False,
1201
1211
  laser = False, # https://arxiv.org/abs/2411.03493v1
1202
1212
  laser_softclamp_value = 15.,
1213
+ qkv_receive_diff_residuals = False,
1203
1214
  onnxable = False,
1204
1215
  attend_sdp_kwargs: dict = dict(
1205
1216
  enable_flash = True,
@@ -1239,6 +1250,10 @@ class Attention(Module):
1239
1250
  assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1240
1251
  self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
1241
1252
 
1253
+ # whether qkv receives different residual stream combinations from hyper connections
1254
+
1255
+ self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
1256
+
1242
1257
  # enhancing gradients to attention through exponentiated values
1243
1258
 
1244
1259
  self.laser = laser
@@ -1423,14 +1438,21 @@ class Attention(Module):
1423
1438
  cache: Intermediates | None = None,
1424
1439
  value_residual = None
1425
1440
  ):
1426
- b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
1441
+ b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals
1427
1442
 
1428
- kv_input = default(context, x)
1443
+ assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'
1429
1444
 
1430
- q_input = x
1431
- k_input = kv_input
1432
- v_input = kv_input
1433
- r_input = x
1445
+ if qkv_receive_diff_residuals:
1446
+ assert not exists(self.to_r)
1447
+
1448
+ q_input, k_input, v_input = x
1449
+ else:
1450
+ kv_input = default(context, x)
1451
+
1452
+ q_input = x
1453
+ k_input = kv_input
1454
+ v_input = kv_input
1455
+ r_input = x
1434
1456
 
1435
1457
  if exists(mem):
1436
1458
  k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
@@ -1735,6 +1757,7 @@ class AttentionLayers(Module):
1735
1757
  layerscale_init_value = 0.,
1736
1758
  unet_skips = False,
1737
1759
  num_residual_streams = 1,
1760
+ qkv_receive_diff_residuals = False,
1738
1761
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1739
1762
  learned_reinject_input_gate = False,
1740
1763
  add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
@@ -1771,6 +1794,8 @@ class AttentionLayers(Module):
1771
1794
 
1772
1795
  assert not (num_residual_streams > 1 and gate_residual)
1773
1796
 
1797
+ assert not (num_residual_streams == 1 and qkv_receive_diff_residuals)
1798
+
1774
1799
  # positions related
1775
1800
 
1776
1801
  self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
@@ -2020,7 +2045,7 @@ class AttentionLayers(Module):
2020
2045
 
2021
2046
  if layer_type == 'a':
2022
2047
  self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
2023
- layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
2048
+ layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receive_diff_residuals, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
2024
2049
  is_first_self_attn = False
2025
2050
  elif layer_type == 'c':
2026
2051
  layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
@@ -2041,6 +2066,10 @@ class AttentionLayers(Module):
2041
2066
 
2042
2067
  if num_residual_streams > 1:
2043
2068
  residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
2069
+
2070
+ if layer_type == 'a':
2071
+ residual_fn = partial(residual_fn, num_input_views = 3)
2072
+
2044
2073
  elif gate_residual:
2045
2074
  residual_fn = GRUGating
2046
2075
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: x-transformers
3
- Version: 1.44.6
3
+ Version: 1.44.7
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=PYrwLPEUaiWuPmDNV7nQQZChfMlPJbF9NULHl9Te3LQ,103494
9
+ x_transformers/x_transformers.py,sha256=4jgGkk-OkPdZYPouPs768KHw5eFmskC3m61h6LPfkKY,104695
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.44.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.44.6.dist-info/METADATA,sha256=3_gOvzIcumtCNqhjGmlcAPMZ2FO6q4sVlhAV-_sybBA,924
14
- x_transformers-1.44.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
15
- x_transformers-1.44.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.44.6.dist-info/RECORD,,
12
+ x_transformers-1.44.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.44.7.dist-info/METADATA,sha256=D9xg_-8z4eeOXCwP1w86SFthU9RanucoxlpBLkDCUJ0,924
14
+ x_transformers-1.44.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
15
+ x_transformers-1.44.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.44.7.dist-info/RECORD,,