x-transformers 1.44.5__py3-none-any.whl → 1.44.7__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +48 -11
- {x_transformers-1.44.5.dist-info → x_transformers-1.44.7.dist-info}/METADATA +1 -1
- {x_transformers-1.44.5.dist-info → x_transformers-1.44.7.dist-info}/RECORD +6 -6
- {x_transformers-1.44.5.dist-info → x_transformers-1.44.7.dist-info}/LICENSE +0 -0
- {x_transformers-1.44.5.dist-info → x_transformers-1.44.7.dist-info}/WHEEL +0 -0
- {x_transformers-1.44.5.dist-info → x_transformers-1.44.7.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -882,6 +882,7 @@ class HyperConnection(Module):
|
|
882
882
|
*,
|
883
883
|
layer_index,
|
884
884
|
num_residual_streams,
|
885
|
+
num_input_views = 1,
|
885
886
|
tanh = True,
|
886
887
|
**kwargs
|
887
888
|
):
|
@@ -900,13 +901,16 @@ class HyperConnection(Module):
|
|
900
901
|
|
901
902
|
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
902
903
|
|
903
|
-
init_alpha0 = torch.zeros((num_residual_streams,
|
904
|
-
init_alpha0[layer_index % num_residual_streams,
|
904
|
+
init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
|
905
|
+
init_alpha0[layer_index % num_residual_streams, :] = 1.
|
905
906
|
|
906
907
|
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
907
908
|
|
908
|
-
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams +
|
909
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
|
909
910
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
911
|
+
|
912
|
+
self.num_input_views = num_input_views
|
913
|
+
|
910
914
|
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
911
915
|
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
912
916
|
|
@@ -928,7 +932,13 @@ class HyperConnection(Module):
|
|
928
932
|
|
929
933
|
mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
|
930
934
|
|
931
|
-
|
935
|
+
views = self.num_input_views
|
936
|
+
|
937
|
+
if views == 1:
|
938
|
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
939
|
+
else:
|
940
|
+
branch_input, residuals = mix_h[..., :views, :], mix_h[..., views:, :]
|
941
|
+
branch_input = rearrange(branch_input, '... v d -> v ... d')
|
932
942
|
|
933
943
|
return branch_input, residuals, dict(beta = beta)
|
934
944
|
|
@@ -1200,6 +1210,7 @@ class Attention(Module):
|
|
1200
1210
|
learned_value_residual_mix = False,
|
1201
1211
|
laser = False, # https://arxiv.org/abs/2411.03493v1
|
1202
1212
|
laser_softclamp_value = 15.,
|
1213
|
+
qkv_receive_diff_residuals = False,
|
1203
1214
|
onnxable = False,
|
1204
1215
|
attend_sdp_kwargs: dict = dict(
|
1205
1216
|
enable_flash = True,
|
@@ -1239,6 +1250,10 @@ class Attention(Module):
|
|
1239
1250
|
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
|
1240
1251
|
self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
|
1241
1252
|
|
1253
|
+
# whether qkv receives different residual stream combinations from hyper connections
|
1254
|
+
|
1255
|
+
self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
|
1256
|
+
|
1242
1257
|
# enhancing gradients to attention through exponentiated values
|
1243
1258
|
|
1244
1259
|
self.laser = laser
|
@@ -1423,14 +1438,21 @@ class Attention(Module):
|
|
1423
1438
|
cache: Intermediates | None = None,
|
1424
1439
|
value_residual = None
|
1425
1440
|
):
|
1426
|
-
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
|
1441
|
+
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals
|
1427
1442
|
|
1428
|
-
|
1443
|
+
assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'
|
1429
1444
|
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1445
|
+
if qkv_receive_diff_residuals:
|
1446
|
+
assert not exists(self.to_r)
|
1447
|
+
|
1448
|
+
q_input, k_input, v_input = x
|
1449
|
+
else:
|
1450
|
+
kv_input = default(context, x)
|
1451
|
+
|
1452
|
+
q_input = x
|
1453
|
+
k_input = kv_input
|
1454
|
+
v_input = kv_input
|
1455
|
+
r_input = x
|
1434
1456
|
|
1435
1457
|
if exists(mem):
|
1436
1458
|
k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
|
@@ -1735,7 +1757,9 @@ class AttentionLayers(Module):
|
|
1735
1757
|
layerscale_init_value = 0.,
|
1736
1758
|
unet_skips = False,
|
1737
1759
|
num_residual_streams = 1,
|
1760
|
+
qkv_receive_diff_residuals = False,
|
1738
1761
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1762
|
+
learned_reinject_input_gate = False,
|
1739
1763
|
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
|
1740
1764
|
learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
|
1741
1765
|
rel_pos_kwargs: dict = dict(),
|
@@ -1770,6 +1794,8 @@ class AttentionLayers(Module):
|
|
1770
1794
|
|
1771
1795
|
assert not (num_residual_streams > 1 and gate_residual)
|
1772
1796
|
|
1797
|
+
assert not (num_residual_streams == 1 and qkv_receive_diff_residuals)
|
1798
|
+
|
1773
1799
|
# positions related
|
1774
1800
|
|
1775
1801
|
self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
|
@@ -1993,6 +2019,7 @@ class AttentionLayers(Module):
|
|
1993
2019
|
|
1994
2020
|
self.reinject_input = reinject_input
|
1995
2021
|
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
|
2022
|
+
self.learned_reinject_input_gate = nn.Linear(dim, 1, bias = False) if learned_reinject_input_gate else None
|
1996
2023
|
|
1997
2024
|
# add the value from the first self attention block to all latter projected self attention values as a residual
|
1998
2025
|
|
@@ -2018,7 +2045,7 @@ class AttentionLayers(Module):
|
|
2018
2045
|
|
2019
2046
|
if layer_type == 'a':
|
2020
2047
|
self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
|
2021
|
-
layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
|
2048
|
+
layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receive_diff_residuals, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
|
2022
2049
|
is_first_self_attn = False
|
2023
2050
|
elif layer_type == 'c':
|
2024
2051
|
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
|
@@ -2039,6 +2066,10 @@ class AttentionLayers(Module):
|
|
2039
2066
|
|
2040
2067
|
if num_residual_streams > 1:
|
2041
2068
|
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
|
2069
|
+
|
2070
|
+
if layer_type == 'a':
|
2071
|
+
residual_fn = partial(residual_fn, num_input_views = 3)
|
2072
|
+
|
2042
2073
|
elif gate_residual:
|
2043
2074
|
residual_fn = GRUGating
|
2044
2075
|
else:
|
@@ -2224,7 +2255,9 @@ class AttentionLayers(Module):
|
|
2224
2255
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
2225
2256
|
|
2226
2257
|
# derived input for reinjection if needed
|
2258
|
+
|
2227
2259
|
inp_inject = None
|
2260
|
+
|
2228
2261
|
if self.reinject_input:
|
2229
2262
|
assert not exists(in_attn_cond)
|
2230
2263
|
inp_inject = self.reinject_input_proj(x)
|
@@ -2233,6 +2266,10 @@ class AttentionLayers(Module):
|
|
2233
2266
|
# handle in-attention conditioning, which serves the same purpose of having the network learn the residual
|
2234
2267
|
inp_inject = in_attn_cond if in_attn_cond.ndim == 3 else rearrange(in_attn_cond, 'b d -> b 1 d')
|
2235
2268
|
|
2269
|
+
if exists(inp_inject) and exists(self.learned_reinject_input_gate):
|
2270
|
+
inp_inject_gate = self.learned_reinject_input_gate(x).sigmoid()
|
2271
|
+
inp_inject = inp_inject * inp_inject_gate
|
2272
|
+
|
2236
2273
|
# store all hiddens for skips
|
2237
2274
|
|
2238
2275
|
skip_hiddens = []
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256
|
9
|
+
x_transformers/x_transformers.py,sha256=4jgGkk-OkPdZYPouPs768KHw5eFmskC3m61h6LPfkKY,104695
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.44.
|
13
|
-
x_transformers-1.44.
|
14
|
-
x_transformers-1.44.
|
15
|
-
x_transformers-1.44.
|
16
|
-
x_transformers-1.44.
|
12
|
+
x_transformers-1.44.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.7.dist-info/METADATA,sha256=D9xg_-8z4eeOXCwP1w86SFthU9RanucoxlpBLkDCUJ0,924
|
14
|
+
x_transformers-1.44.7.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
15
|
+
x_transformers-1.44.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|