x-transformers 2.0.4__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +89 -18
- {x_transformers-2.0.4.dist-info → x_transformers-2.1.0.dist-info}/METADATA +10 -1
- {x_transformers-2.0.4.dist-info → x_transformers-2.1.0.dist-info}/RECORD +5 -5
- {x_transformers-2.0.4.dist-info → x_transformers-2.1.0.dist-info}/WHEEL +0 -0
- {x_transformers-2.0.4.dist-info → x_transformers-2.1.0.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -9,7 +9,7 @@ from packaging import version
|
|
9
9
|
import torch
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
|
-
from torch import nn, einsum, Tensor, cat, stack, arange
|
12
|
+
from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
|
13
13
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
@@ -449,17 +449,16 @@ class DynamicPositionBias(Module):
|
|
449
449
|
return next(self.parameters()).device
|
450
450
|
|
451
451
|
def forward(self, i, j):
|
452
|
-
assert i == j
|
453
452
|
n, device = j, self.device
|
454
453
|
|
455
454
|
# get the (n x n) matrix of distances
|
456
|
-
seq_arange = arange(
|
457
|
-
context_arange = arange(
|
455
|
+
seq_arange = arange(j - i, j, device = device)
|
456
|
+
context_arange = arange(j, device = device)
|
458
457
|
indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
|
459
|
-
indices += (
|
458
|
+
indices += (j - 1)
|
460
459
|
|
461
460
|
# input to continuous positions MLP
|
462
|
-
pos = arange(-
|
461
|
+
pos = arange(-j + 1, j, device = device).float()
|
463
462
|
pos = rearrange(pos, '... -> ... 1')
|
464
463
|
|
465
464
|
if self.log_distance:
|
@@ -963,6 +962,45 @@ class HyperConnection(Module):
|
|
963
962
|
residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
|
964
963
|
return rearrange(residuals, 'b n s d -> (b s) n d')
|
965
964
|
|
965
|
+
# LIMe - layer integrated memory (dynamic version)
|
966
|
+
|
967
|
+
class DynamicLIMe(Module):
|
968
|
+
def __init__(
|
969
|
+
self,
|
970
|
+
dim,
|
971
|
+
num_layers,
|
972
|
+
num_views = 1,
|
973
|
+
use_softmax = True
|
974
|
+
):
|
975
|
+
super().__init__()
|
976
|
+
self.num_layers = num_layers
|
977
|
+
self.multiple_views = num_views > 1
|
978
|
+
|
979
|
+
self.to_weights = Sequential(
|
980
|
+
nn.Linear(dim, num_views * num_layers),
|
981
|
+
Rearrange('... (views layers) -> views ... layers', views = num_views),
|
982
|
+
nn.Softmax(dim = -1) if use_softmax else nn.ReLU()
|
983
|
+
)
|
984
|
+
|
985
|
+
def forward(
|
986
|
+
self,
|
987
|
+
x,
|
988
|
+
hiddens
|
989
|
+
):
|
990
|
+
if not is_tensor(hiddens):
|
991
|
+
hiddens = stack(hiddens)
|
992
|
+
|
993
|
+
assert hiddens.shape[0] == self.num_layers, f'expected hiddens to have {self.num_layers} layers but received {tuple(hiddens.shape)} instead (first dimension must be layers)'
|
994
|
+
|
995
|
+
weights = self.to_weights(x)
|
996
|
+
|
997
|
+
out = einsum('l b n d, v b n l -> v b n d', hiddens, weights)
|
998
|
+
|
999
|
+
if self.multiple_views:
|
1000
|
+
return out
|
1001
|
+
|
1002
|
+
return rearrange(out, '1 ... -> ...')
|
1003
|
+
|
966
1004
|
# token shifting
|
967
1005
|
|
968
1006
|
def shift(t, amount, mask = None):
|
@@ -1307,7 +1345,7 @@ class Attention(Module):
|
|
1307
1345
|
|
1308
1346
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
1309
1347
|
|
1310
|
-
# whether qkv receives different residual stream combinations from hyper connections
|
1348
|
+
# whether qkv receives different residual stream combinations from hyper connections or lime
|
1311
1349
|
|
1312
1350
|
self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
|
1313
1351
|
|
@@ -1870,6 +1908,8 @@ class AttentionLayers(Module):
|
|
1870
1908
|
use_layerscale = False,
|
1871
1909
|
layerscale_init_value = 0.,
|
1872
1910
|
unet_skips = False,
|
1911
|
+
integrate_layers = False,
|
1912
|
+
layer_integrate_use_softmax = True,
|
1873
1913
|
num_residual_streams = 1,
|
1874
1914
|
qkv_receive_diff_residuals = False,
|
1875
1915
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
@@ -1896,16 +1936,30 @@ class AttentionLayers(Module):
|
|
1896
1936
|
self.causal = causal
|
1897
1937
|
self.layers = ModuleList([])
|
1898
1938
|
|
1899
|
-
#
|
1939
|
+
# routing related
|
1940
|
+
# 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
|
1941
|
+
# 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
|
1942
|
+
|
1943
|
+
qkv_receive_diff_residuals |= integrate_layers # qkv always receives different views if integrating layers
|
1944
|
+
|
1945
|
+
# hyper connections
|
1900
1946
|
|
1901
1947
|
assert num_residual_streams > 0
|
1948
|
+
has_hyper_connections = num_residual_streams > 1
|
1902
1949
|
|
1903
1950
|
self.num_residual_streams = num_residual_streams
|
1904
1951
|
self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
|
1905
1952
|
|
1906
|
-
assert not (
|
1953
|
+
assert not (has_hyper_connections and gate_residual)
|
1907
1954
|
|
1908
|
-
|
1955
|
+
hyper_conn_produce_diff_views = qkv_receive_diff_residuals and not integrate_layers
|
1956
|
+
|
1957
|
+
# LIMe
|
1958
|
+
|
1959
|
+
hiddens_counter = 0
|
1960
|
+
self.layer_integrators = ModuleList([])
|
1961
|
+
|
1962
|
+
assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))
|
1909
1963
|
|
1910
1964
|
# positions related
|
1911
1965
|
|
@@ -2148,14 +2202,19 @@ class AttentionLayers(Module):
|
|
2148
2202
|
|
2149
2203
|
if layer_type == 'a':
|
2150
2204
|
self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
|
2151
|
-
|
2205
|
+
qkv_receives_diff_view = qkv_receive_diff_residuals and not (is_first_self_attn and integrate_layers)
|
2206
|
+
|
2207
|
+
layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receives_diff_view, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
|
2152
2208
|
is_first_self_attn = False
|
2209
|
+
|
2153
2210
|
elif layer_type == 'c':
|
2154
2211
|
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
|
2155
2212
|
is_first_cross_attn = False
|
2213
|
+
|
2156
2214
|
elif layer_type == 'f':
|
2157
2215
|
layer = FeedForward(dim, **ff_kwargs)
|
2158
2216
|
layer = layer if not macaron else Scale(0.5, layer)
|
2217
|
+
|
2159
2218
|
else:
|
2160
2219
|
raise Exception(f'invalid layer type {layer_type}')
|
2161
2220
|
|
@@ -2167,10 +2226,18 @@ class AttentionLayers(Module):
|
|
2167
2226
|
if exists(post_branch_fn):
|
2168
2227
|
layer = post_branch_fn(layer)
|
2169
2228
|
|
2170
|
-
|
2229
|
+
layer_integrate = None
|
2230
|
+
|
2231
|
+
if integrate_layers:
|
2232
|
+
num_layer_hiddens = ind + 1
|
2233
|
+
layer_integrate_num_view = 3 if layer_type == 'a' and qkv_receives_diff_view else 1
|
2234
|
+
|
2235
|
+
layer_integrate = DynamicLIMe(dim, num_layer_hiddens, num_views = layer_integrate_num_view, use_softmax = layer_integrate_use_softmax)
|
2236
|
+
|
2237
|
+
if has_hyper_connections:
|
2171
2238
|
residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
|
2172
2239
|
|
2173
|
-
if layer_type == 'a' and
|
2240
|
+
if layer_type == 'a' and hyper_conn_produce_diff_views:
|
2174
2241
|
residual_fn = partial(residual_fn, num_input_views = 3)
|
2175
2242
|
|
2176
2243
|
elif gate_residual:
|
@@ -2202,6 +2269,8 @@ class AttentionLayers(Module):
|
|
2202
2269
|
|
2203
2270
|
self.skip_combines.append(skip_combine)
|
2204
2271
|
|
2272
|
+
self.layer_integrators.append(layer_integrate)
|
2273
|
+
|
2205
2274
|
self.layers.append(ModuleList([
|
2206
2275
|
norms,
|
2207
2276
|
layer,
|
@@ -2342,13 +2411,13 @@ class AttentionLayers(Module):
|
|
2342
2411
|
self.layer_types,
|
2343
2412
|
self.skip_combines,
|
2344
2413
|
self.layers,
|
2345
|
-
self.layer_dropouts
|
2414
|
+
self.layer_dropouts,
|
2415
|
+
self.layer_integrators
|
2346
2416
|
)
|
2347
2417
|
|
2348
2418
|
# able to override the layers execution order on forward, for trying to depth extrapolate
|
2349
2419
|
|
2350
2420
|
layers_execute_order = default(layers_execute_order, self.layers_execute_order)
|
2351
|
-
|
2352
2421
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
2353
2422
|
|
2354
2423
|
# derived input for reinjection if needed
|
@@ -2378,7 +2447,7 @@ class AttentionLayers(Module):
|
|
2378
2447
|
|
2379
2448
|
# go through the attention and feedforward layers
|
2380
2449
|
|
2381
|
-
for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
|
2450
|
+
for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout, layer_integrator) in enumerate(zip(*layer_variables)):
|
2382
2451
|
is_last = ind == (len(self.layers) - 1)
|
2383
2452
|
|
2384
2453
|
# handle skip connections
|
@@ -2406,8 +2475,10 @@ class AttentionLayers(Module):
|
|
2406
2475
|
|
2407
2476
|
x, inner_residual, residual_kwargs = residual_fn.prepare(x)
|
2408
2477
|
|
2409
|
-
|
2410
|
-
|
2478
|
+
layer_hiddens.append(x)
|
2479
|
+
|
2480
|
+
if exists(layer_integrator):
|
2481
|
+
x = layer_integrator(x, layer_hiddens)
|
2411
2482
|
|
2412
2483
|
pre_norm, post_branch_norm, post_main_norm = norm
|
2413
2484
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.0
|
3
|
+
Version: 2.1.0
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2427,4 +2427,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2427
2427
|
}
|
2428
2428
|
```
|
2429
2429
|
|
2430
|
+
```bibtex
|
2431
|
+
@inproceedings{Gerasimov2025YouDN,
|
2432
|
+
title = {You Do Not Fully Utilize Transformer's Representation Capacity},
|
2433
|
+
author = {Gleb Gerasimov and Yaroslav Aksenov and Nikita Balagansky and Viacheslav Sinii and Daniil Gavrilov},
|
2434
|
+
year = {2025},
|
2435
|
+
url = {https://api.semanticscholar.org/CorpusID:276317819}
|
2436
|
+
}
|
2437
|
+
```
|
2438
|
+
|
2430
2439
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -6,10 +6,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=PhwkSTxLYxFPLn-mjVe6t5LrrZEKsS8P03u7Q9v_KYM,110061
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-2.0.
|
13
|
-
x_transformers-2.0.
|
14
|
-
x_transformers-2.0.
|
15
|
-
x_transformers-2.0.
|
12
|
+
x_transformers-2.1.0.dist-info/METADATA,sha256=BzidlkOJz0xRRSpZjOxcph8e_K16QOpELw6JUeDS9mQ,87275
|
13
|
+
x_transformers-2.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
14
|
+
x_transformers-2.1.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
15
|
+
x_transformers-2.1.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|