x-transformers 2.0.4__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from packaging import version
9
9
  import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
- from torch import nn, einsum, Tensor, cat, stack, arange
12
+ from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
13
13
  from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
@@ -449,17 +449,16 @@ class DynamicPositionBias(Module):
449
449
  return next(self.parameters()).device
450
450
 
451
451
  def forward(self, i, j):
452
- assert i == j
453
452
  n, device = j, self.device
454
453
 
455
454
  # get the (n x n) matrix of distances
456
- seq_arange = arange(n, device = device)
457
- context_arange = arange(n, device = device)
455
+ seq_arange = arange(j - i, j, device = device)
456
+ context_arange = arange(j, device = device)
458
457
  indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
459
- indices += (n - 1)
458
+ indices += (j - 1)
460
459
 
461
460
  # input to continuous positions MLP
462
- pos = arange(-n + 1, n, device = device).float()
461
+ pos = arange(-j + 1, j, device = device).float()
463
462
  pos = rearrange(pos, '... -> ... 1')
464
463
 
465
464
  if self.log_distance:
@@ -963,6 +962,45 @@ class HyperConnection(Module):
963
962
  residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
964
963
  return rearrange(residuals, 'b n s d -> (b s) n d')
965
964
 
965
+ # LIMe - layer integrated memory (dynamic version)
966
+
967
+ class DynamicLIMe(Module):
968
+ def __init__(
969
+ self,
970
+ dim,
971
+ num_layers,
972
+ num_views = 1,
973
+ use_softmax = True
974
+ ):
975
+ super().__init__()
976
+ self.num_layers = num_layers
977
+ self.multiple_views = num_views > 1
978
+
979
+ self.to_weights = Sequential(
980
+ nn.Linear(dim, num_views * num_layers),
981
+ Rearrange('... (views layers) -> views ... layers', views = num_views),
982
+ nn.Softmax(dim = -1) if use_softmax else nn.ReLU()
983
+ )
984
+
985
+ def forward(
986
+ self,
987
+ x,
988
+ hiddens
989
+ ):
990
+ if not is_tensor(hiddens):
991
+ hiddens = stack(hiddens)
992
+
993
+ assert hiddens.shape[0] == self.num_layers, f'expected hiddens to have {self.num_layers} layers but received {tuple(hiddens.shape)} instead (first dimension must be layers)'
994
+
995
+ weights = self.to_weights(x)
996
+
997
+ out = einsum('l b n d, v b n l -> v b n d', hiddens, weights)
998
+
999
+ if self.multiple_views:
1000
+ return out
1001
+
1002
+ return rearrange(out, '1 ... -> ...')
1003
+
966
1004
  # token shifting
967
1005
 
968
1006
  def shift(t, amount, mask = None):
@@ -1307,7 +1345,7 @@ class Attention(Module):
1307
1345
 
1308
1346
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
1309
1347
 
1310
- # whether qkv receives different residual stream combinations from hyper connections
1348
+ # whether qkv receives different residual stream combinations from hyper connections or lime
1311
1349
 
1312
1350
  self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
1313
1351
 
@@ -1870,6 +1908,8 @@ class AttentionLayers(Module):
1870
1908
  use_layerscale = False,
1871
1909
  layerscale_init_value = 0.,
1872
1910
  unet_skips = False,
1911
+ integrate_layers = False,
1912
+ layer_integrate_use_softmax = True,
1873
1913
  num_residual_streams = 1,
1874
1914
  qkv_receive_diff_residuals = False,
1875
1915
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
@@ -1896,16 +1936,30 @@ class AttentionLayers(Module):
1896
1936
  self.causal = causal
1897
1937
  self.layers = ModuleList([])
1898
1938
 
1899
- # greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
1939
+ # routing related
1940
+ # 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
1941
+ # 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
1942
+
1943
+ qkv_receive_diff_residuals |= integrate_layers # qkv always receives different views if integrating layers
1944
+
1945
+ # hyper connections
1900
1946
 
1901
1947
  assert num_residual_streams > 0
1948
+ has_hyper_connections = num_residual_streams > 1
1902
1949
 
1903
1950
  self.num_residual_streams = num_residual_streams
1904
1951
  self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
1905
1952
 
1906
- assert not (num_residual_streams > 1 and gate_residual)
1953
+ assert not (has_hyper_connections and gate_residual)
1907
1954
 
1908
- assert not (num_residual_streams == 1 and qkv_receive_diff_residuals)
1955
+ hyper_conn_produce_diff_views = qkv_receive_diff_residuals and not integrate_layers
1956
+
1957
+ # LIMe
1958
+
1959
+ hiddens_counter = 0
1960
+ self.layer_integrators = ModuleList([])
1961
+
1962
+ assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))
1909
1963
 
1910
1964
  # positions related
1911
1965
 
@@ -2148,14 +2202,19 @@ class AttentionLayers(Module):
2148
2202
 
2149
2203
  if layer_type == 'a':
2150
2204
  self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
2151
- layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receive_diff_residuals, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
2205
+ qkv_receives_diff_view = qkv_receive_diff_residuals and not (is_first_self_attn and integrate_layers)
2206
+
2207
+ layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receives_diff_view, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
2152
2208
  is_first_self_attn = False
