x-transformers 1.35.2__py3-none-any.whl → 1.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import math
4
- from random import random
4
+ from random import random, randrange
5
5
  from packaging import version
6
6
 
7
7
  import torch
@@ -12,6 +12,7 @@ from torch.amp import autocast
12
12
 
13
13
  from functools import partial, wraps
14
14
  from collections import namedtuple
15
+ from contextlib import nullcontext
15
16
  from dataclasses import dataclass
16
17
  from typing import List, Dict, Tuple, Callable
17
18
 
@@ -920,6 +921,7 @@ class Attention(Module):
920
921
  kv_heads = None,
921
922
  shared_kv = False,
922
923
  value_dim_head = None,
924
+ dim_out = None,
923
925
  tensor_product = False, # https://arxiv.org/abs/2208.06061
924
926
  add_zero_kv = False, # same as add_zero_attn in pytorch
925
927
  rotary_embed_values = False,
@@ -1057,7 +1059,11 @@ class Attention(Module):
1057
1059
  # attention on attention
1058
1060
 
1059
1061
  self.attn_on_attn = on_attn
1060
- self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
1062
+
1063
+ # output dimension by default same as input, but can be overridden
1064
+
1065
+ dim_out = default(dim_out, dim)
1066
+ self.to_out = nn.Sequential(nn.Linear(out_dim, dim_out * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim_out, bias = False)
1061
1067
 
1062
1068
  # whether to rotate positions into values, for absolute positions in addition to relative
1063
1069
 
@@ -1982,7 +1988,9 @@ class TransformerWrapper(Module):
1982
1988
  use_abs_pos_emb = True,
1983
1989
  scaled_sinu_pos_emb = False,
1984
1990
  l2norm_embed = False,
1985
- emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1991
+ recycling = False, # from Jumper et al. - Alphafold2
1992
+ train_max_recycle_steps = 4, # saw a benefit for language modeling up to 3 recycling steps, so let's default this to 4
1993
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1986
1994
  attn_z_loss_weight = 1e-4,
1987
1995
  average_pool_embed = False,
1988
1996
  use_cls_token = False,
@@ -2039,6 +2047,13 @@ class TransformerWrapper(Module):
2039
2047
 
2040
2048
  assert at_most_one_of(average_pool_embed, use_cls_token)
2041
2049
 
2050
+ # maybe recycling
2051
+
2052
+ self.recycling = recycling
2053
+ self.recycled_proj = nn.Linear(dim, dim, bias = False) if recycling else None
2054
+
2055
+ self.train_max_recycle_steps = train_max_recycle_steps
2056
+
2042
2057
  # classic cls token from the bert days
2043
2058
 
2044
2059
  self.cls_token = None
@@ -2082,7 +2097,7 @@ class TransformerWrapper(Module):
2082
2097
 
2083
2098
  # whether can do cached kv decoding
2084
2099
 
2085
- self.can_cache_kv = self.num_memory_tokens == 0
2100
+ self.can_cache_kv = self.num_memory_tokens == 0 and not recycling
2086
2101
  self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
2087
2102
 
2088
2103
  def init_(self):
@@ -2105,6 +2120,7 @@ class TransformerWrapper(Module):
2105
2120
  return_attn = False,
2106
2121
  mems = None,
2107
2122
  mem_masks = None,
2123
+ recycle_steps = None,
2108
2124
  pos = None,
2109
2125
  prepend_embeds = None,
2110
2126
  prepend_mask = None,
@@ -2210,11 +2226,37 @@ class TransformerWrapper(Module):
2210
2226
  if exists(mem_every):
2211
2227
  x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2212
2228
 
2229
+ # handle maybe shifting of memories
2230
+
2213
2231
  if self.shift_mem_down and exists(mems):
2214
2232
  mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
2215
2233
  mems = [*mems_r, *mems_l]
2216
2234
 
2217
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2235
+ # attention layers
2236
+
2237
+ if not self.recycling:
2238
+ # regular
2239
+
2240
+ attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2241
+
2242
+ else:
2243
+ # recycling
2244
+
2245
+ recycle_steps = default(recycle_steps, (randrange(self.train_max_recycle_steps) + 1) if self.training else None)
2246
+ assert exists(recycle_steps) and recycle_steps > 0, '`recycle_steps` must be provided on forward if recycling is turned on and not training'
2247
+
2248
+ for i in range(recycle_steps):
2249
+ first_step = i == 0
2250
+ last_step = i == (recycle_steps - 1)
2251
+
2252
+ context = nullcontext if last_step else torch.no_grad
2253
+
2254
+ with context():
2255
+ maybe_recycled = self.recycled_proj(attended.detach()) if not first_step else 0.
2256
+
2257
+ attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2258
+
2259
+ x = attended
2218
2260
 
2219
2261
  # handle memories post-attention
2220
2262
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.35.2
3
+ Version: 1.36.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=mpA9hriHqCXLckdlVo8sxzXT6sjxwsY6AaKoP-Rpw3c,80631
8
+ x_transformers/x_transformers.py,sha256=iib15Squ9VE7tLpb4Z4_Hq_hi7dZhPNR_xPtC9BzMrE,82321
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.35.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.35.2.dist-info/METADATA,sha256=4UuWPhkRRayYadZ8kwaHyqEGhHurnWJGRbPTzDMdEZo,661
13
- x_transformers-1.35.2.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.35.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.35.2.dist-info/RECORD,,
11
+ x_transformers-1.36.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.36.0.dist-info/METADATA,sha256=YKcnT5T0UkZxwpP72cPfx9RN0SVoBYy0e6Xo581YCE0,661
13
+ x_transformers-1.36.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.36.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.36.0.dist-info/RECORD,,