x-transformers 1.42.28__py3-none-any.whl → 1.43.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +110 -8
- {x_transformers-1.42.28.dist-info → x_transformers-1.43.1.dist-info}/METADATA +1 -1
- {x_transformers-1.42.28.dist-info → x_transformers-1.43.1.dist-info}/RECORD +6 -6
- {x_transformers-1.42.28.dist-info → x_transformers-1.43.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.28.dist-info → x_transformers-1.43.1.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.28.dist-info → x_transformers-1.43.1.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -824,12 +824,15 @@ class SimpleRMSNorm(Module):
|
|
824
824
|
# residual and residual gates
|
825
825
|
|
826
826
|
class Residual(Module):
|
827
|
-
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1
|
827
|
+
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1., **kwargs):
|
828
828
|
super().__init__()
|
829
829
|
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
830
830
|
self.scale_residual_constant = scale_residual_constant
|
831
831
|
|
832
|
-
def
|
832
|
+
def prepare(self, residual):
|
833
|
+
return residual, residual, dict()
|
834
|
+
|
835
|
+
def forward(self, x, residual, **kwargs):
|
833
836
|
if exists(self.residual_scale):
|
834
837
|
residual = residual * self.residual_scale
|
835
838
|
|
@@ -844,7 +847,10 @@ class GRUGating(Module):
|
|
844
847
|
self.gru = nn.GRUCell(dim, dim)
|
845
848
|
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
846
849
|
|
847
|
-
def
|
850
|
+
def prepare(self, residual):
|
851
|
+
return residual, residual, dict()
|
852
|
+
|
853
|
+
def forward(self, x, residual, **kwargs):
|
848
854
|
if exists(self.residual_scale):
|
849
855
|
residual = residual * self.residual_scale
|
850
856
|
|
@@ -855,6 +861,69 @@ class GRUGating(Module):
|
|
855
861
|
|
856
862
|
return gated_output.reshape_as(x)
|
857
863
|
|
864
|
+
# hyper connections
|
865
|
+
|
866
|
+
class HyperConnection(Module):
|
867
|
+
def __init__(
|
868
|
+
self,
|
869
|
+
dim,
|
870
|
+
*,
|
871
|
+
layer_index,
|
872
|
+
num_residual_streams,
|
873
|
+
tanh = True,
|
874
|
+
**kwargs
|
875
|
+
):
|
876
|
+
"""
|
877
|
+
https://arxiv.org/abs/2409.19606
|
878
|
+
Appendix J - Algorithm 2, Dynamic only
|
879
|
+
"""
|
880
|
+
super().__init__()
|
881
|
+
|
882
|
+
self.act = nn.Tanh() if tanh else nn.Identity()
|
883
|
+
|
884
|
+
self.norm = nn.LayerNorm(dim, bias = False)
|
885
|
+
|
886
|
+
self.num_residual_streams = num_residual_streams
|
887
|
+
self.layer_index = layer_index
|
888
|
+
|
889
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
890
|
+
|
891
|
+
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
892
|
+
init_alpha0[layer_index % num_residual_streams, 0] = 1.
|
893
|
+
|
894
|
+
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
895
|
+
|
896
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
|
897
|
+
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
898
|
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
899
|
+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
900
|
+
|
901
|
+
def prepare(self, residuals):
|
902
|
+
|
903
|
+
residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)
|
904
|
+
|
905
|
+
normed = self.norm(residuals)
|
906
|
+
|
907
|
+
wc_weight = self.act(normed @ self.dynamic_alpha_fn)
|
908
|
+
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
909
|
+
alpha = dynamic_alpha + self.static_alpha
|
910
|
+
|
911
|
+
dc_weight = self.act(normed @ self.dynamic_beta_fn)
|
912
|
+
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
913
|
+
beta = dynamic_beta + self.static_beta
|
914
|
+
|
915
|
+
# width connection
|
916
|
+
|
917
|
+
mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
|
918
|
+
|
919
|
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
920
|
+
|
921
|
+
return branch_input, residuals, dict(beta = beta)
|
922
|
+
|
923
|
+
def forward(self, x, residuals, *, beta):
|
924
|
+
residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
|
925
|
+
return rearrange(residuals, 'b n s d -> (b s) n d')
|
926
|
+
|
858
927
|
# token shifting
|
859
928
|
|
860
929
|
def shift(t, amount, mask = None):
|
@@ -1582,10 +1651,12 @@ class AttentionLayers(Module):
|
|
1582
1651
|
use_layerscale = False,
|
1583
1652
|
layerscale_init_value = 0.,
|
1584
1653
|
unet_skips = False,
|
1654
|
+
num_residual_streams = 1,
|
1585
1655
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1586
|
-
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
|
1656
|
+
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
|
1587
1657
|
learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
|
1588
1658
|
rel_pos_kwargs: dict = dict(),
|
1659
|
+
residual_fn_kwargs: dict = dict(),
|
1589
1660
|
**kwargs
|
1590
1661
|
):
|
1591
1662
|
super().__init__()
|
@@ -1607,6 +1678,17 @@ class AttentionLayers(Module):
|
|
1607
1678
|
self.causal = causal
|
1608
1679
|
self.layers = ModuleList([])
|
1609
1680
|
|
1681
|
+
# greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
|
1682
|
+
|
1683
|
+
assert num_residual_streams > 0
|
1684
|
+
|
1685
|
+
self.num_residual_streams = num_residual_streams
|
1686
|
+
self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
|
1687
|
+
|
1688
|
+
assert not (num_residual_streams > 1 and gate_residual)
|
1689
|
+
|
1690
|
+
# positions related
|
1691
|
+
|
1610
1692
|
self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
|
1611
1693
|
|
1612
1694
|
rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
|
@@ -1872,9 +1954,14 @@ class AttentionLayers(Module):
|
|
1872
1954
|
if exists(post_branch_fn):
|
1873
1955
|
layer = post_branch_fn(layer)
|
1874
1956
|
|
1875
|
-
|
1957
|
+
if num_residual_streams > 1:
|
1958
|
+
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
|
1959
|
+
elif gate_residual:
|
1960
|
+
residual_fn = GRUGating
|
1961
|
+
else:
|
1962
|
+
residual_fn = Residual
|
1876
1963
|
|
1877
|
-
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1964
|
+
residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant, **residual_fn_kwargs)
|
1878
1965
|
|
1879
1966
|
# handle unet skip connection
|
1880
1967
|
|
@@ -2024,6 +2111,16 @@ class AttentionLayers(Module):
|
|
2024
2111
|
|
2025
2112
|
iter_attn_cache = iter(attn_cache)
|
2026
2113
|
|
2114
|
+
# setup multistreams if needed
|
2115
|
+
|
2116
|
+
streams = self.num_residual_streams
|
2117
|
+
is_multistream = streams > 1
|
2118
|
+
|
2119
|
+
if is_multistream:
|
2120
|
+
x = repeat(x, 'b n d -> b n s d', s = streams)
|
2121
|
+
x = x + self.stream_emb
|
2122
|
+
x = rearrange(x, 'b n s d -> (b s) n d')
|
2123
|
+
|
2027
2124
|
# outer residual - for resiDual paper
|
2028
2125
|
|
2029
2126
|
outer_residual = x * self.resi_dual_scale
|
@@ -2090,7 +2187,7 @@ class AttentionLayers(Module):
|
|
2090
2187
|
if self.training and self.cross_attn_tokens_dropout > 0.:
|
2091
2188
|
context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
|
2092
2189
|
|
2093
|
-
inner_residual = x
|
2190
|
+
x, inner_residual, residual_kwargs = residual_fn.prepare(x)
|
2094
2191
|
|
2095
2192
|
if return_hiddens:
|
2096
2193
|
layer_hiddens.append(x)
|
@@ -2148,7 +2245,7 @@ class AttentionLayers(Module):
|
|
2148
2245
|
if exists(post_branch_norm):
|
2149
2246
|
out = post_branch_norm(out)
|
2150
2247
|
|
2151
|
-
x = residual_fn(out, inner_residual)
|
2248
|
+
x = residual_fn(out, inner_residual, **residual_kwargs)
|
2152
2249
|
|
2153
2250
|
if layer_type in ('a', 'c') and return_hiddens:
|
2154
2251
|
inter.layer_type = layer_type
|
@@ -2178,6 +2275,11 @@ class AttentionLayers(Module):
|
|
2178
2275
|
else:
|
2179
2276
|
x = final_norm(x)
|
2180
2277
|
|
2278
|
+
# take care of multistreams if needed, use sum for now
|
2279
|
+
|
2280
|
+
if is_multistream:
|
2281
|
+
x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
|
2282
|
+
|
2181
2283
|
if not return_hiddens:
|
2182
2284
|
return x
|
2183
2285
|
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=JG38kcXdhRBKT5_FHMhV5dQabSGrAHsuIQkHjPalDiI,100384
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
16
|
-
x_transformers-1.
|
12
|
+
x_transformers-1.43.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.43.1.dist-info/METADATA,sha256=V57c6Bps0GjG0GLEBpxkHdbvxIWzXss2Xu5_KQJJXPc,738
|
14
|
+
x_transformers-1.43.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.43.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.43.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|