x-transformers 2.0.5__py3-none-any.whl → 2.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from packaging import version
9
9
  import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
- from torch import nn, einsum, Tensor, cat, stack, arange
12
+ from torch import nn, einsum, Tensor, cat, stack, arange, is_tensor
13
13
  from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
@@ -962,6 +962,45 @@ class HyperConnection(Module):
962
962
  residuals = einsum('b n d, b n s -> b n s d', x, beta) + residuals
963
963
  return rearrange(residuals, 'b n s d -> (b s) n d')
964
964
 
965
+ # LIMe - layer integrated memory (dynamic version)
966
+
967
+ class DynamicLIMe(Module):
968
+ def __init__(
969
+ self,
970
+ dim,
971
+ num_layers,
972
+ num_views = 1,
973
+ use_softmax = True
974
+ ):
975
+ super().__init__()
976
+ self.num_layers = num_layers
977
+ self.multiple_views = num_views > 1
978
+
979
+ self.to_weights = Sequential(
980
+ nn.Linear(dim, num_views * num_layers),
981
+ Rearrange('... (views layers) -> views ... layers', views = num_views),
982
+ nn.Softmax(dim = -1) if use_softmax else nn.ReLU()
983
+ )
984
+
985
+ def forward(
986
+ self,
987
+ x,
988
+ hiddens
989
+ ):
990
+ if not is_tensor(hiddens):
991
+ hiddens = stack(hiddens)
992
+
993
+ assert hiddens.shape[0] == self.num_layers, f'expected hiddens to have {self.num_layers} layers but received {tuple(hiddens.shape)} instead (first dimension must be layers)'
994
+
995
+ weights = self.to_weights(x)
996
+
997
+ out = einsum('l b n d, v b n l -> v b n d', hiddens, weights)
998
+
999
+ if self.multiple_views:
1000
+ return out
1001
+
1002
+ return rearrange(out, '1 ... -> ...')
1003
+
965
1004
  # token shifting
966
1005
 
967
1006
  def shift(t, amount, mask = None):
@@ -1306,7 +1345,7 @@ class Attention(Module):
1306
1345
 
1307
1346
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
1308
1347
 
1309
- # whether qkv receives different residual stream combinations from hyper connections
1348
+ # whether qkv receives different residual stream combinations from hyper connections or lime
1310
1349
 
1311
1350
  self.qkv_receive_diff_residuals = qkv_receive_diff_residuals
1312
1351
 
@@ -1869,6 +1908,8 @@ class AttentionLayers(Module):
1869
1908
  use_layerscale = False,
1870
1909
  layerscale_init_value = 0.,
1871
1910
  unet_skips = False,
1911
+ integrate_layers = False,
1912
+ layer_integrate_use_softmax = True,
1872
1913
  num_residual_streams = 1,
1873
1914
  qkv_receive_diff_residuals = False,
1874
1915
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
@@ -1895,16 +1936,30 @@ class AttentionLayers(Module):
1895
1936
  self.causal = causal
1896
1937
  self.layers = ModuleList([])
1897
1938
 
1898
- # greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
1939
+ # routing related
1940
+ # 1. greater than one residual stream, proposed in Hyper-Connections paper https://arxiv.org/abs/2409.19606
1941
+ # 2. integrating more than one past layer, from LIMe paper https://arxiv.org/abs/2502.09245
1942
+
1943
+ qkv_receive_diff_residuals |= integrate_layers # qkv always receives different views if integrating layers
1944
+
1945
+ # hyper connections
1899
1946
 
1900
1947
  assert num_residual_streams > 0
1948
+ has_hyper_connections = num_residual_streams > 1
1901
1949
 
1902
1950
  self.num_residual_streams = num_residual_streams
1903
1951
  self.stream_emb = nn.Parameter(torch.zeros(num_residual_streams, dim)) if num_residual_streams > 1 else None
1904
1952
 
1905
- assert not (num_residual_streams > 1 and gate_residual)
1953
+ assert not (has_hyper_connections and gate_residual)
1954
+
1955
+ hyper_conn_produce_diff_views = qkv_receive_diff_residuals and not integrate_layers
1956
+
1957
+ # LIMe
1906
1958
 
1907
- assert not (num_residual_streams == 1 and qkv_receive_diff_residuals)
1959
+ hiddens_counter = 0
1960
+ self.layer_integrators = ModuleList([])
1961
+
1962
+ assert not (qkv_receive_diff_residuals and not (hyper_conn_produce_diff_views or integrate_layers))
1908
1963
 
1909
1964
  # positions related
1910
1965
 
@@ -2145,16 +2200,22 @@ class AttentionLayers(Module):
2145
2200
 
2146
2201
  # attention, cross attention, feedforward
2147
2202
 
2203
+ layer_qkv_receives_diff_view = layer_type == 'a' and qkv_receive_diff_residuals and not (is_first_self_attn and integrate_layers)
2204
+
2148
2205
  if layer_type == 'a':
2149
2206
  self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
2150
- layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = qkv_receive_diff_residuals, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
2207
+
2208
+ layer = Attention(dim, heads = heads, causal = causal, qkv_receive_diff_residuals = layer_qkv_receives_diff_view, learned_value_residual_mix = self_attn_learned_value_residual, rotate_num_heads = rotate_num_heads, **attn_kwargs)
2151
2209
  is_first_self_attn = False
