x-transformers 1.35.3__py3-none-any.whl → 1.36.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +41 -4
- {x_transformers-1.35.3.dist-info → x_transformers-1.36.0.dist-info}/METADATA +1 -1
- {x_transformers-1.35.3.dist-info → x_transformers-1.36.0.dist-info}/RECORD +6 -6
- {x_transformers-1.35.3.dist-info → x_transformers-1.36.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.35.3.dist-info → x_transformers-1.36.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.35.3.dist-info → x_transformers-1.36.0.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import math
|
4
|
-
from random import random
|
4
|
+
from random import random, randrange
|
5
5
|
from packaging import version
|
6
6
|
|
7
7
|
import torch
|
@@ -12,6 +12,7 @@ from torch.amp import autocast
|
|
12
12
|
|
13
13
|
from functools import partial, wraps
|
14
14
|
from collections import namedtuple
|
15
|
+
from contextlib import nullcontext
|
15
16
|
from dataclasses import dataclass
|
16
17
|
from typing import List, Dict, Tuple, Callable
|
17
18
|
|
@@ -1987,7 +1988,9 @@ class TransformerWrapper(Module):
|
|
1987
1988
|
use_abs_pos_emb = True,
|
1988
1989
|
scaled_sinu_pos_emb = False,
|
1989
1990
|
l2norm_embed = False,
|
1990
|
-
|
1991
|
+
recycling = False, # from Jumper et al. - Alphafold2
|
1992
|
+
train_max_recycle_steps = 4, # saw a benefit for language modeling up to 3 recycling steps, so let's default this to 4
|
1993
|
+
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
1991
1994
|
attn_z_loss_weight = 1e-4,
|
1992
1995
|
average_pool_embed = False,
|
1993
1996
|
use_cls_token = False,
|
@@ -2044,6 +2047,13 @@ class TransformerWrapper(Module):
|
|
2044
2047
|
|
2045
2048
|
assert at_most_one_of(average_pool_embed, use_cls_token)
|
2046
2049
|
|
2050
|
+
# maybe recycling
|
2051
|
+
|
2052
|
+
self.recycling = recycling
|
2053
|
+
self.recycled_proj = nn.Linear(dim, dim, bias = False) if recycling else None
|
2054
|
+
|
2055
|
+
self.train_max_recycle_steps = train_max_recycle_steps
|
2056
|
+
|
2047
2057
|
# classic cls token from the bert days
|
2048
2058
|
|
2049
2059
|
self.cls_token = None
|
@@ -2087,7 +2097,7 @@ class TransformerWrapper(Module):
|
|
2087
2097
|
|
2088
2098
|
# whether can do cached kv decoding
|
2089
2099
|
|
2090
|
-
self.can_cache_kv = self.num_memory_tokens == 0
|
2100
|
+
self.can_cache_kv = self.num_memory_tokens == 0 and not recycling
|
2091
2101
|
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
|
2092
2102
|
|
2093
2103
|
def init_(self):
|
@@ -2110,6 +2120,7 @@ class TransformerWrapper(Module):
|
|
2110
2120
|
return_attn = False,
|
2111
2121
|
mems = None,
|
2112
2122
|
mem_masks = None,
|
2123
|
+
recycle_steps = None,
|
2113
2124
|
pos = None,
|
2114
2125
|
prepend_embeds = None,
|
2115
2126
|
prepend_mask = None,
|
@@ -2215,11 +2226,37 @@ class TransformerWrapper(Module):
|
|
2215
2226
|
if exists(mem_every):
|
2216
2227
|
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
|
2217
2228
|
|
2229
|
+
# handle maybe shifting of memories
|
2230
|
+
|
2218
2231
|
if self.shift_mem_down and exists(mems):
|
2219
2232
|
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
2220
2233
|
mems = [*mems_r, *mems_l]
|
2221
2234
|
|
2222
|
-
|
2235
|
+
# attention layers
|
2236
|
+
|
2237
|
+
if not self.recycling:
|
2238
|
+
# regular
|
2239
|
+
|
2240
|
+
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
2241
|
+
|
2242
|
+
else:
|
2243
|
+
# recycling
|
2244
|
+
|
2245
|
+
recycle_steps = default(recycle_steps, (randrange(self.train_max_recycle_steps) + 1) if self.training else None)
|
2246
|
+
assert exists(recycle_steps) and recycle_steps > 0, '`recycle_steps` must be provided on forward if recycling is turned on and not training'
|
2247
|
+
|
2248
|
+
for i in range(recycle_steps):
|
2249
|
+
first_step = i == 0
|
2250
|
+
last_step = i == (recycle_steps - 1)
|
2251
|
+
|
2252
|
+
context = nullcontext if last_step else torch.no_grad
|
2253
|
+
|
2254
|
+
with context():
|
2255
|
+
maybe_recycled = self.recycled_proj(attended.detach()) if not first_step else 0.
|
2256
|
+
|
2257
|
+
attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
2258
|
+
|
2259
|
+
x = attended
|
2223
2260
|
|
2224
2261
|
# handle memories post-attention
|
2225
2262
|
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=iib15Squ9VE7tLpb4Z4_Hq_hi7dZhPNR_xPtC9BzMrE,82321
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.36.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.36.0.dist-info/METADATA,sha256=YKcnT5T0UkZxwpP72cPfx9RN0SVoBYy0e6Xo581YCE0,661
|
13
|
+
x_transformers-1.36.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.36.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.36.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|