x-transformers 1.42.28__tar.gz → 1.43.1__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.42.28/x_transformers.egg-info → x_transformers-1.43.1}/PKG-INFO +1 -1
- {x_transformers-1.42.28 → x_transformers-1.43.1}/README.md +12 -1
- {x_transformers-1.42.28 → x_transformers-1.43.1}/setup.py +1 -1
- {x_transformers-1.42.28 → x_transformers-1.43.1}/tests/test_x_transformers.py +21 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/x_transformers.py +110 -8
- {x_transformers-1.42.28 → x_transformers-1.43.1/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.42.28 → x_transformers-1.43.1}/LICENSE +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/setup.cfg +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/__init__.py +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/attend.py +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/continuous.py +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/dpo.py +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/xval.py +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers.egg-info/top_level.txt +0 -0
@@ -2240,7 +2240,7 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2240
2240
|
}
|
2241
2241
|
```
|
2242
2242
|
|
2243
|
-
```
|
2243
|
+
```bibtex
|
2244
2244
|
@article{Yang2017BreakingTS,
|
2245
2245
|
title = {Breaking the Softmax Bottleneck: A High-Rank RNN Language Model},
|
2246
2246
|
author = {Zhilin Yang and Zihang Dai and Ruslan Salakhutdinov and William W. Cohen},
|
@@ -2363,4 +2363,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2363
2363
|
}
|
2364
2364
|
```
|
2365
2365
|
|
2366
|
+
```bibtex
|
2367
|
+
@article{Zhu2024HyperConnections,
|
2368
|
+
title = {Hyper-Connections},
|
2369
|
+
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
|
2370
|
+
journal = {ArXiv},
|
2371
|
+
year = {2024},
|
2372
|
+
volume = {abs/2409.19606},
|
2373
|
+
url = {https://api.semanticscholar.org/CorpusID:272987528}
|
2374
|
+
}
|
2375
|
+
```
|
2376
|
+
|
2366
2377
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -590,3 +590,24 @@ def test_cross_attn_rotary(
|
|
590
590
|
context_pos = context_pos,
|
591
591
|
context_mask = context_mask
|
592
592
|
)
|
593
|
+
|
594
|
+
@pytest.mark.parametrize('tanh', (True, False))
|
595
|
+
def test_hyper_connections(tanh):
|
596
|
+
|
597
|
+
model = TransformerWrapper(
|
598
|
+
num_tokens = 20000,
|
599
|
+
max_seq_len = 1024,
|
600
|
+
attn_layers = Decoder(
|
601
|
+
dim = 128,
|
602
|
+
depth = 6,
|
603
|
+
heads = 8,
|
604
|
+
num_residual_streams = 8, # 8 dynamic hyper connection residual streams
|
605
|
+
residual_fn_kwargs = dict(
|
606
|
+
tanh = tanh
|
607
|
+
)
|
608
|
+
)
|
609
|
+
)
|
610
|
+
|
611
|
+
x = torch.randint(0, 20000, (2, 1024))
|
612
|
+
|
613
|
+
model(x)
|
@@ -824,12 +824,15 @@ class SimpleRMSNorm(Module):
|
|
824
824
|
# residual and residual gates
|
825
825
|
|
826
826
|
class Residual(Module):
|
827
|
-
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1
|
827
|
+
def __init__(self, dim, scale_residual = False, scale_residual_constant = 1., **kwargs):
|
828
828
|
super().__init__()
|
829
829
|
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
830
830
|
self.scale_residual_constant = scale_residual_constant
|
831
831
|
|
832
|
-
def
|
832
|
+
def prepare(self, residual):
|
833
|
+
return residual, residual, dict()
|
834
|
+
|
835
|
+
def forward(self, x, residual, **kwargs):
|
833
836
|
if exists(self.residual_scale):
|
834
837
|
residual = residual * self.residual_scale
|
835
838
|
|
@@ -844,7 +847,10 @@ class GRUGating(Module):
|
|
844
847
|
self.gru = nn.GRUCell(dim, dim)
|
845
848
|
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
846
849
|
|
847
|
-
def
|
850
|
+
def prepare(self, residual):
|
851
|
+
return residual, residual, dict()
|
852
|
+
|
853
|
+
def forward(self, x, residual, **kwargs):
|
848
854
|
if exists(self.residual_scale):
|
849
855
|
residual = residual * self.residual_scale
|
850
856
|
|
@@ -855,6 +861,69 @@ class GRUGating(Module):
|
|
855
861
|
|
856
862
|
return gated_output.reshape_as(x)
|
857
863
|
|
864
|
+
# hyper connections
|
865
|
+
|
866
|
+
class HyperConnection(Module):
|
867
|
+
def __init__(
|
868
|
+
self,
|
869
|
+
dim,
|
870
|
+
*,
|
871
|
+
layer_index,
|
872
|
+
num_residual_streams,
|
873
|
+
tanh = True,
|
874
|
+
**kwargs
|
875
|
+
):
|
876
|
+
"""
|
877
|
+
https://arxiv.org/abs/2409.19606
|
878
|
+
Appendix J - Algorithm 2, Dynamic only
|
879
|
+
"""
|
880
|
+
super().__init__()
|
881
|
+
|
882
|
+
self.act = nn.Tanh() if tanh else nn.Identity()
|
883
|
+
|
884
|
+
self.norm = nn.LayerNorm(dim, bias = False)
|
885
|
+
|
886
|
+
self.num_residual_streams = num_residual_streams
|
887
|
+
self.layer_index = layer_index
|
888
|
+
|
889
|
+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
|
890
|
+
|
891
|
+
init_alpha0 = torch.zeros((num_residual_streams, 1))
|
892
|
+
init_alpha0[layer_index % num_residual_streams, 0] = 1.
|
893
|
+
|
894
|
+
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
895
|
+
|
896
|
+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
|
897
|
+
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
898
|
+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
|
899
|
+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
900
|
+
|
901
|
+
def prepare(self, residuals):
|
902
|
+
|
903
|
+
residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)
|
904
|
+
|
905
|
+
normed = self.norm(residuals)
|
906
|
+
|
907
|
+
wc_weight = self.act(normed @ self.dynamic_alpha_fn)
|
908
|
+
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
|
909
|
+
alpha = dynamic_alpha + self.static_alpha
|
910
|
+
|
911
|
+
dc_weight = self.act(normed @ self.dynamic_beta_fn)
|
912
|
+
dynamic_beta = dc_weight * self.dynamic_beta_scale
|
913
|
+
beta = dynamic_beta + self.static_beta
|
914
|
+
|
915
|
+
# width connection
|
916
|
+
|
917
|
+
mix_h = einsum('... s t, ... s d -> ... t d', alpha, residuals)
|
918
|
+
|
919
|
+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
|
920
|
+
|
921
|
+
return branch_input, residuals, dict(beta = beta)
|
922
|
+
|
923
|
+
def forward(self, x, residuals, *, beta):
|
924
|
+
residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
|
925
|
+
return rearrange(residuals, 'b n s d -> (b s) n d')
|
926
|
+
|
858
927
|
# token shifting
|
859
928
|
|
860
929
|
def shift(t, amount, mask = None):
|
@@ -1582,10 +1651,12 @@ class AttentionLayers(Module):
|
|
1582
1651
|
use_layerscale = False,
|
1583
1652
|
layerscale_init_value = 0.,
|
1584
1653
|
unet_skips = False,
|
1654
|
+
num_residual_streams = 1,
|
1585
1655
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1586
|
-
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
|
1656
|
+
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 - further corroboration by https://arxiv.org/abs/2412.15113 (faster emergence of ICL) - looks like this setting may becoming a necessity for every transformer soon
|
1587
1657
|
learned_value_residual_mix = True, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
|
1588
1658
|
rel_pos_kwargs: dict = dict(),
|
1659
|
+
residual_fn_kwargs: dict = dict(),
|
1589
1660
|
**kwargs
|
1590
1661
|
):
|
1591
1662
|
super().__init__()
|
@@ -1607,6 +1678,17 @@ class AttentionLayers(Module):
|
|
1607
1678
|
self.causal = causal
|
1608
1679
|
self.layers = ModuleList([])
|
1609
1680
|
|
1681
|
+
# greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
|
1682
|
+
|
1683
|
+
assert num_residual_streams > 0
|
1684
|
+
|
1685
|
+
self.num_residual_streams = num_residual_streams
|
1686
|
+
self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
|
1687
|
+
|
1688
|
+
assert not (num_residual_streams > 1 and gate_residual)
|
1689
|
+
|
1690
|
+
# positions related
|
1691
|
+
|
1610
1692
|
self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
|
1611
1693
|
|
1612
1694
|
rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
|
@@ -1872,9 +1954,14 @@ class AttentionLayers(Module):
|
|
1872
1954
|
if exists(post_branch_fn):
|
1873
1955
|
layer = post_branch_fn(layer)
|
1874
1956
|
|
1875
|
-
|
1957
|
+
if num_residual_streams > 1:
|
1958
|
+
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
|
1959
|
+
elif gate_residual:
|
1960
|
+
residual_fn = GRUGating
|
1961
|
+
else:
|
1962
|
+
residual_fn = Residual
|
1876
1963
|
|
1877
|
-
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1964
|
+
residual = residual_fn(dim, layer_index = ind, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant, **residual_fn_kwargs)
|
1878
1965
|
|
1879
1966
|
# handle unet skip connection
|
1880
1967
|
|
@@ -2024,6 +2111,16 @@ class AttentionLayers(Module):
|
|
2024
2111
|
|
2025
2112
|
iter_attn_cache = iter(attn_cache)
|
2026
2113
|
|
2114
|
+
# setup multistreams if needed
|
2115
|
+
|
2116
|
+
streams = self.num_residual_streams
|
2117
|
+
is_multistream = streams > 1
|
2118
|
+
|
2119
|
+
if is_multistream:
|
2120
|
+
x = repeat(x, 'b n d -> b n s d', s = streams)
|
2121
|
+
x = x + self.stream_emb
|
2122
|
+
x = rearrange(x, 'b n s d -> (b s) n d')
|
2123
|
+
|
2027
2124
|
# outer residual - for resiDual paper
|
2028
2125
|
|
2029
2126
|
outer_residual = x * self.resi_dual_scale
|
@@ -2090,7 +2187,7 @@ class AttentionLayers(Module):
|
|
2090
2187
|
if self.training and self.cross_attn_tokens_dropout > 0.:
|
2091
2188
|
context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
|
2092
2189
|
|
2093
|
-
inner_residual = x
|
2190
|
+
x, inner_residual, residual_kwargs = residual_fn.prepare(x)
|
2094
2191
|
|
2095
2192
|
if return_hiddens:
|
2096
2193
|
layer_hiddens.append(x)
|
@@ -2148,7 +2245,7 @@ class AttentionLayers(Module):
|
|
2148
2245
|
if exists(post_branch_norm):
|
2149
2246
|
out = post_branch_norm(out)
|
2150
2247
|
|
2151
|
-
x = residual_fn(out, inner_residual)
|
2248
|
+
x = residual_fn(out, inner_residual, **residual_kwargs)
|
2152
2249
|
|
2153
2250
|
if layer_type in ('a', 'c') and return_hiddens:
|
2154
2251
|
inter.layer_type = layer_type
|
@@ -2178,6 +2275,11 @@ class AttentionLayers(Module):
|
|
2178
2275
|
else:
|
2179
2276
|
x = final_norm(x)
|
2180
2277
|
|
2278
|
+
# take care of multistreams if needed, use sum for now
|
2279
|
+
|
2280
|
+
if is_multistream:
|
2281
|
+
x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
|
2282
|
+
|
2181
2283
|
if not return_hiddens:
|
2182
2284
|
return x
|
2183
2285
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/nonautoregressive_wrapper.py
RENAMED
File without changes
|
{x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers/xl_autoregressive_wrapper.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.42.28 → x_transformers-1.43.1}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|