x-transformers 2.0.5__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +86 -13
- {x_transformers-2.0.5.dist-info → x_transformers-2.1.1.dist-info}/METADATA +10 -1
- {x_transformers-2.0.5.dist-info → x_transformers-2.1.1.dist-info}/RECORD +5 -5
- {x_transformers-2.0.5.dist-info → x_transformers-2.1.1.dist-info}/WHEEL +0 -0
- {x_transformers-2.0.5.dist-info → x_transformers-2.1.1.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -9,7 +9,7 @@ from packaging import version
|
|
9
9
|
import torch
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
|
-
from torch import nn, einsum, Tensor, cat, stack, arange
|
12
|
+
from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
|
13
13
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
@@ -962,6 +962,45 @@ class HyperConnection(Module):
|
|
962
962
|
residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
|
963
963
|
return rearrange(residuals, 'b n s d -> (b s) n d')
|
964
964
|
|
965
|
+
# LIMe - layer integrated memory (dynamic version)
|
966
|
+
|
967
|
+
class DynamicLIMe(Module):
|
968
|
+
def __init__(
|
969
|
+
self,
|
970
|
+
dim,
|
971
|
+
num_layers,
|
972
|
+
num_views = 1,
|
973
|
+
use_softmax = True
|
974
|
+
):
|
975
|
+
super().__init__()
|
976
|
+
self.num_layers = num_layers
|
977
|
+
self.multiple_views = num_views > 1
|
978
|
+
|
979
|
+
self.to_weights = Sequential(
|
980
|
+
nn.Linear(dim, num_views * num_layers),
|
981
|
+
Rearrange('... (views layers) -> views ... layers', views = num_views),
|
982
|
+
nn.Softmax(dim = -1) if use_softmax else nn.ReLU()
|
983
|
+
)
|
984
|
+
|
985
|
+
def forward(
|
986
|
+
self,
|
987
|
+
x,
|
988
|
+
hiddens
|
989
|
+
):
|
990
|
+
if not is_tensor(hiddens):
|
991
|
+
hiddens = stack(hiddens)
|
992
|
+
|
993
|
+
assert hiddens.shape[0] == self.num_layers, f'expected hiddens to have {self.num_layers} layers but received {tuple(hiddens.shape)} instead (first dimension must be layers)'
|
994
|
+
|
995
|
+
weights = self.to_weights(x)
|
996
|
+
|
997
|
+
out = einsum('l b n d, v b n l -> v b n d', hiddens, weights)
|
998
|
+
|
999
|
+
if self.multiple_views:
|
1000
|
+
return out
|
1001
|
+
|
1002
|
+
return rearrange(out, '1 ... -> ...')
|
1003
|
+
|
965
1004
|
# token shifting
|
966
1005
|
|
967
1006
|
def shift(t, amount, mask = None):
|
@@ -1306,7 +1345,7 @@ class Attention(Module):
|
|
1306
1345
|
|
1307
1346
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
1308
1347
|
|
1309
|
-
# whether qkv receives different residual stream combinations from hyper connections
|
1348
|
+
# whether qkv receives different residual stream combinations from hyper connections or lime
|
1310
1349
|
|
1311
1350
|
self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
|
1312
1351
|
|
@@ -1869,6 +1908,8 @@ class AttentionLayers(Module):
|
|
1869
1908
|
use_layerscale = False,
|
1870
1909
|
layerscale_init_value = 0.,
|
1871
1910
|
unet_skips = False,
|
1911
|
+
integrate_layers = False,
|
1912
|
+
layer_integrate_use_softmax = True,
|
1872
1913
|
num_residual_streams = 1,
|
1873
1914
|
qkv_receive_diff_residuals = False,
|
1874
1915
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
@@ -1895,16 +1936,30 @@ class AttentionLayers(Module):
|
|
1895
1936
|
self.causal = causal
|
1896
1937
|
self.layers = ModuleList([])
|
1897
1938
|
|
1898
|
-
#
|
1939
|
+
# routing related
|
1940
|
+
# 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
|
1941
|
+
# 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
|
1942
|
+
|
1943
|
+
qkv_receive_diff_residuals |= integrate_layers # qkv always receives different views if integrating layers
|
1944
|
+
|
1945
|
+
# hyper connections
|
1899
1946
|
|
1900
1947
|
assert num_residual_streams > 0
|
1948
|
+
has_hyper_connections = num_residual_streams > 1
|
1901
1949
|
|
1902
1950
|
self.num_residual_streams = num_residual_streams
|
1903
1951
|
self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
|
1904
1952
|
|
1905
|
-
assert not (
|
1953
|
+
assert not (has_hyper_connections and gate_residual)
|
1954
|
+
|
1955
|
+
hyper_conn_produce_diff_views = qkv_receive_diff_residuals and not integrate_layers
|
1956
|
+
|
1957
|
+
# LIMe
|
1906
1958
|
|
1907
|
-
|
1959
|
+
hiddens_counter = 0
|
1960
|
+
self.layer_integrators = ModuleList([])
|
1961
|
+
|
1962
|
+
assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))
|
1908
1963
|
|
1909
1964
|
# positions related
|
1910
1965
|
|
@@ -2145,16 +2200,22 @@ class AttentionLayers(Module):
|
|
2145
2200
|
|
2146
2201
|
# attention, cross attention, feedforward
|
2147
2202
|
|
2203
|
+
layer_qkv_receives_diff_view = layer_type == 'a' and qkv_receive_diff_residuals and not (is_first_self_attn and integrate_layers)
|
2204
|
+
|
2148
2205
|
if layer_type == 'a':
|
2149
2206
|
self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
|
2150
|
-
|
2207
|
+
|
2208
|
+
layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = layer_qkv_receives_diff_view, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
|
2151
2209
|
is_first_self_attn = False
|
2210
|
+
|
2152
2211
|
elif layer_type == 'c':
|
2153
2212
|
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
|
2154
2213
|
is_first_cross_attn = False
|
2214
|
+
|
2155
2215
|
elif layer_type == 'f':
|
2156
2216
|
layer = FeedForward(dim, **ff_kwargs)
|
2157
2217
|
layer = layer if not macaron else Scale(0.5, layer)
|
2218
|
+
|
2158
2219
|
else:
|
2159
2220
|
raise Exception(f'invalid layer type {layer_type}')
|
2160
2221
|
|
@@ -2166,10 +2227,18 @@ class AttentionLayers(Module):
|
|
2166
2227
|
if exists(post_branch_fn):
|
2167
2228
|
layer = post_branch_fn(layer)
|
2168
2229
|
|
2169
|
-
|
2230
|
+
layer_integrate = None
|
2231
|
+
|
2232
|
+
if integrate_layers:
|
2233
|
+
num_layer_hiddens = ind + 1
|
2234
|
+
layer_integrate_num_view = 3 if layer_qkv_receives_diff_view else 1
|
2235
|
+
|
2236
|
+
layer_integrate = DynamicLIMe(dim, num_layer_hiddens, num_views = layer_integrate_num_view, use_softmax = layer_integrate_use_softmax)
|
2237
|
+
|
2238
|
+
if has_hyper_connections:
|
2170
2239
|
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
|
2171
2240
|
|
2172
|
-
if layer_type == 'a' and
|
2241
|
+
if layer_type == 'a' and hyper_conn_produce_diff_views:
|
2173
2242
|
residual_fn = partial(residual_fn, num_input_views = 3)
|
2174
2243
|
|
2175
2244
|
elif gate_residual:
|
@@ -2201,6 +2270,8 @@ class AttentionLayers(Module):
|
|
2201
2270
|
|
2202
2271
|
self.skip_combines.append(skip_combine)
|
2203
2272
|
|
2273
|
+
self.layer_integrators.append(layer_integrate)
|
2274
|
+
|
2204
2275
|
self.layers.append(ModuleList([
|
2205
2276
|
norms,
|
2206
2277
|
layer,
|
@@ -2341,13 +2412,13 @@ class AttentionLayers(Module):
|
|
2341
2412
|
self.layer_types,
|
2342
2413
|
self.skip_combines,
|
2343
2414
|
self.layers,
|
2344
|
-
self.layer_dropouts
|
2415
|
+
self.layer_dropouts,
|
2416
|
+
self.layer_integrators
|
2345
2417
|
)
|
2346
2418
|
|
2347
2419
|
# able to override the layers execution order on forward, for trying to depth extrapolate
|
2348
2420
|
|
2349
2421
|
layers_execute_order = default(layers_execute_order, self.layers_execute_order)
|
2350
|
-
|
2351
2422
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
2352
2423
|
|
2353
2424
|
# derived input for reinjection if needed
|
@@ -2377,7 +2448,7 @@ class AttentionLayers(Module):
|
|
2377
2448
|
|
2378
2449
|
# go through the attention and feedforward layers
|
2379
2450
|
|
2380
|
-
for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
|
2451
|
+
for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout, layer_integrator) in enumerate(zip(*layer_variables)):
|
2381
2452
|
is_last = ind == (len(self.layers) - 1)
|
2382
2453
|
|
2383
2454
|
# handle skip connections
|
@@ -2405,8 +2476,10 @@ class AttentionLayers(Module):
|
|
2405
2476
|
|
2406
2477
|
x, inner_residual, residual_kwargs = residual_fn.prepare(x)
|
2407
2478
|
|
2408
|
-
|
2409
|
-
|
2479
|
+
layer_hiddens.append(x)
|
2480
|
+
|
2481
|
+
if exists(layer_integrator):
|
2482
|
+
x = layer_integrator(x, layer_hiddens)
|
2410
2483
|
|
2411
2484
|
pre_norm, post_branch_norm, post_main_norm = norm
|
2412
2485
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.1.1
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2427,4 +2427,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2427
2427
|
}
|
2428
2428
|
```
|
2429
2429
|
|
2430
|
+
```bibtex
|
2431
|
+
@inproceedings{Gerasimov2025YouDN,
|
2432
|
+
title = {You Do Not Fully Utilize Transformer's Representation Capacity},
|
2433
|
+
author = {Gleb Gerasimov and Yaroslav Aksenov and Nikita Balagansky and Viacheslav Sinii and Daniil Gavrilov},
|
2434
|
+
year = {2025},
|
2435
|
+
url = {https://api.semanticscholar.org/CorpusID:276317819}
|
2436
|
+
}
|
2437
|
+
```
|
2438
|
+
|
2430
2439
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -6,10 +6,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=bIlP-NHj0SB2joklpxicoaD1HVpRMGIulMF8WYEsOAQ,110076
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-2.
|
13
|
-
x_transformers-2.
|
14
|
-
x_transformers-2.
|
15
|
-
x_transformers-2.
|
12
|
+
x_transformers-2.1.1.dist-info/METADATA,sha256=BBGKnocyDvj_ynWM5dtrbyX1iodI4eWEnn9TWrw38kc,87275
|
13
|
+
x_transformers-2.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
14
|
+
x_transformers-2.1.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
15
|
+
x_transformers-2.1.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|