x-transformers 1.32.4__py3-none-any.whl → 1.32.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/__init__.py +2 -0
- x_transformers/multi_input.py +292 -0
- x_transformers/x_transformers.py +3 -0
- {x_transformers-1.32.4.dist-info → x_transformers-1.32.5.dist-info}/METADATA +1 -1
- {x_transformers-1.32.4.dist-info → x_transformers-1.32.5.dist-info}/RECORD +8 -7
- {x_transformers-1.32.4.dist-info → x_transformers-1.32.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.4.dist-info → x_transformers-1.32.5.dist-info}/WHEEL +0 -0
- {x_transformers-1.32.4.dist-info → x_transformers-1.32.5.dist-info}/top_level.txt +0 -0
x_transformers/__init__.py
CHANGED
@@ -0,0 +1,292 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn, Tensor
|
3
|
+
from torch.nn import Module, ModuleDict
|
4
|
+
import torch.nn.functional as F
|
5
|
+
|
6
|
+
from typing import Dict
|
7
|
+
|
8
|
+
from einops import pack, repeat, unpack
|
9
|
+
|
10
|
+
from x_transformers.x_transformers import (
|
11
|
+
AttentionLayers,
|
12
|
+
ScaledSinusoidalEmbedding,
|
13
|
+
AbsolutePositionalEmbedding,
|
14
|
+
LayerIntermediates,
|
15
|
+
LayerNorm,
|
16
|
+
always,
|
17
|
+
pad_at_dim,
|
18
|
+
is_empty
|
19
|
+
)
|
20
|
+
|
21
|
+
# helper functions
|
22
|
+
|
23
|
+
def exists(val):
|
24
|
+
return val is not None
|
25
|
+
|
26
|
+
def default(val, d):
|
27
|
+
if exists(val):
|
28
|
+
return val
|
29
|
+
return d() if callable(d) else d
|
30
|
+
|
31
|
+
|
32
|
+
class MultiInputTransformerWrapper(Module):
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
*,
|
36
|
+
num_tokens: Dict[str, int] = dict(),
|
37
|
+
max_seq_len,
|
38
|
+
attn_layers: AttentionLayers,
|
39
|
+
emb_dim = None,
|
40
|
+
max_mem_len = 0,
|
41
|
+
shift_mem_down = 0,
|
42
|
+
emb_dropout = 0.,
|
43
|
+
post_emb_norm = False,
|
44
|
+
num_memory_tokens = None,
|
45
|
+
memory_tokens_interspersed_every = None,
|
46
|
+
tie_embedding = False,
|
47
|
+
logits_dim = None,
|
48
|
+
return_only_embed = False,
|
49
|
+
num_output_heads = 1,
|
50
|
+
use_abs_pos_emb = True,
|
51
|
+
scaled_sinu_pos_emb = False,
|
52
|
+
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
53
|
+
attn_z_loss_weight = 1e-4,
|
54
|
+
):
|
55
|
+
super().__init__()
|
56
|
+
|
57
|
+
dim = attn_layers.dim
|
58
|
+
emb_dim = default(emb_dim, dim)
|
59
|
+
self.emb_dim = emb_dim
|
60
|
+
|
61
|
+
self.max_seq_len = max_seq_len
|
62
|
+
self.max_mem_len = max_mem_len
|
63
|
+
self.shift_mem_down = shift_mem_down
|
64
|
+
|
65
|
+
no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)
|
66
|
+
|
67
|
+
if no_abs_pos_emb:
|
68
|
+
self.pos_emb = always(0)
|
69
|
+
elif scaled_sinu_pos_emb:
|
70
|
+
self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
|
71
|
+
else:
|
72
|
+
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len)
|
73
|
+
|
74
|
+
# additional embeddings - say type embedding from BERT
|
75
|
+
|
76
|
+
self.embeds = ModuleDict({f'{name}_embed': nn.Embedding(one_num_tokens, emb_dim) for name, one_num_tokens in num_tokens.items()})
|
77
|
+
|
78
|
+
# fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
|
79
|
+
|
80
|
+
self.emb_frac_gradient = emb_frac_gradient
|
81
|
+
|
82
|
+
self.post_emb_norm = LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
|
83
|
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
84
|
+
|
85
|
+
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
86
|
+
self.attn_layers = attn_layers
|
87
|
+
|
88
|
+
assert num_output_heads > 0
|
89
|
+
|
90
|
+
# output head, usually to logits of num_tokens
|
91
|
+
|
92
|
+
logits_dim = default(logits_dim, num_tokens)
|
93
|
+
|
94
|
+
self.has_multiple_heads = False
|
95
|
+
|
96
|
+
if return_only_embed:
|
97
|
+
self.to_logits = None
|
98
|
+
elif tie_embedding:
|
99
|
+
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
100
|
+
elif num_output_heads > 1:
|
101
|
+
self.has_multiple_heads = True
|
102
|
+
self.to_logits = ModuleList([nn.Linear(dim, logits_dim, bias = False) for _ in range(num_output_heads)])
|
103
|
+
else:
|
104
|
+
self.to_logits = nn.Linear(dim, logits_dim, bias = False)
|
105
|
+
|
106
|
+
# memory tokens (like [cls]) from Memory Transformers paper
|
107
|
+
|
108
|
+
num_memory_tokens = default(num_memory_tokens, 0)
|
109
|
+
self.num_memory_tokens = num_memory_tokens
|
110
|
+
if num_memory_tokens > 0:
|
111
|
+
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
112
|
+
|
113
|
+
self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
|
114
|
+
|
115
|
+
# whether can do cached kv decoding
|
116
|
+
|
117
|
+
self.can_cache_kv = self.num_memory_tokens == 0
|
118
|
+
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
|
119
|
+
|
120
|
+
def forward(
|
121
|
+
self,
|
122
|
+
x: Dict[str, Tensor],
|
123
|
+
return_embeddings = False,
|
124
|
+
return_logits_and_embeddings = False,
|
125
|
+
return_intermediates = False,
|
126
|
+
mask = None,
|
127
|
+
return_mems = False,
|
128
|
+
return_attn = False,
|
129
|
+
mems = None,
|
130
|
+
mem_masks = None,
|
131
|
+
pos = None,
|
132
|
+
prepend_embeds = None,
|
133
|
+
prepend_mask = None,
|
134
|
+
sum_embeds = None,
|
135
|
+
return_attn_z_loss = False,
|
136
|
+
attn_z_loss_weight = 1e-4,
|
137
|
+
seq_start_pos = None,
|
138
|
+
cache: LayerIntermediates | None = None,
|
139
|
+
**kwargs
|
140
|
+
):
|
141
|
+
assert not is_empty(x)
|
142
|
+
first_input = list(x.values())[0]
|
143
|
+
|
144
|
+
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *first_input.shape, first_input.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
|
145
|
+
|
146
|
+
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
147
|
+
return_embeddings = return_embeddings | (not exists(self.to_logits))
|
148
|
+
|
149
|
+
# token embedding
|
150
|
+
|
151
|
+
assert len(x) == len(self.embeds)
|
152
|
+
|
153
|
+
token_emb = 0.
|
154
|
+
|
155
|
+
for name, embed_id in x.items():
|
156
|
+
embed_key = f'{name}_embed'
|
157
|
+
|
158
|
+
assert embed_key in self.embeds
|
159
|
+
embed = self.embeds[embed_key](embed_id)
|
160
|
+
|
161
|
+
token_emb = token_emb + embed
|
162
|
+
|
163
|
+
# absolute positional embedding
|
164
|
+
|
165
|
+
external_pos_emb = exists(pos) and pos.dtype != torch.long
|
166
|
+
pos_emb = self.pos_emb(first_input, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
|
167
|
+
|
168
|
+
token_emb = token_emb + pos_emb
|
169
|
+
|
170
|
+
# for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
|
171
|
+
|
172
|
+
if exists(sum_embeds):
|
173
|
+
token_emb = token_emb + sum_embeds
|
174
|
+
|
175
|
+
# set back to `x`
|
176
|
+
|
177
|
+
x = token_emb
|
178
|
+
|
179
|
+
# post embedding norm, purportedly leads to greater stabilization
|
180
|
+
|
181
|
+
x = self.post_emb_norm(x)
|
182
|
+
|
183
|
+
# whether to append embeds, as in PaLI, for image embeddings
|
184
|
+
|
185
|
+
if exists(prepend_embeds):
|
186
|
+
prepend_seq, prepend_dim = prepend_embeds.shape[1:]
|
187
|
+
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
|
188
|
+
|
189
|
+
x = torch.cat((prepend_embeds, x), dim = -2)
|
190
|
+
|
191
|
+
if exists(prepend_mask) or exists(mask):
|
192
|
+
mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
|
193
|
+
prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
|
194
|
+
|
195
|
+
mask = torch.cat((prepend_mask, mask), dim = -1)
|
196
|
+
|
197
|
+
# whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
|
198
|
+
|
199
|
+
if emb_frac_gradient < 1:
|
200
|
+
assert emb_frac_gradient > 0
|
201
|
+
x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
|
202
|
+
|
203
|
+
# embedding dropout
|
204
|
+
|
205
|
+
x = self.emb_dropout(x)
|
206
|
+
|
207
|
+
x = self.project_emb(x)
|
208
|
+
|
209
|
+
if has_memory_tokens:
|
210
|
+
mem_every = self.memory_tokens_interspersed_every
|
211
|
+
|
212
|
+
if exists(mem_every):
|
213
|
+
assert mem_every > 0
|
214
|
+
assert isinstance(self.attn_layers, Decoder), 'only for decoder'
|
215
|
+
next_seq_len = math.ceil(n / mem_every) * mem_every
|
216
|
+
|
217
|
+
x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
|
218
|
+
x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
|
219
|
+
|
220
|
+
mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
|
221
|
+
x, mem_packed_shape = pack((mem, x), 'b * d')
|
222
|
+
|
223
|
+
# auto-handle masking after appending memory tokens
|
224
|
+
if not exists(mem_every) and exists(mask):
|
225
|
+
mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
|
226
|
+
|
227
|
+
if exists(mem_every):
|
228
|
+
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
|
229
|
+
|
230
|
+
if self.shift_mem_down and exists(mems):
|
231
|
+
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
|
232
|
+
mems = [*mems_r, *mems_l]
|
233
|
+
|
234
|
+
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
|
235
|
+
|
236
|
+
# handle memories post-attention
|
237
|
+
|
238
|
+
if has_memory_tokens:
|
239
|
+
if exists(mem_every):
|
240
|
+
x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
|
241
|
+
|
242
|
+
mem, x = unpack(x, mem_packed_shape, 'b * d')
|
243
|
+
|
244
|
+
intermediates.memory_tokens = mem
|
245
|
+
|
246
|
+
if exists(mem_every):
|
247
|
+
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
|
248
|
+
|
249
|
+
x = x[:, :n]
|
250
|
+
|
251
|
+
# projecting to logits
|
252
|
+
|
253
|
+
if not return_embeddings:
|
254
|
+
if self.has_multiple_heads:
|
255
|
+
logits = tuple(fn(x) for fn in self.to_logits)
|
256
|
+
else:
|
257
|
+
logits = self.to_logits(x)
|
258
|
+
|
259
|
+
# different returns
|
260
|
+
|
261
|
+
if return_logits_and_embeddings:
|
262
|
+
out = (logits, x)
|
263
|
+
elif return_embeddings:
|
264
|
+
out = x
|
265
|
+
else:
|
266
|
+
out = logits
|
267
|
+
|
268
|
+
# aux loss
|
269
|
+
|
270
|
+
if return_attn_z_loss:
|
271
|
+
pre_softmax_attns = [t.pre_softmax_attn for t in intermediates.attn_intermediates]
|
272
|
+
intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
|
273
|
+
return_intermediates = True
|
274
|
+
|
275
|
+
if return_mems:
|
276
|
+
hiddens = intermediates.hiddens
|
277
|
+
new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
|
278
|
+
new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
|
279
|
+
|
280
|
+
if not return_intermediates:
|
281
|
+
return out, new_mems
|
282
|
+
|
283
|
+
intermediates.mems = new_mems
|
284
|
+
|
285
|
+
if return_intermediates:
|
286
|
+
return out, intermediates
|
287
|
+
|
288
|
+
if return_attn:
|
289
|
+
attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates]
|
290
|
+
return out, attn_maps
|
291
|
+
|
292
|
+
return out
|
x_transformers/x_transformers.py
CHANGED
@@ -1,14 +1,15 @@
|
|
1
|
-
x_transformers/__init__.py,sha256
|
1
|
+
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
2
|
x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
|
+
x_transformers/multi_input.py,sha256=S4hJffAhN7Ubbd92_9YmZ0VOzdu1u-1e2IgWZxYz9BU,9758
|
6
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=5DHbYgx0RPg9QHvfBs2qHWrtn4Jji-q0d1MRBbcRPR8,76696
|
8
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.32.
|
11
|
-
x_transformers-1.32.
|
12
|
-
x_transformers-1.32.
|
13
|
-
x_transformers-1.32.
|
14
|
-
x_transformers-1.32.
|
11
|
+
x_transformers-1.32.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.32.5.dist-info/METADATA,sha256=HSujd9BjY5lAIY5QxDqf4JDecEOXrvWQoe2zjpRaGG0,661
|
13
|
+
x_transformers-1.32.5.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
14
|
+
x_transformers-1.32.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.32.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|