x-transformers 1.42.28__py3-none-any.whl → 1.43.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +105 -7
- {x_transformers-1.42.28.dist-info → x_transformers-1.43.0.dist-info}/METADATA +1 -1
- {x_transformers-1.42.28.dist-info → x_transformers-1.43.0.dist-info}/RECORD +6 -6
- {x_transformers-1.42.28.dist-info → x_transformers-1.43.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.28.dist-info → x_transformers-1.43.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.28.dist-info → x_transformers-1.43.0.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -824,12 +824,15 @@ class SimpleRMSNorm(Module):
|
|
824
824
|
# residual and residual gates
|
825
825
|
|
826
826
|
class Residual(Module):
|
827
|
-
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1
|
827
|
+
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1., **kwargs):
|
828
828
|
super().__init__()
|
829
829
|
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
830
830
|
self.scale_residual_constant = scale_residual_constant
|
831
831
|
|
832
|
-
def
|
832
|
+
def prepare(self, residual):
|
833
|
+
return residual, residual, dict()
|
834
|
+
|
835
|
+
def forward(self, x, residual, **kwargs):
|
833
836
|
if exists(self.residual_scale):
|
834
837
|
residual = residual * self.residual_scale
|
835
838
|
|
@@ -844,7 +847,10 @@ class GRUGating(Module):
|
|
844
847
|
self.gru = nn.GRUCell(dim, dim)
|
845
848
|
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
846
849
|
|
847
|
-
def
|
850
|
+
def prepare(self, residual):
|
851
|
+
return residual, residual, dict()
|
852
|
+
|
853
|
+
def forward(self, x, residual, **kwargs):
|
848
854
|
if exists(self.residual_scale):
|
849
855
|
residual = residual * self.residual_scale
|
850
856
|
|
@@ -855,6 +861,66 @@ class GRUGating(Module):
|
|
855
861
|
|
856
862
|
return gated_output.reshape_as(x)
|
857
863
|
|
864
|
+
# hyper connections
|
865
|
+
|
866
|
+
class HyperConnection(Module):
|
867
|
+
def __init__(
|
868
|
+
self,
|
869
|
+
dim,
|
870
|
+
*,
|
871
|
+
layer_index,
|
872
|
+
num_residual_streams,
|
873
|
+
**kwargs
|
874
|
+
):
|
875
|
+
"""
|
876
|
+
https://arxiv.org/abs/2409.19606
|
877
|
+
Appendix J - Algorithm 2, Dynamic only
|
878
|
+
"""
|
879
|
+
super().__init__()
|
880
|
+
|
881
|
+
self.norm = nn.LayerNorm(dim, bias = False)
|
882
|
+
|
883
|
+
self.num_residual_streams = num_residual_streams
|
884
|
+
self.layer_index = layer_index
|
885
|
+
|
886
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
887
|
+
|
888
|
+
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
889
|
+
init_alpha0[layer_index % num_residual_streams, 0] = 1.
|
890
|
+
|
891
|
+
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
892
|
+
|
893
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
|
894
|
+
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
895
|
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
896
|
+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
897
|
+
|
898
|
+
def prepare(self, residuals):
|
899
|
+
|
900
|
+
residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)
|
901
|
+
|
902
|
+
normed = self.norm(residuals)
|
903
|
+
|
904
|
+
wc_weight = (normed @ self.dynamic_alpha_fn).tanh()
|
905
|
+
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
906
|
+
alpha = dynamic_alpha + self.static_alpha
|
907
|
+
|
908
|
+
dc_weight = (normed @ self.dynamic_beta_fn).tanh()
|
909
|
+
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
910
|
+
beta = dynamic_beta + self.static_beta
|
911
|
+
|
912
|
+
# width connection
|
913
|
+
|
914
|
+
mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
|
915
|
+
|
916
|
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
917
|
+
|
918
|
+
return branch_input, residuals, dict(beta = beta)
|
919
|
+
|
920
|
+
def forward(self, x, residuals, *, beta):
|
921
|
+
residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
|
922
|
+
return rearrange(residuals, 'b n s d -> (b s) n d')
|
923
|
+
|
858
924
|
# token shifting
|
859
925
|
|
860
926
|
def shift(t, amount, mask = None):
|
@@ -1582,6 +1648,7 @@ class AttentionLayers(Module):
|
|
1582
1648
|
use_layerscale = False,
|
1583
1649
|
layerscale_init_value = 0.,
|
1584
1650
|
unet_skips = False,
|
1651
|
+
num_residual_streams = 1,
|
1585
1652
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1586
1653
|
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
|
1587
1654
|
learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
|
@@ -1607,6 +1674,17 @@ class AttentionLayers(Module):
|
|
1607
1674
|
self.causal = causal
|
1608
1675
|
self.layers = ModuleList([])
|
1609
1676
|
|
1677
|
+
# greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
|
1678
|
+
|
1679
|
+
assert num_residual_streams > 0
|
1680
|
+
|
1681
|
+
self.num_residual_streams = num_residual_streams
|
1682
|
+
self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
|
1683
|
+
|
1684
|
+
assert not (num_residual_streams > 1 and gate_residual)
|
1685
|
+
|
1686
|
+
# positions related
|
1687
|
+
|
1610
1688
|
self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
|
1611
1689
|
|
1612
1690
|
rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
|
@@ -1872,9 +1950,14 @@ class AttentionLayers(Module):
|
|
1872
1950
|
if exists(post_branch_fn):
|
1873
1951
|
layer = post_branch_fn(layer)
|
1874
1952
|
|
1875
|
-
|
1953
|
+
if num_residual_streams > 1:
|
1954
|
+
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
|
1955
|
+
elif gate_residual:
|
1956
|
+
residual_fn = GRUGating
|
1957
|
+
else:
|
1958
|
+
residual_fn = Residual
|
1876
1959
|
|
1877
|
-
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1960
|
+
residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1878
1961
|
|
1879
1962
|
# handle unet skip connection
|
1880
1963
|
|
@@ -2024,6 +2107,16 @@ class AttentionLayers(Module):
|
|
2024
2107
|
|
2025
2108
|
iter_attn_cache = iter(attn_cache)
|
2026
2109
|
|
2110
|
+
# setup multistreams if needed
|
2111
|
+
|
2112
|
+
streams = self.num_residual_streams
|
2113
|
+
is_multistream = streams > 1
|
2114
|
+
|
2115
|
+
if is_multistream:
|
2116
|
+
x = repeat(x, 'b n d -> b n s d', s = streams)
|
2117
|
+
x = x + self.stream_emb
|
2118
|
+
x = rearrange(x, 'b n s d -> (b s) n d')
|
2119
|
+
|
2027
2120
|
# outer residual - for resiDual paper
|
2028
2121
|
|
2029
2122
|
outer_residual = x * self.resi_dual_scale
|
@@ -2090,7 +2183,7 @@ class AttentionLayers(Module):
|
|
2090
2183
|
if self.training and self.cross_attn_tokens_dropout > 0.:
|
2091
2184
|
context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
|
2092
2185
|
|
2093
|
-
inner_residual = x
|
2186
|
+
x, inner_residual, residual_kwargs = residual_fn.prepare(x)
|
2094
2187
|
|
2095
2188
|
if return_hiddens:
|
2096
2189
|
layer_hiddens.append(x)
|
@@ -2148,7 +2241,7 @@ class AttentionLayers(Module):
|
|
2148
2241
|
if exists(post_branch_norm):
|
2149
2242
|
out = post_branch_norm(out)
|
2150
2243
|
|
2151
|
-
x = residual_fn(out, inner_residual)
|
2244
|
+
x = residual_fn(out, inner_residual, **residual_kwargs)
|
2152
2245
|
|
2153
2246
|
if layer_type in ('a', 'c') and return_hiddens:
|
2154
2247
|
inter.layer_type = layer_type
|
@@ -2178,6 +2271,11 @@ class AttentionLayers(Module):
|
|
2178
2271
|
else:
|
2179
2272
|
x = final_norm(x)
|
2180
2273
|
|
2274
|
+
# take care of multistreams if needed, use sum for now
|
2275
|
+
|
2276
|
+
if is_multistream:
|
2277
|
+
x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
|
2278
|
+
|
2181
2279
|
if not return_hiddens:
|
2182
2280
|
return x
|
2183
2281
|
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=wAY0lqZvFlXk-fmpr4Ot6yZ6ivzEjetFXTin7z7eA88,100075
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
16
|
-
x_transformers-1.
|
12
|
+
x_transformers-1.43.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.43.0.dist-info/METADATA,sha256=C6eRstMfzmbxQUxNeKnt1Mf-e9pJ45GKNJ8hsc_3uwo,738
|
14
|
+
x_transformers-1.43.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.43.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.43.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|