2209
+
2153
2210
  elif layer_type == 'c':
2154
2211
  layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
2155
2212
  is_first_cross_attn = False
2213
+
2156
2214
  elif layer_type == 'f':
2157
2215
  layer = FeedForward(dim, **ff_kwargs)
2158
2216
  layer = layer if not macaron else Scale(0.5, layer)
2217
+
2159
2218
  else:
2160
2219
  raise Exception(f'invalid layer type {layer_type}')
2161
2220
 
@@ -2167,10 +2226,18 @@ class AttentionLayers(Module):
2167
2226
  if exists(post_branch_fn):
2168
2227
  layer = post_branch_fn(layer)
2169
2228
 
2170
- if num_residual_streams > 1:
2229
+ layer_integrate = None
2230
+
2231
+ if integrate_layers:
2232
+ num_layer_hiddens = ind + 1
2233
+ layer_integrate_num_view = 3 if layer_type == 'a' and qkv_receives_diff_view else 1
2234
+
2235
+ layer_integrate = DynamicLIMe(dim, num_layer_hiddens, num_views = layer_integrate_num_view, use_softmax = layer_integrate_use_softmax)
2236
+
2237
+ if has_hyper_connections:
2171
2238
  residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
2172
2239
 
2173
- if layer_type == 'a' and qkv_receive_diff_residuals:
2240
+ if layer_type == 'a' and hyper_conn_produce_diff_views:
2174
2241
  residual_fn = partial(residual_fn, num_input_views = 3)
2175
2242
 
2176
2243
  elif gate_residual:
@@ -2202,6 +2269,8 @@ class AttentionLayers(Module):
2202
2269
 
2203
2270
  self.skip_combines.append(skip_combine)
2204
2271
 
2272
+ self.layer_integrators.append(layer_integrate)
2273
+
2205
2274
  self.layers.append(ModuleList([
2206
2275
  norms,
2207
2276
  layer,
@@ -2342,13 +2411,13 @@ class AttentionLayers(Module):
2342
2411
  self.layer_types,
2343
2412
  self.skip_combines,
2344
2413
  self.layers,
2345
- self.layer_dropouts
2414
+ self.layer_dropouts,
2415
+ self.layer_integrators
2346
2416
  )
2347
2417
 
2348
2418
  # able to override the layers execution order on forward, for trying to depth extrapolate
2349
2419
 
2350
2420
  layers_execute_order = default(layers_execute_order, self.layers_execute_order)
2351
-
2352
2421
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
2353
2422
 
2354
2423
  # derived input for reinjection if needed
@@ -2378,7 +2447,7 @@ class AttentionLayers(Module):
2378
2447
 
2379
2448
  # go through the attention and feedforward layers
2380
2449
 
2381
- for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
2450
+ for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout, layer_integrator) in enumerate(zip(*layer_variables)):
2382
2451
  is_last = ind == (len(self.layers) - 1)
2383
2452
 
2384
2453
  # handle skip connections
@@ -2406,8 +2475,10 @@ class AttentionLayers(Module):
2406
2475
 
2407
2476
  x, inner_residual, residual_kwargs = residual_fn.prepare(x)
2408
2477
 
2409
- if return_hiddens:
2410
- layer_hiddens.append(x)
2478
+ layer_hiddens.append(x)
2479
+
2480
+ if exists(layer_integrator):
2481
+ x = layer_integrator(x, layer_hiddens)
2411
2482
 
2412
2483
  pre_norm, post_branch_norm, post_main_norm = norm
2413
2484
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.0.4
3
+ Version: 2.1.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2427,4 +2427,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2427
2427
  }
2428
2428
  ```
2429
2429
 
2430
+ ```bibtex
2431
+ @inproceedings{Gerasimov2025YouDN,
2432
+ title = {You Do Not Fully Utilize Transformer's Representation Capacity},
2433
+ author = {Gleb Gerasimov and Yaroslav Aksenov and Nikita Balagansky and Viacheslav Sinii and Daniil Gavrilov},
2434
+ year = {2025},
2435
+ url = {https://api.semanticscholar.org/CorpusID:276317819}
2436
+ }
2437
+ ```
2438
+
2430
2439
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -6,10 +6,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=iE4m38BUwCB1aENGLV5dMsIuu1t3CElEBKuXfkJfPA4,107685
9
+ x_transformers/x_transformers.py,sha256=PhwkSTxLYxFPLn-mjVe6t5LrrZEKsS8P03u7Q9v_KYM,110061
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-2.0.4.dist-info/METADATA,sha256=UbaywSq7GvNJLub5VFrsooDeUgohEzWWBtA9ZnNOxkI,86938
13
- x_transformers-2.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- x_transformers-2.0.4.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
15
- x_transformers-2.0.4.dist-info/RECORD,,
12
+ x_transformers-2.1.0.dist-info/METADATA,sha256=BzidlkOJz0xRRSpZjOxcph8e_K16QOpELw6JUeDS9mQ,87275
13
+ x_transformers-2.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ x_transformers-2.1.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
15
+ x_transformers-2.1.0.dist-info/RECORD,,