x-transformers 1.35.2__py3-none-any.whl → 1.36.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +47 -5
- {x_transformers-1.35.2.dist-info → x_transformers-1.36.0.dist-info}/METADATA +1 -1
- {x_transformers-1.35.2.dist-info → x_transformers-1.36.0.dist-info}/RECORD +6 -6
- {x_transformers-1.35.2.dist-info → x_transformers-1.36.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.35.2.dist-info → x_transformers-1.36.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.35.2.dist-info → x_transformers-1.36.0.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import math
|
4
|
-
from random import random
|
4
|
+
from random import random, randrange
|
5
5
|
from packaging import version
|
6
6
|
|
7
7
|
import torch
|
@@ -12,6 +12,7 @@ from torch.amp import autocast
|
|
12
12
|
|
13
13
|
from functools import partial, wraps
|
14
14
|
from collections import namedtuple
|
15
|
+
from contextlib import nullcontext
|
15
16
|
from dataclasses import dataclass
|
16
17
|
from typing import List, Dict, Tuple, Callable
|
17
18
|
|
@@ -920,6 +921,7 @@ class Attention(Module):
|
|
920
921
|
kv_heads = None,
|
921
922
|
shared_kv = False,
|
922
923
|
value_dim_head = None,
|
924
|
+
dim_out = None,
|
923
925
|
tensor_product = False, # https://arxiv.org/abs/2208.06061
|
924
926
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
925
927
|
rotary_embed_values = False,
|
@@ -1057,7 +1059,11 @@ class Attention(Module):
|
|
1057
1059
|
# attention on attention
|
1058
1060
|
|
1059
1061
|
self.attn_on_attn = on_attn
|
1060
|
-
|
1062
|
+
|
1063
|
+
# output dimension by default same as input, but can be overridden
|
1064
|
+
|
1065
|
+
dim_out = default(dim_out, dim)
|
1066
|
+
self.to_out = nn.Sequential(nn.Linear(out_dim, dim_out * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim_out, bias = False)
|
1061
1067
|
|
1062
1068
|
# whether to rotate positions into values, for absolute positions in addition to relative
|
1063
1069
|
|
@@ -1982,7 +1988,9 @@ class TransformerWrapper(Module):
|
|
1982
1988
|
use_abs_pos_emb = True,
|
1983
1989
|
scaled_sinu_pos_emb = False,
|
1984
1990
|
l2norm_embed = False,
|
1985
|
-
|
1991
|
+
recycling = False, # from Jumper et al. - Alphafold2
|
1992
|
+
train_max_recycle_steps = 4, # saw a benefit for language modeling up to 3 recycling steps, so let's default this to 4
|
1993
|
+
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
1986
1994
|
attn_z_loss_weight = 1e-4,
|
1987
1995
|
average_pool_embed = False,
|
1988
1996
|
use_cls_token = False,
|
@@ -2039,6 +2047,13 @@ class TransformerWrapper(Module):
|
|
2039
2047
|
|
2040
2048
|
assert at_most_one_of(average_pool_embed, use_cls_token)
|
2041
2049
|
|
2050
|
+
# maybe recycling
|
2051
|
+
|
2052
|
+
self.recycling = recycling
|
2053
|
+
self.recycled_proj = nn.Linear(dim, dim, bias = False) if recycling else None
|
2054
|
+
|
2055
|
+
self.train_max_recycle_steps = train_max_recycle_steps
|
2056
|
+
|
2042
2057
|
# classic cls token from the bert days
|
2043
2058
|
|
2044
2059
|
self.cls_token = None
|
@@ -2082,7 +2097,7 @@ class TransformerWrapper(Module):
|
|
2082
2097
|
|
2083
2098
|
# whether can do cached kv decoding
|
2084
2099
|
|
2085
|
-
self.can_cache_kv = self.num_memory_tokens == 0
|
2100
|
+
self.can_cache_kv = self.num_memory_tokens == 0 and not recycling
|
2086
2101
|
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
|
2087
2102
|
|
2088
2103
|
def init_(self):
|
@@ -2105,6 +2120,7 @@ class TransformerWrapper(Module):
|
|
2105
2120
|
return_attn = False,
|
2106
2121
|
mems = None,
|
2107
2122
|
mem_masks = None,
|
2123
|
+
recycle_steps = None,
|
2108
2124
|
pos = None,
|
2109
2125
|
prepend_embeds = None,
|
2110
2126
|
prepend_mask = None,
|
@@ -2210,11 +2226,37 @@ class TransformerWrapper(Module):
|
|
2210
2226
|
if exists(mem_every):
|
2211
2227
|
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
|
2212
2228
|
|
2229
|
+
# handle maybe shifting of memories
|
2230
|
+
|
2213
2231
|
if self.shift_mem_down and exists(mems):
|
2214
2232
|
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
2215
2233
|
mems = [*mems_r, *mems_l]
|
2216
2234
|
|
2217
|
-
|
2235
|
+
# attention layers
|
2236
|
+
|
2237
|
+
if not self.recycling:
|
2238
|
+
# regular
|
2239
|
+
|
2240
|
+
attended, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
2241
|
+
|
2242
|
+
else:
|
2243
|
+
# recycling
|
2244
|
+
|
2245
|
+
recycle_steps = default(recycle_steps, (randrange(self.train_max_recycle_steps) + 1) if self.training else None)
|
2246
|
+
assert exists(recycle_steps) and recycle_steps > 0, '`recycle_steps` must be provided on forward if recycling is turned on and not training'
|
2247
|
+
|
2248
|
+
for i in range(recycle_steps):
|
2249
|
+
first_step = i == 0
|
2250
|
+
last_step = i == (recycle_steps - 1)
|
2251
|
+
|
2252
|
+
context = nullcontext if last_step else torch.no_grad
|
2253
|
+
|
2254
|
+
with context():
|
2255
|
+
maybe_recycled = self.recycled_proj(attended.detach()) if not first_step else 0.
|
2256
|
+
|
2257
|
+
attended, intermediates = self.attn_layers(x + maybe_recycled, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
2258
|
+
|
2259
|
+
x = attended
|
2218
2260
|
|
2219
2261
|
# handle memories post-attention
|
2220
2262
|
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=iib15Squ9VE7tLpb4Z4_Hq_hi7dZhPNR_xPtC9BzMrE,82321
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.36.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.36.0.dist-info/METADATA,sha256=YKcnT5T0UkZxwpP72cPfx9RN0SVoBYy0e6Xo581YCE0,661
|
13
|
+
x_transformers-1.36.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.36.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.36.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|