x-transformers 1.35.3__py3-none-any.whl → 1.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import math
4
- from random import random
4
+ from random import random, randrange
5
5
  from packaging import version
6
6
 
7
7
  import torch
@@ -12,6 +12,7 @@ from torch.amp import autocast
12
12
 
13
13
  from functools import partial, wraps
14
14
  from collections import namedtuple
15
+ from contextlib import nullcontext
15
16
  from dataclasses import dataclass
16
17
  from typing import List, Dict, Tuple, Callable
17
18
 
@@ -1987,7 +1988,9 @@ class TransformerWrapper(Module):
1987
1988
  use_abs_pos_emb = True,
1988
1989
  scaled_sinu_pos_emb = False,
1989
1990
  l2norm_embed = False,
1990
- emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1991
+ recycling = False, # from Jumper et al. - Alphafold2
1992
+ train_max_recycle_steps = 4, # saw a benefit for language modeling up to 3 recycling steps, so let's default this to 4
1993
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1991
1994
  attn_z_loss_weight = 1e-4,
1992
1995
  average_pool_embed = False,
1993
1996
  use_cls_token = False,
@@ -2044,6 +2047,13 @@ class TransformerWrapper(Module):
2044
2047
 
2045
2048
  assert at_most_one_of(average_pool_embed, use_cls_token)
2046
2049
 
2050
+ # maybe recycling
2051
+
2052
+ self.recycling = recycling
2053
+ self.recycled_proj = nn.Linear(dim, dim, bias = False) if recycling else None
2054
+
2055
+ self.train_max_recycle_steps = train_max_recycle_steps
2056
+
2047
2057
  # classic cls token from the bert days
2048
2058
 
2049
2059
  self.cls_token = None
@@ -2087,7 +2097,7 @@ class TransformerWrapper(Module):
2087
2097
 
2088
2098
  # whether can do cached kv decoding
2089
2099
 
2090
- self.can_cache_kv = self.num_memory_tokens == 0
2100
+ self.can_cache_kv = self.num_memory_tokens == 0 and not recycling
2091
2101
  self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
2092
2102
 
2093
2103
  def init_(self):
@@ -2110,6 +2120,7 @@ class TransformerWrapper(Module):
2110
2120
  return_attn = False,
2111
2121
  mems = None,
2112
2122
  mem_masks = None,
2123
+ recycle_steps = None,
2113
2124
  pos = None,
2114
2125
  prepend_embeds = None,
2115
2126
  prepend_mask = None,
@@ -2215,11 +2226,37 @@ class TransformerWrapper(Module):
2215
2226
  if exists(mem_every):
2216
2227
  x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2217
2228
 
2229
+ # handle maybe shifting of memories
2230
+
2218
2231
  if self.shift_mem_down and exists(mems):
2219
2232
  mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
2220
2233
  mems = [*mems_r, *mems_l]
2221
2234
 
2222
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2235
+ # attention layers
2236
+
2237
+ if not self.recycling:
2238
+ # regular
2239
+
2240
+ attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2241
+
2242
+ else:
2243
+ # recycling
2244
+
2245
+ recycle_steps = default(recycle_steps, (randrange(self.train_max_recycle_steps) + 1) if self.training else None)
2246
+ assert exists(recycle_steps) and recycle_steps > 0, '`recycle_steps` must be provided on forward if recycling is turned on and not training'
2247
+
2248
+ for i in range(recycle_steps):
2249
+ first_step = i == 0
2250
+ last_step = i == (recycle_steps - 1)
2251
+
2252
+ context = nullcontext if last_step else torch.no_grad
2253
+
2254
+ with context():
2255
+ maybe_recycled = self.recycled_proj(attended.detach()) if not first_step else 0.
2256
+
2257
+ attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2258
+
2259
+ x = attended
2223
2260
 
2224
2261
  # handle memories post-attention
2225
2262
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.35.3
3
+ Version: 1.36.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=ma5_LbZf5UvfKYJUJcqceUdFG8THFVzER9ZrDXKVV7Y,80780
8
+ x_transformers/x_transformers.py,sha256=iib15Squ9VE7tLpb4Z4_Hq_hi7dZhPNR_xPtC9BzMrE,82321
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.35.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.35.3.dist-info/METADATA,sha256=YEiRJvu5g17ZVT3saNBhrmpNeRLqPXyN0cBdajt3psM,661
13
- x_transformers-1.35.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
- x_transformers-1.35.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.35.3.dist-info/RECORD,,
11
+ x_transformers-1.36.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.36.0.dist-info/METADATA,sha256=YKcnT5T0UkZxwpP72cPfx9RN0SVoBYy0e6Xo581YCE0,661
13
+ x_transformers-1.36.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
14
+ x_transformers-1.36.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.36.0.dist-info/RECORD,,