x-transformers 2.0.5__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from packaging import version
9
9
  import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
- from torch import nn, einsum, Tensor, cat, stack, arange
12
+ from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
13
13
  from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
@@ -962,6 +962,45 @@ class HyperConnection(Module):
962
962
  residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
963
963
  return rearrange(residuals, 'b n s d -> (b s) n d')
964
964
 
965
+ # LIMe - layer integrated memory (dynamic version)
966
+
967
+ class DynamicLIMe(Module):
968
+ def __init__(
969
+ self,
970
+ dim,
971
+ num_layers,
972
+ num_views = 1,
973
+ use_softmax = True
974
+ ):
975
+ super().__init__()
976
+ self.num_layers = num_layers
977
+ self.multiple_views = num_views > 1
978
+
979
+ self.to_weights = Sequential(
980
+ nn.Linear(dim, num_views * num_layers),
981
+ Rearrange('... (views layers) -> views ... layers', views = num_views),
982
+ nn.Softmax(dim = -1) if use_softmax else nn.ReLU()
983
+ )
984
+
985
+ def forward(
986
+ self,
987
+ x,
988
+ hiddens
989
+ ):
990
+ if not is_tensor(hiddens):
991
+ hiddens = stack(hiddens)
992
+
993
+ assert hiddens.shape[0] == self.num_layers, f'expected hiddens to have {self.num_layers} layers but received {tuple(hiddens.shape)} instead (first dimension must be layers)'
994
+
995
+ weights = self.to_weights(x)
996
+
997
+ out = einsum('l b n d, v b n l -> v b n d', hiddens, weights)
998
+
999
+ if self.multiple_views:
1000
+ return out
1001
+
1002
+ return rearrange(out, '1 ... -> ...')
1003
+
965
1004
  # token shifting
966
1005
 
967
1006
  def shift(t, amount, mask = None):
@@ -1306,7 +1345,7 @@ class Attention(Module):
1306
1345
 
1307
1346
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
1308
1347
 
1309
- # whether qkv receives different residual stream combinations from hyper connections
1348
+ # whether qkv receives different residual stream combinations from hyper connections or lime
1310
1349
 
1311
1350
  self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
1312
1351
 
@@ -1869,6 +1908,8 @@ class AttentionLayers(Module):
1869
1908
  use_layerscale = False,
1870
1909
  layerscale_init_value = 0.,
1871
1910
  unet_skips = False,
1911
+ integrate_layers = False,
1912
+ layer_integrate_use_softmax = True,
1872
1913
  num_residual_streams = 1,
1873
1914
  qkv_receive_diff_residuals = False,
1874
1915
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
@@ -1895,16 +1936,30 @@ class AttentionLayers(Module):
1895
1936
  self.causal = causal
1896
1937
  self.layers = ModuleList([])
1897
1938
 
1898
- # greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
1939
+ # routing related
1940
+ # 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
1941
+ # 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
1942
+
1943
+ qkv_receive_diff_residuals |= integrate_layers # qkv always receives different views if integrating layers
1944
+
1945
+ # hyper connections
1899
1946
 
1900
1947
  assert num_residual_streams > 0
1948
+ has_hyper_connections = num_residual_streams > 1
1901
1949
 
1902
1950
  self.num_residual_streams = num_residual_streams
1903
1951
  self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
1904
1952
 
1905
- assert not (num_residual_streams > 1 and gate_residual)
1953
+ assert not (has_hyper_connections and gate_residual)
1954
+
1955
+ hyper_conn_produce_diff_views = qkv_receive_diff_residuals and not integrate_layers
1906
1956
 
1907
- assert not (num_residual_streams == 1 and qkv_receive_diff_residuals)
1957
+ # LIMe
1958
+
1959
+ hiddens_counter = 0
1960
+ self.layer_integrators = ModuleList([])
1961
+
1962
+ assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))
1908
1963
 
1909
1964
  # positions related
1910
1965
 
@@ -2147,14 +2202,19 @@ class AttentionLayers(Module):
2147
2202
 
2148
2203
  if layer_type == 'a':
2149
2204
  self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
2150
- layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receive_diff_residuals, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
2205
+ qkv_receives_diff_view = qkv_receive_diff_residuals and not (is_first_self_attn and integrate_layers)
2206
+
2207
+ layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receives_diff_view, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
2151
2208
  is_first_self_attn = False
2209
+
2152
2210
  elif layer_type == 'c':
2153
2211
  layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
2154
2212
  is_first_cross_attn = False
2213
+
2155
2214
  elif layer_type == 'f':
2156
2215
  layer = FeedForward(dim, **ff_kwargs)
2157
2216
  layer = layer if not macaron else Scale(0.5, layer)
2217
+
2158
2218
  else:
