x-transformers 2.0.5__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +85 -13
- {x_transformers-2.0.5.dist-info → x_transformers-2.1.0.dist-info}/METADATA +10 -1
- {x_transformers-2.0.5.dist-info → x_transformers-2.1.0.dist-info}/RECORD +5 -5
- {x_transformers-2.0.5.dist-info → x_transformers-2.1.0.dist-info}/WHEEL +0 -0
- {x_transformers-2.0.5.dist-info → x_transformers-2.1.0.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -9,7 +9,7 @@ from packaging import version
|
|
9
9
|
import torch
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
|
-
from torch import nn, einsum, Tensor, cat, stack, arange
|
12
|
+
from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
|
13
13
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
@@ -962,6 +962,45 @@ class HyperConnection(Module):
|
|
962
962
|
residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
|
963
963
|
return rearrange(residuals, 'b n s d -> (b s) n d')
|
964
964
|
|
965
|
+
# LIMe - layer integrated memory (dynamic version)
|
966
|
+
|
967
|
+
class DynamicLIMe(Module):
|
968
|
+
def __init__(
|
969
|
+
self,
|
970
|
+
dim,
|
971
|
+
num_layers,
|
972
|
+
num_views = 1,
|
973
|
+
use_softmax = True
|
974
|
+
):
|
975
|
+
super().__init__()
|
976
|
+
self.num_layers = num_layers
|
977
|
+
self.multiple_views = num_views > 1
|
978
|
+
|
979
|
+
self.to_weights = Sequential(
|
980
|
+
nn.Linear(dim, num_views * num_layers),
|
981
|
+
Rearrange('... (views layers) -> views ... layers', views = num_views),
|
982
|
+
nn.Softmax(dim = -1) if use_softmax else nn.ReLU()
|
983
|
+
)
|
984
|
+
|
985
|
+
def forward(
|
986
|
+
self,
|
987
|
+
x,
|
988
|
+
hiddens
|
989
|
+
):
|
990
|
+
if not is_tensor(hiddens):
|
991
|
+
hiddens = stack(hiddens)
|
992
|
+
|
993
|
+
assert hiddens.shape[0] == self.num_layers, f'expected hiddens to have {self.num_layers} layers but received {tuple(hiddens.shape)} instead (first dimension must be layers)'
|
994
|
+
|
995
|
+
weights = self.to_weights(x)
|
996
|
+
|
997
|
+
out = einsum('l b n d, v b n l -> v b n d', hiddens, weights)
|
998
|
+
|
999
|
+
if self.multiple_views:
|
1000
|
+
return out
|
1001
|
+
|
1002
|
+
return rearrange(out, '1 ... -> ...')
|
1003
|
+
|
965
1004
|
# token shifting
|
966
1005
|
|
967
1006
|
def shift(t, amount, mask = None):
|
@@ -1306,7 +1345,7 @@ class Attention(Module):
|
|
1306
1345
|
|
1307
1346
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
1308
1347
|
|
1309
|
-
# whether qkv receives different residual stream combinations from hyper connections
|
1348
|
+
# whether qkv receives different residual stream combinations from hyper connections or lime
|
1310
1349
|
|
1311
1350
|
self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
|
1312
1351
|
|
@@ -1869,6 +1908,8 @@ class AttentionLayers(Module):
|
|
1869
1908
|
use_layerscale = False,
|
1870
1909
|
layerscale_init_value = 0.,
|
1871
1910
|
unet_skips = False,
|
1911
|
+
integrate_layers = False,
|
1912
|
+
layer_integrate_use_softmax = True,
|
1872
1913
|
num_residual_streams = 1,
|
1873
1914
|
qkv_receive_diff_residuals = False,
|
1874
1915
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
@@ -1895,16 +1936,30 @@ class AttentionLayers(Module):
|
|
1895
1936
|
self.causal = causal
|
1896
1937
|
self.layers = ModuleList([])
|
1897
1938
|
|
1898
|
-
#
|
1939
|
+
# routing related
|
1940
|
+
# 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
|
1941
|
+
# 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
|
1942
|
+
|
1943
|
+
qkv_receive_diff_residuals |= integrate_layers # qkv always receives different views if integrating layers
|
1944
|
+
|
1945
|
+
# hyper connections
|
1899
1946
|
|
1900
1947
|
assert num_residual_streams > 0
|
1948
|
+
has_hyper_connections = num_residual_streams > 1
|
1901
1949
|
|
1902
1950
|
self.num_residual_streams = num_residual_streams
|
1903
1951
|
self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
|
1904
1952
|
|
1905
|
-
assert not (
|
1953
|
+
assert not (has_hyper_connections and gate_residual)
|
1954
|
+
|
1955
|
+
hyper_conn_produce_diff_views = qkv_receive_diff_residuals and not integrate_layers
|
1906
1956
|
|
1907
|
-
|
1957
|
+
# LIMe
|
1958
|
+
|
1959
|
+
hiddens_counter = 0
|
1960
|
+
self.layer_integrators = ModuleList([])
|
1961
|
+
|
1962
|
+
assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))
|
1908
1963
|
|
1909
1964
|
# positions related
|
1910
1965
|
|
@@ -2147,14 +2202,19 @@ class AttentionLayers(Module):
|
|
2147
2202
|
|
2148
2203
|
if layer_type == 'a':
|
2149
2204
|
self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
|
2150
|
-
|
2205
|
+
qkv_receives_diff_view = qkv_receive_diff_residuals and not (is_first_self_attn and integrate_layers)
|
2206
|
+
|
2207
|
+
layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receives_diff_view, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
|
2151
2208
|
is_first_self_attn = False
|
2209
|
+
|
2152
2210
|
elif layer_type == 'c':
|
2153
2211
|
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
|
2154
2212
|
is_first_cross_attn = False
|
2213
|
+
|
2155
2214
|
elif layer_type == 'f':
|
2156
2215
|
layer = FeedForward(dim, **ff_kwargs)
|
2157
2216
|
layer = layer if not macaron else Scale(0.5, layer)
|
2217
|
+
|
2158
2218
|
else:
|
2159
2219
|
raise Exception(f'invalid layer type {layer_type}')
|
2160
2220
|
|
@@ -2166,10 +2226,18 @@ class AttentionLayers(Module):
|
|
2166
2226
|
if exists(post_branch_fn):
|
2167
2227
|
layer = post_branch_fn(layer)
|
2168
2228
|
|
2169
|
-
|
2229
|
+
layer_integrate = None
|
2230
|
+
|
2231
|
+
if integrate_layers:
|
2232
|
+
num_layer_hiddens = ind + 1
|
2233
|
+
layer_integrate_num_view = 3 if layer_type == 'a' and qkv_receives_diff_view else 1
|
2234
|
+
|
2235
|
+
layer_integrate = DynamicLIMe(dim, num_layer_hiddens, num_views = layer_integrate_num_view, use_softmax = layer_integrate_use_softmax)
|
2236
|
+
|
2237
|
+
if has_hyper_connections:
|
2170
2238
|
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
|
2171
2239
|
|
2172
|
-
if layer_type == 'a' and
|
2240
|
+
if layer_type == 'a' and hyper_conn_produce_diff_views:
|
2173
2241
|
residual_fn = partial(residual_fn, num_input_views = 3)
|
2174
2242
|
|
2175
2243
|
elif gate_residual:
|
@@ -2201,6 +2269,8 @@ class AttentionLayers(Module):
|
|
2201
2269
|
|
2202
2270
|
self.skip_combines.append(skip_combine)
|
2203
2271
|
|
2272
|
+
self.layer_integrators.append(layer_integrate)
|
2273
|
+
|
2204
2274
|
self.layers.append(ModuleList([
|
2205
2275
|
norms,
|
2206
2276
|
layer,
|
@@ -2341,13 +2411,13 @@ class AttentionLayers(Module):
|
|
2341
2411
|
self.layer_types,
|
2342
2412
|
self.skip_combines,
|
2343
2413
|
self.layers,
|
2344
|
-
self.layer_dropouts
|
2414
|
+
self.layer_dropouts,
|
2415
|
+
self.layer_integrators
|
2345
2416
|
)
|
2346
2417
|
|
2347
2418
|
# able to override the layers execution order on forward, for trying to depth extrapolate
|
2348
2419
|
|
2349
2420
|
layers_execute_order = default(layers_execute_order, self.layers_execute_order)
|
2350
|
-
|
2351
2421
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
2352
2422
|
|
2353
2423
|
# derived input for reinjection if needed
|
@@ -2377,7 +2447,7 @@ class AttentionLayers(Module):
|
|
2377
2447
|
|
2378
2448
|
# go through the attention and feedforward layers
|
2379
2449
|
|
2380
|
-
for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
|
2450
|
+
for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout, layer_integrator) in enumerate(zip(*layer_variables)):
|
2381
2451
|
is_last = ind == (len(self.layers) - 1)
|
2382
2452
|
|
2383
2453
|
# handle skip connections
|
@@ -2405,8 +2475,10 @@ class AttentionLayers(Module):
|
|
2405
2475
|
|
2406
2476
|
x, inner_residual, residual_kwargs = residual_fn.prepare(x)
|
2407
2477
|
|
2408
|
-
|
2409
|
-
|
2478
|
+
layer_hiddens.append(x)
|
2479
|
+
|
2480
|
+
if exists(layer_integrator):
|
2481
|
+
x = layer_integrator(x, layer_hiddens)
|
2410
2482
|
|
2411
2483
|
pre_norm, post_branch_norm, post_main_norm = norm
|
2412
2484
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.0
|
3
|
+
Version: 2.1.0
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2427,4 +2427,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2427
2427
|
}
|
2428
2428
|
```
|
2429
2429
|
|
2430
|
+
```bibtex
|
2431
|
+
@inproceedings{Gerasimov2025YouDN,
|
2432
|
+
title = {You Do Not Fully Utilize Transformer's Representation Capacity},
|
2433
|
+
author = {Gleb Gerasimov and Yaroslav Aksenov and Nikita Balagansky and Viacheslav Sinii and Daniil Gavrilov},
|
2434
|
+
year = {2025},
|
2435
|
+
url = {https://api.semanticscholar.org/CorpusID:276317819}
|
2436
|
+
}
|
2437
|
+
```
|
2438
|
+
|
2430
2439
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -6,10 +6,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=PhwkSTxLYxFPLn-mjVe6t5LrrZEKsS8P03u7Q9v_KYM,110061
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-2.0.
|
13
|
-
x_transformers-2.0.
|
14
|
-
x_transformers-2.0.
|
15
|
-
x_transformers-2.0.
|
12
|
+
x_transformers-2.1.0.dist-info/METADATA,sha256=BzidlkOJz0xRRSpZjOxcph8e_K16QOpELw6JUeDS9mQ,87275
|
13
|
+
x_transformers-2.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
14
|
+
x_transformers-2.1.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
15
|
+
x_transformers-2.1.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|