x-transformers 1.32.2__py3-none-any.whl → 1.32.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/__init__.py +2 -0
- x_transformers/multi_input.py +292 -0
- x_transformers/x_transformers.py +13 -1
- {x_transformers-1.32.2.dist-info → x_transformers-1.32.5.dist-info}/METADATA +1 -1
- {x_transformers-1.32.2.dist-info → x_transformers-1.32.5.dist-info}/RECORD +8 -7
- {x_transformers-1.32.2.dist-info → x_transformers-1.32.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.2.dist-info → x_transformers-1.32.5.dist-info}/WHEEL +0 -0
- {x_transformers-1.32.2.dist-info → x_transformers-1.32.5.dist-info}/top_level.txt +0 -0
x_transformers/__init__.py
CHANGED
@@ -0,0 +1,292 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn, Tensor
|
3
|
+
from torch.nn import Module, ModuleDict
|
4
|
+
import torch.nn.functional as F
|
5
|
+
|
6
|
+
from typing import Dict
|
7
|
+
|
8
|
+
from einops import pack, repeat, unpack
|
9
|
+
|
10
|
+
from x_transformers.x_transformers import (
|
11
|
+
AttentionLayers,
|
12
|
+
ScaledSinusoidalEmbedding,
|
13
|
+
AbsolutePositionalEmbedding,
|
14
|
+
LayerIntermediates,
|
15
|
+
LayerNorm,
|
16
|
+
always,
|
17
|
+
pad_at_dim,
|
18
|
+
is_empty
|
19
|
+
)
|
20
|
+
|
21
|
+
# helper functions
|
22
|
+
|
23
|
+
def exists(val):
|
24
|
+
return val is not None
|
25
|
+
|
26
|
+
def default(val, d):
|
27
|
+
if exists(val):
|
28
|
+
return val
|
29
|
+
return d() if callable(d) else d
|
30
|
+
|
31
|
+
|
32
|
+
class MultiInputTransformerWrapper(Module):
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
*,
|
36
|
+
num_tokens: Dict[str, int] = dict(),
|
37
|
+
max_seq_len,
|
38
|
+
attn_layers: AttentionLayers,
|
39
|
+
emb_dim = None,
|
40
|
+
max_mem_len = 0,
|
41
|
+
shift_mem_down = 0,
|
42
|
+
emb_dropout = 0.,
|
43
|
+
post_emb_norm = False,
|
44
|
+
num_memory_tokens = None,
|
45
|
+
memory_tokens_interspersed_every = None,
|
46
|
+
tie_embedding = False,
|
47
|
+
logits_dim = None,
|
48
|
+
return_only_embed = False,
|
49
|
+
num_output_heads = 1,
|
50
|
+
use_abs_pos_emb = True,
|
51
|
+
scaled_sinu_pos_emb = False,
|
52
|
+
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
53
|
+
attn_z_loss_weight = 1e-4,
|
54
|
+
):
|
55
|
+
super().__init__()
|
56
|
+
|
57
|
+
dim = attn_layers.dim
|
58
|
+
emb_dim = default(emb_dim, dim)
|
59
|
+
self.emb_dim = emb_dim
|
60
|
+
|
61
|
+
self.max_seq_len = max_seq_len
|
62
|
+
self.max_mem_len = max_mem_len
|
63
|
+
self.shift_mem_down = shift_mem_down
|
64
|
+
|
65
|
+
no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)
|
66
|
+
|
67
|
+
if no_abs_pos_emb:
|
68
|
+
self.pos_emb = always(0)
|
69
|
+
elif scaled_sinu_pos_emb:
|
70
|
+
self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
|
71
|
+
else:
|
72
|
+
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)
|
73
|
+
|
74
|
+
# additional embeddings - say type embedding from BERT
|
75
|
+
|
76
|
+
self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(one_num_tokens, emb_dim) for name, one_num_tokens in num_tokens.items()})
|
77
|
+
|
78
|
+
# fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
|
79
|
+
|
80
|
+
self.emb_frac_gradient = emb_frac_gradient
|
81
|
+
|
82
|
+
self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
|
83
|
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
84
|
+
|
85
|
+
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
86
|
+
self.attn_layers = attn_layers
|
87
|
+
|
88
|
+
assert num_output_heads > 0
|
89
|
+
|
90
|
+
# output head, usually to logits of num_tokens
|
91
|
+
|
92
|
+
logits_dim = default(logits_dim, num_tokens)
|
93
|
+
|
94
|
+
self.has_multiple_heads = False
|
95
|
+
|
96
|
+
if return_only_embed:
|
97
|
+
self.to_logits = None
|
98
|
+
elif tie_embedding:
|
99
|
+
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
100
|
+
elif num_output_heads > 1:
|
101
|
+
self.has_multiple_heads = True
|
102
|
+
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
|
103
|
+
else:
|
104
|
+
self.to_logits = nn.Linear(dim, logits_dim, bias = False)
|
105
|
+
|
106
|
+
# memory tokens (like [cls]) from Memory Transformers paper
|
107
|
+
|
108
|
+
num_memory_tokens = default(num_memory_tokens, 0)
|
109
|
+
self.num_memory_tokens = num_memory_tokens
|
110
|
+
if num_memory_tokens > 0:
|
111
|
+
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
112
|
+
|
113
|
+
self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
|
114
|
+
|
115
|
+
# whether can do cached kv decoding
|
116
|
+
|
117
|
+
self.can_cache_kv = self.num_memory_tokens == 0
|
118
|
+
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
|
119
|
+
|
120
|
+
def forward(
|
121
|
+
self,
|
122
|
+
x: Dict[str, Tensor],
|
123
|
+
return_embeddings = False,
|
124
|
+
return_logits_and_embeddings = False,
|
125
|
+
return_intermediates = False,
|
126
|
+
mask = None,
|
127
|
+
return_mems = False,
|
128
|
+
return_attn = False,
|
129
|
+
mems = None,
|
130
|
+
mem_masks = None,
|
131
|
+
pos = None,
|
132
|
+
prepend_embeds = None,
|
133
|
+
prepend_mask = None,
|
134
|
+
sum_embeds = None,
|
135
|
+
return_attn_z_loss = False,
|
136
|
+
attn_z_loss_weight = 1e-4,
|
137
|
+
seq_start_pos = None,
|
138
|
+
cache: LayerIntermediates | None = None,
|
139
|
+
**kwargs
|
140
|
+
):
|
141
|
+
assert not is_empty(x)
|
142
|
+
first_input = list(x.values())[0]
|
143
|
+
|
144
|
+
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *first_input.shape, first_input.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
|
145
|
+
|
146
|
+
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
147
|
+
return_embeddings = return_embeddings | (not exists(self.to_logits))
|
148
|
+
|
149
|
+
# token embedding
|
150
|
+
|
151
|
+
assert len(x) == len(self.embeds)
|
152
|
+
|
153
|
+
token_emb = 0.
|
154
|
+
|
155
|
+
for name, embed_id in x.items():
|
156
|
+
embed_key = f'{name}_embed'
|
157
|
+
|
158
|
+
assert embed_key in self.embeds
|
159
|
+
embed = self.embeds[embed_key](embed_id)
|
160
|
+
|
161
|
+
token_emb = token_emb + embed
|
162
|
+
|
163
|
+
# absolute positional embedding
|
164
|
+
|
165
|
+
external_pos_emb = exists(pos) and pos.dtype != torch.long
|
166
|
+
pos_emb = self.pos_emb(first_input, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
|
167
|
+
|
168
|
+
token_emb = token_emb + pos_emb
|
169
|
+
|
170
|
+
# for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
|
171
|
+
|
172
|
+
if exists(sum_embeds):
|
173
|
+
token_emb = token_emb + sum_embeds
|
174
|
+
|
175
|
+
# set back to `x`
|
176
|
+
|
177
|
+
x = token_emb
|
178
|
+
|
179
|
+
# post embedding norm, purportedly leads to greater stabilization
|
180
|
+
|
181
|
+
x = self.post_emb_norm(x)
|
182
|
+
|
183
|
+
# whether to append embeds, as in PaLI, for image embeddings
|
184
|
+
|
185
|
+
if exists(prepend_embeds):
|
186
|
+
prepend_seq, prepend_dim = prepend_embeds.shape[1:]
|
187
|
+
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
|
188
|
+
|
189
|
+
x = torch.cat((prepend_embeds, x), dim = -2)
|
190
|
+
|
191
|
+
if exists(prepend_mask) or exists(mask):
|
192
|
+
mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
|
193
|
+
prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
|
194
|
+
|
195
|
+
mask = torch.cat((prepend_mask, mask), dim = -1)
|
196
|
+
|
197
|
+
# whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
|
198
|
+
|
199
|
+
if emb_frac_gradient < 1:
|
200
|
+
assert emb_frac_gradient > 0
|
201
|
+
x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
|
202
|
+
|
203
|
+
# embedding dropout
|
204
|
+
|
205
|
+
x = self.emb_dropout(x)
|
206
|
+
|
207
|
+
x = self.project_emb(x)
|
208
|
+
|
209
|
+
if has_memory_tokens:
|
210
|
+
mem_every = self.memory_tokens_interspersed_every
|
211
|
+
|
212
|
+
if exists(mem_every):
|
213
|
+
assert mem_every > 0
|
214
|
+
assert isinstance(self.attn_layers, Decoder), 'only for decoder'
|
215
|
+
next_seq_len = math.ceil(n / mem_every) * mem_every
|
216
|
+
|
217
|
+
x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
|
218
|
+
x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
|
219
|
+
|
220
|
+
mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
|
221
|
+
x, mem_packed_shape = pack((mem, x), 'b * d')
|
222
|
+
|
223
|
+
# auto-handle masking after appending memory tokens
|
224
|
+
if not exists(mem_every) and exists(mask):
|
225
|
+
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
|
226
|
+
|
227
|
+
if exists(mem_every):
|
228
|
+
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
|
229
|
+
|
230
|
+
if self.shift_mem_down and exists(mems):
|
231
|
+
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
232
|
+
mems = [*mems_r, *mems_l]
|
233
|
+
|
234
|
+
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
235
|
+
|
236
|
+
# handle memories post-attention
|
237
|
+
|
238
|
+
if has_memory_tokens:
|
239
|
+
if exists(mem_every):
|
240
|
+
x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
|
241
|
+
|
242
|
+
mem, x = unpack(x, mem_packed_shape, 'b * d')
|
243
|
+
|
244
|
+
intermediates.memory_tokens = mem
|
245
|
+
|
246
|
+
if exists(mem_every):
|
247
|
+
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
|
248
|
+
|
249
|
+
x = x[:, :n]
|
250
|
+
|
251
|
+
# projecting to logits
|
252
|
+
|
253
|
+
if not return_embeddings:
|
254
|
+
if self.has_multiple_heads:
|
255
|
+
logits = tuple(fn(x) for fn in self.to_logits)
|
256
|
+
else:
|
257
|
+
logits = self.to_logits(x)
|
258
|
+
|
259
|
+
# different returns
|
260
|
+
|
261
|
+
if return_logits_and_embeddings:
|
262
|
+
out = (logits, x)
|
263
|
+
elif return_embeddings:
|
264
|
+
out = x
|
265
|
+
else:
|
266
|
+
out = logits
|
267
|
+
|
268
|
+
# aux loss
|
269
|
+
|
270
|
+
if return_attn_z_loss:
|
271
|
+
pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
|
272
|
+
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
|
273
|
+
return_intermediates = True
|
274
|
+
|
275
|
+
if return_mems:
|
276
|
+
hiddens = intermediates.hiddens
|
277
|
+
new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
|
278
|
+
new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
|
279
|
+
|
280
|
+
if not return_intermediates:
|
281
|
+
return out, new_mems
|
282
|
+
|
283
|
+
intermediates.mems = new_mems
|
284
|
+
|
285
|
+
if return_intermediates:
|
286
|
+
return out, intermediates
|
287
|
+
|
288
|
+
if return_attn:
|
289
|
+
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
290
|
+
return out, attn_maps
|
291
|
+
|
292
|
+
return out
|
x_transformers/x_transformers.py
CHANGED
@@ -51,6 +51,9 @@ def cast_tuple(val, depth):
|
|
51
51
|
def divisible_by(num, den):
|
52
52
|
return (num % den) == 0
|
53
53
|
|
54
|
+
def is_empty(x):
|
55
|
+
return len(x) == 0
|
56
|
+
|
54
57
|
def maybe(fn):
|
55
58
|
@wraps(fn)
|
56
59
|
def inner(x, *args, **kwargs):
|
@@ -1899,6 +1902,7 @@ class TransformerWrapper(Module):
|
|
1899
1902
|
memory_tokens_interspersed_every = None,
|
1900
1903
|
tie_embedding = False,
|
1901
1904
|
logits_dim = None,
|
1905
|
+
return_only_embed = False,
|
1902
1906
|
num_output_heads = 1,
|
1903
1907
|
use_abs_pos_emb = True,
|
1904
1908
|
scaled_sinu_pos_emb = False,
|
@@ -1948,13 +1952,17 @@ class TransformerWrapper(Module):
|
|
1948
1952
|
|
1949
1953
|
self.init_()
|
1950
1954
|
|
1955
|
+
assert num_output_heads > 0
|
1956
|
+
|
1951
1957
|
# output head, usually to logits of num_tokens
|
1952
1958
|
|
1953
1959
|
logits_dim = default(logits_dim, num_tokens)
|
1954
1960
|
|
1955
1961
|
self.has_multiple_heads = False
|
1956
1962
|
|
1957
|
-
if
|
1963
|
+
if return_only_embed:
|
1964
|
+
self.to_logits = None
|
1965
|
+
elif tie_embedding:
|
1958
1966
|
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
1959
1967
|
elif num_output_heads > 1:
|
1960
1968
|
self.has_multiple_heads = True
|
@@ -2008,7 +2016,9 @@ class TransformerWrapper(Module):
|
|
2008
2016
|
**kwargs
|
2009
2017
|
):
|
2010
2018
|
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
|
2019
|
+
|
2011
2020
|
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
2021
|
+
return_embeddings = return_embeddings | (not exists(self.to_logits))
|
2012
2022
|
|
2013
2023
|
# absolute positional embedding
|
2014
2024
|
|
@@ -2018,6 +2028,8 @@ class TransformerWrapper(Module):
|
|
2018
2028
|
|
2019
2029
|
# add additional embeddings
|
2020
2030
|
|
2031
|
+
assert not (exists(self.embeds) ^ (len(embed_ids) > 0)), '`embed_num_tokens` must be defined on `TransformerWrapper`'
|
2032
|
+
|
2021
2033
|
if exists(self.embeds):
|
2022
2034
|
assert len(embed_ids) == len(self.embeds)
|
2023
2035
|
|
@@ -1,14 +1,15 @@
|
|
1
|
-
x_transformers/__init__.py,sha256
|
1
|
+
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
2
|
x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
|
+
x_transformers/multi_input.py,sha256=S4hJffAhN7Ubbd92_9YmZ0VOzdu1u-1e2IgWZxYz9BU,9758
|
6
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=5DHbYgx0RPg9QHvfBs2qHWrtn4Jji-q0d1MRBbcRPR8,76696
|
8
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.32.
|
11
|
-
x_transformers-1.32.
|
12
|
-
x_transformers-1.32.
|
13
|
-
x_transformers-1.32.
|
14
|
-
x_transformers-1.32.
|
11
|
+
x_transformers-1.32.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.32.5.dist-info/METADATA,sha256=HSujd9BjY5lAIY5QxDqf4JDecEOXrvWQoe2zjpRaGG0,661
|
13
|
+
x_transformers-1.32.5.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
14
|
+
x_transformers-1.32.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.32.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|