2210
+
2152
2211
  elif layer_type == 'c':
2153
2212
  layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
2154
2213
  is_first_cross_attn = False
2214
+
2155
2215
  elif layer_type == 'f':
2156
2216
  layer = FeedForward(dim, **ff_kwargs)
2157
2217
  layer = layer if not macaron else Scale(0.5, layer)
2218
+
2158
2219
  else:
2159
2220
  raise Exception(f'invalid layer type {layer_type}')
2160
2221
 
@@ -2166,10 +2227,18 @@ class AttentionLayers(Module):
2166
2227
  if exists(post_branch_fn):
2167
2228
  layer = post_branch_fn(layer)
2168
2229
 
2169
- if num_residual_streams > 1:
2230
+ layer_integrate = None
2231
+
2232
+ if integrate_layers:
2233
+ num_layer_hiddens = ind + 1
2234
+ layer_integrate_num_view = 3 if layer_qkv_receives_diff_view else 1
2235
+
2236
+ layer_integrate = DynamicLIMe(dim, num_layer_hiddens, num_views = layer_integrate_num_view, use_softmax = layer_integrate_use_softmax)
2237
+
2238
+ if has_hyper_connections:
2170
2239
  residual_fn = partial(HyperConnection, num_residual_streams = num_residual_streams)
2171
2240
 
2172
- if layer_type == 'a' and qkv_receive_diff_residuals:
2241
+ if layer_type == 'a' and hyper_conn_produce_diff_views:
2173
2242
  residual_fn = partial(residual_fn, num_input_views = 3)
2174
2243
 
2175
2244
  elif gate_residual:
@@ -2201,6 +2270,8 @@ class AttentionLayers(Module):
2201
2270
 
2202
2271
  self.skip_combines.append(skip_combine)
2203
2272
 
2273
+ self.layer_integrators.append(layer_integrate)
2274
+
2204
2275
  self.layers.append(ModuleList([
2205
2276
  norms,
2206
2277
  layer,
@@ -2341,13 +2412,13 @@ class AttentionLayers(Module):
2341
2412
  self.layer_types,
2342
2413
  self.skip_combines,
2343
2414
  self.layers,
2344
- self.layer_dropouts
2415
+ self.layer_dropouts,
2416
+ self.layer_integrators
2345
2417
  )
2346
2418
 
2347
2419
  # able to override the layers execution order on forward, for trying to depth extrapolate
2348
2420
 
2349
2421
  layers_execute_order = default(layers_execute_order, self.layers_execute_order)
2350
-
2351
2422
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
2352
2423
 
2353
2424
  # derived input for reinjection if needed
@@ -2377,7 +2448,7 @@ class AttentionLayers(Module):
2377
2448
 
2378
2449
  # go through the attention and feedforward layers
2379
2450
 
2380
- for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
2451
+ for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout, layer_integrator) in enumerate(zip(*layer_variables)):
2381
2452
  is_last = ind == (len(self.layers) - 1)
2382
2453
 
2383
2454
  # handle skip connections
@@ -2405,8 +2476,10 @@ class AttentionLayers(Module):
2405
2476
 
2406
2477
  x, inner_residual, residual_kwargs = residual_fn.prepare(x)
2407
2478
 
2408
- if return_hiddens:
2409
- layer_hiddens.append(x)
2479
+ layer_hiddens.append(x)
2480
+
2481
+ if exists(layer_integrator):
2482
+ x = layer_integrator(x, layer_hiddens)
2410
2483
 
2411
2484
  pre_norm, post_branch_norm, post_main_norm = norm
2412
2485
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.0.5
3
+ Version: 2.1.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2427,4 +2427,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2427
2427
  }
2428
2428
  ```
2429
2429
 
2430
+ ```bibtex
2431
+ @inproceedings{Gerasimov2025YouDN,
2432
+ title = {You Do Not Fully Utilize Transformer's Representation Capacity},
2433
+ author = {Gleb Gerasimov and Yaroslav Aksenov and Nikita Balagansky and Viacheslav Sinii and Daniil Gavrilov},
2434
+ year = {2025},
2435
+ url = {https://api.semanticscholar.org/CorpusID:276317819}
2436
+ }
2437
+ ```
2438
+
2430
2439
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -6,10 +6,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=yijnlpQnhC0lK5qYzSxII7IkVf7ILhsTyntw_S5MvRU,107670
9
+ x_transformers/x_transformers.py,sha256=bIlP-NHj0SB2joklpxicoaD1HVpRMGIulMF8WYEsOAQ,110076
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-2.0.5.dist-info/METADATA,sha256=9U0kHbTwa2sv4z-pCqlkcm998SDcELMonQ9JoHaYgR4,86938
13
- x_transformers-2.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- x_transformers-2.0.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
15
- x_transformers-2.0.5.dist-info/RECORD,,
12
+ x_transformers-2.1.1.dist-info/METADATA,sha256=BBGKnocyDvj_ynWM5dtrbyX1iodI4eWEnn9TWrw38kc,87275
13
+ x_transformers-2.1.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ x_transformers-2.1.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
15
+ x_transformers-2.1.1.dist-info/RECORD,,