x-transformers 1.44.6__py3-none-any.whl → 1.44.7__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +40 -11
- {x_transformers-1.44.6.dist-info → x_transformers-1.44.7.dist-info}/METADATA +1 -1
- {x_transformers-1.44.6.dist-info → x_transformers-1.44.7.dist-info}/RECORD +6 -6
- {x_transformers-1.44.6.dist-info → x_transformers-1.44.7.dist-info}/LICENSE +0 -0
- {x_transformers-1.44.6.dist-info → x_transformers-1.44.7.dist-info}/WHEEL +0 -0
- {x_transformers-1.44.6.dist-info → x_transformers-1.44.7.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -882,6 +882,7 @@ class HyperConnection(Module):
|
|
882
882
|
*,
|
883
883
|
layer_index,
|
884
884
|
num_residual_streams,
|
885
|
+
num_input_views = 1,
|
885
886
|
tanh = True,
|
886
887
|
**kwargs
|
887
888
|
):
|
@@ -900,13 +901,16 @@ class HyperConnection(Module):
|
|
900
901
|
|
901
902
|
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
902
903
|
|
903
|
-
init_alpha0 = torch.zeros((num_residual_streams,
|
904
|
-
init_alpha0[layer_index % num_residual_streams,
|
904
|
+
init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
|
905
|
+
init_alpha0[layer_index % num_residual_streams, :] = 1.
|
905
906
|
|
906
907
|
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
907
908
|
|
908
|
-
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams +
|
909
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
|
909
910
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
911
|
+
|
912
|
+
self.num_input_views = num_input_views
|
913
|
+
|
910
914
|
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
911
915
|
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
912
916
|
|
@@ -928,7 +932,13 @@ class HyperConnection(Module):
|
|
928
932
|
|
929
933
|
mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
|
930
934
|
|
931
|
-
|
935
|
+
views = self.num_input_views
|
936
|
+
|
937
|
+
if views == 1:
|
938
|
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
939
|
+
else:
|
940
|
+
branch_input, residuals = mix_h[..., :views, :], mix_h[..., views:, :]
|
941
|
+
branch_input = rearrange(branch_input, '... v d -> v ... d')
|
932
942
|
|
933
943
|
return branch_input, residuals, dict(beta = beta)
|
934
944
|
|
@@ -1200,6 +1210,7 @@ class Attention(Module):
|
|
1200
1210
|
learned_value_residual_mix = False,
|
1201
1211
|
laser = False, # https://arxiv.org/abs/2411.03493v1
|
1202
1212
|
laser_softclamp_value = 15.,
|
1213
|
+
qkv_receive_diff_residuals = False,
|
1203
1214
|
onnxable = False,
|
1204
1215
|
attend_sdp_kwargs: dict = dict(
|
1205
1216
|
enable_flash = True,
|
@@ -1239,6 +1250,10 @@ class Attention(Module):
|
|
1239
1250
|
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
|
1240
1251
|
self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
|
1241
1252
|
|
1253
|
+
# whether qkv receives different residual stream combinations from hyper connections
|
1254
|
+
|
1255
|
+
self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
|
1256
|
+
|
1242
1257
|
# enhancing gradients to attention through exponentiated values
|
1243
1258
|
|
1244
1259
|
self.laser = laser
|
@@ -1423,14 +1438,21 @@ class Attention(Module):
|
|
1423
1438
|
cache: Intermediates | None = None,
|
1424
1439
|
value_residual = None
|
1425
1440
|
):
|
1426
|
-
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
|
1441
|
+
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals
|
1427
1442
|
|
1428
|
-
|
1443
|
+
assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'
|
1429
1444
|
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1445
|
+
if qkv_receive_diff_residuals:
|
1446
|
+
assert not exists(self.to_r)
|
1447
|
+
|
1448
|
+
q_input, k_input, v_input = x
|
1449
|
+
else:
|
1450
|
+
kv_input = default(context, x)
|
1451
|
+
|
1452
|
+
q_input = x
|
1453
|
+
k_input = kv_input
|
1454
|
+
v_input = kv_input
|
1455
|
+
r_input = x
|
1434
1456
|
|
1435
1457
|
if exists(mem):
|
1436
1458
|
k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
|
@@ -1735,6 +1757,7 @@ class AttentionLayers(Module):
|
|
1735
1757
|
layerscale_init_value = 0.,
|
1736
1758
|
unet_skips = False,
|
1737
1759
|
num_residual_streams = 1,
|
1760
|
+
qkv_receive_diff_residuals = False,
|
1738
1761
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1739
1762
|
learned_reinject_input_gate = False,
|
1740
1763
|
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
|
@@ -1771,6 +1794,8 @@ class AttentionLayers(Module):
|
|
1771
1794
|
|
1772
1795
|
assert not (num_residual_streams > 1 and gate_residual)
|
1773
1796
|
|
1797
|
+
assert not (num_residual_streams == 1 and qkv_receive_diff_residuals)
|
1798
|
+
|
1774
1799
|
# positions related
|
1775
1800
|
|
1776
1801
|
self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
|
@@ -2020,7 +2045,7 @@ class AttentionLayers(Module):
|
|
2020
2045
|
|
2021
2046
|
if layer_type == 'a':
|
2022
2047
|
self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
|
2023
|
-
layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
|
2048
|
+
layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receive_diff_residuals, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
|
2024
2049
|
is_first_self_attn = False
|
2025
2050
|
elif layer_type == 'c':
|
2026
2051
|
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
|
@@ -2041,6 +2066,10 @@ class AttentionLayers(Module):
|
|
2041
2066
|
|
2042
2067
|
if num_residual_streams > 1:
|
2043
2068
|
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
|
2069
|
+
|
2070
|
+
if layer_type == 'a':
|
2071
|
+
residual_fn = partial(residual_fn, num_input_views = 3)
|
2072
|
+
|
2044
2073
|
elif gate_residual:
|
2045
2074
|
residual_fn = GRUGating
|
2046
2075
|
else:
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=4jgGkk-OkPdZYPouPs768KHw5eFmskC3m61h6LPfkKY,104695
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.44.
|
13
|
-
x_transformers-1.44.
|
14
|
-
x_transformers-1.44.
|
15
|
-
x_transformers-1.44.
|
16
|
-
x_transformers-1.44.
|
12
|
+
x_transformers-1.44.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.7.dist-info/METADATA,sha256=D9xg_-8z4eeOXCwP1w86SFthU9RanucoxlpBLkDCUJ0,924
|
14
|
+
x_transformers-1.44.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
15
|
+
x_transformers-1.44.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|