2159
2219
  raise Exception(f'invalid layer type {layer_type}')
2160
2220
 
@@ -2166,10 +2226,18 @@ class AttentionLayers(Module):
2166
2226
  if exists(post_branch_fn):
2167
2227
  layer = post_branch_fn(layer)
2168
2228
 
2169
- if num_residual_streams > 1:
2229
+ layer_integrate = None
2230
+
2231
+ if integrate_layers:
2232
+ num_layer_hiddens = ind + 1
2233
+ layer_integrate_num_view = 3 if layer_type == 'a' and qkv_receives_diff_view else 1
2234
+
2235
+ layer_integrate = DynamicLIMe(dim, num_layer_hiddens, num_views = layer_integrate_num_view, use_softmax = layer_integrate_use_softmax)
2236
+
2237
+ if has_hyper_connections:
2170
2238
  residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
2171
2239
 
2172
- if layer_type == 'a' and qkv_receive_diff_residuals:
2240
+ if layer_type == 'a' and hyper_conn_produce_diff_views:
2173
2241
  residual_fn = partial(residual_fn, num_input_views = 3)
2174
2242
 
2175
2243
  elif gate_residual:
@@ -2201,6 +2269,8 @@ class AttentionLayers(Module):
2201
2269
 
2202
2270
  self.skip_combines.append(skip_combine)
2203
2271
 
2272
+ self.layer_integrators.append(layer_integrate)
2273
+
2204
2274
  self.layers.append(ModuleList([
2205
2275
  norms,
2206
2276
  layer,
@@ -2341,13 +2411,13 @@ class AttentionLayers(Module):
2341
2411
  self.layer_types,
2342
2412
  self.skip_combines,
2343
2413
  self.layers,
2344
- self.layer_dropouts
2414
+ self.layer_dropouts,
2415
+ self.layer_integrators
2345
2416
  )
2346
2417
 
2347
2418
  # able to override the layers execution order on forward, for trying to depth extrapolate
2348
2419
 
2349
2420
  layers_execute_order = default(layers_execute_order, self.layers_execute_order)
2350
-
2351
2421
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
2352
2422
 
2353
2423
  # derived input for reinjection if needed
@@ -2377,7 +2447,7 @@ class AttentionLayers(Module):
2377
2447
 
2378
2448
  # go through the attention and feedforward layers
2379
2449
 
2380
- for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
2450
+ for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout, layer_integrator) in enumerate(zip(*layer_variables)):
2381
2451
  is_last = ind == (len(self.layers) - 1)
2382
2452
 
2383
2453
  # handle skip connections
@@ -2405,8 +2475,10 @@ class AttentionLayers(Module):
2405
2475
 
2406
2476
  x, inner_residual, residual_kwargs = residual_fn.prepare(x)
2407
2477
 
2408
- if return_hiddens:
2409
- layer_hiddens.append(x)
2478
+ layer_hiddens.append(x)
2479
+
2480
+ if exists(layer_integrator):
2481
+ x = layer_integrator(x, layer_hiddens)
2410
2482
 
2411
2483
  pre_norm, post_branch_norm, post_main_norm = norm
2412
2484
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.0.5
3
+ Version: 2.1.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2427,4 +2427,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2427
2427
  }
2428
2428
  ```
2429
2429
 
2430
+ ```bibtex
2431
+ @inproceedings{Gerasimov2025YouDN,
2432
+ title = {You Do Not Fully Utilize Transformer's Representation Capacity},
2433
+ author = {Gleb Gerasimov and Yaroslav Aksenov and Nikita Balagansky and Viacheslav Sinii and Daniil Gavrilov},
2434
+ year = {2025},
2435
+ url = {https://api.semanticscholar.org/CorpusID:276317819}
2436
+ }
2437
+ ```
2438
+
2430
2439
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -6,10 +6,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=yijnlpQnhC0lK5qYzSxII7IkVf7ILhsTyntw_S5MvRU,107670
9
+ x_transformers/x_transformers.py,sha256=PhwkSTxLYxFPLn-mjVe6t5LrrZEKsS8P03u7Q9v_KYM,110061
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-2.0.5.dist-info/METADATA,sha256=9U0kHbTwa2sv4z-pCqlkcm998SDcELMonQ9JoHaYgR4,86938
13
- x_transformers-2.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- x_transformers-2.0.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
15
- x_transformers-2.0.5.dist-info/RECORD,,
12
+ x_transformers-2.1.0.dist-info/METADATA,sha256=BzidlkOJz0xRRSpZjOxcph8e_K16QOpELw6JUeDS9mQ,87275
13
+ x_transformers-2.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ x_transformers-2.1.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
15
+ x_transformers-2.1.0.dist-info/RECORD,,