x-transformers 2.4.9__py3-none-any.whl → 2.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +177 -1
- {x_transformers-2.4.9.dist-info → x_transformers-2.4.10.dist-info}/METADATA +1 -1
- {x_transformers-2.4.9.dist-info → x_transformers-2.4.10.dist-info}/RECORD +5 -5
- {x_transformers-2.4.9.dist-info → x_transformers-2.4.10.dist-info}/WHEEL +0 -0
- {x_transformers-2.4.9.dist-info → x_transformers-2.4.10.dist-info}/licenses/LICENSE +0 -0
@@ -8,7 +8,7 @@ from torch import nn, Tensor
|
|
8
8
|
from torch.nn import Module
|
9
9
|
import torch.nn.functional as F
|
10
10
|
|
11
|
-
from einops import rearrange, pack, unpack
|
11
|
+
from einops import rearrange, repeat, pack, unpack
|
12
12
|
|
13
13
|
def exists(val):
|
14
14
|
return val is not None
|
@@ -34,6 +34,21 @@ def eval_decorator(fn):
|
|
34
34
|
return out
|
35
35
|
return inner
|
36
36
|
|
37
|
+
# gumbel topk
|
38
|
+
|
39
|
+
def log(t, eps = 1e-20):
|
40
|
+
return t.clamp(min = eps).log()
|
41
|
+
|
42
|
+
def gumbel_noise(t):
|
43
|
+
return -log(-log(torch.rand_like(t)))
|
44
|
+
|
45
|
+
# function for modifying all the cached key / values
|
46
|
+
|
47
|
+
def modify_cached_kv(cache, fn):
|
48
|
+
for inter in cache.attn_intermediates:
|
49
|
+
if inter.layer_type == 'a':
|
50
|
+
inter.cached_kv = [fn(t) for t in inter.cached_kv]
|
51
|
+
|
37
52
|
# for variable lengthed prefixes
|
38
53
|
|
39
54
|
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
@@ -161,6 +176,167 @@ class AutoregressiveWrapper(Module):
|
|
161
176
|
self.add_continuous_pred_head = net.add_continuous_pred_head
|
162
177
|
self.next_embed_loss_weight = next_embed_loss_weight
|
163
178
|
|
179
|
+
@torch.no_grad()
|
180
|
+
@eval_decorator
|
181
|
+
def beam_search(
|
182
|
+
self,
|
183
|
+
prompts,
|
184
|
+
seq_len,
|
185
|
+
beams = 4,
|
186
|
+
eos_token = None,
|
187
|
+
temperature = 1.,
|
188
|
+
stochastic = False,
|
189
|
+
prompt_lens: Tensor | None = None,
|
190
|
+
filter_logits_fn: str | Callable = top_k,
|
191
|
+
restrict_to_max_seq_len = True,
|
192
|
+
filter_kwargs: dict = dict(),
|
193
|
+
cache_kv = True,
|
194
|
+
**kwargs
|
195
|
+
):
|
196
|
+
assert not exists(eos_token), 'eos token not supported yet'
|
197
|
+
|
198
|
+
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
199
|
+
|
200
|
+
prompts, packed_shape = pack([prompts], '* n')
|
201
|
+
|
202
|
+
batch, orig_seq_len = prompts.shape
|
203
|
+
|
204
|
+
# handle filter logits fn given as string
|
205
|
+
|
206
|
+
if isinstance(filter_logits_fn, str):
|
207
|
+
assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
|
208
|
+
|
209
|
+
filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
|
210
|
+
|
211
|
+
# handle variable lengthed prompts (prefixes)
|
212
|
+
|
213
|
+
seq_start_pos = None
|
214
|
+
if exists(prompt_lens):
|
215
|
+
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
|
216
|
+
seq_start_pos = orig_seq_len - prompt_lens
|
217
|
+
|
218
|
+
# output from which sampled tokens appended to
|
219
|
+
|
220
|
+
out = prompts
|
221
|
+
|
222
|
+
# kv caches
|
223
|
+
|
224
|
+
cache = None
|
225
|
+
|
226
|
+
should_cache = cache_kv and self.net.can_cache_kv
|
227
|
+
|
228
|
+
# scores for the beams
|
229
|
+
|
230
|
+
scores = torch.zeros((batch,), device = device)
|
231
|
+
|
232
|
+
batch_arange = torch.arange(batch, device = device)
|
233
|
+
|
234
|
+
# sampling up to seq_len
|
235
|
+
|
236
|
+
for i in range(seq_len):
|
237
|
+
is_first = i == 0
|
238
|
+
|
239
|
+
if restrict_to_max_seq_len:
|
240
|
+
max_len_exceeded = out.shape[-1] > max_seq_len
|
241
|
+
|
242
|
+
assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
|
243
|
+
|
244
|
+
x = out[:, -max_seq_len:]
|
245
|
+
|
246
|
+
if exists(cache):
|
247
|
+
modify_cached_kv(cache, lambda t: t[..., -(max_seq_len - 1):, :])
|
248
|
+
|
249
|
+
logits, new_cache = self.net(
|
250
|
+
x,
|
251
|
+
return_intermediates = True,
|
252
|
+
cache = cache,
|
253
|
+
seq_start_pos = seq_start_pos,
|
254
|
+
**kwargs
|
255
|
+
)
|
256
|
+
|
257
|
+
if should_cache:
|
258
|
+
cache = new_cache
|
259
|
+
|
260
|
+
logits = logits[:, -1]
|
261
|
+
|
262
|
+
# to add to the scores
|
263
|
+
|
264
|
+
log_probs = logits.log_softmax(dim = -1)
|
265
|
+
|
266
|
+
# maybe filter by top_k, top_p (nucleus) for stochastic beam search
|
267
|
+
|
268
|
+
if stochastic and not greedy:
|
269
|
+
logits = filter_logits_fn(logits, **filter_kwargs)
|
270
|
+
logits = (logits / temperature) + gumbel_noise(logits)
|
271
|
+
|
272
|
+
# (gumbel) topk
|
273
|
+
|
274
|
+
samples = logits.topk(beams, dim = -1).indices
|
275
|
+
|
276
|
+
# get the scores for keeping track of beams
|
277
|
+
|
278
|
+
next_scores = log_probs.gather(-1, samples)
|
279
|
+
|
280
|
+
# expand beam times
|
281
|
+
|
282
|
+
scores = repeat(scores, 'b -> b beams', beams = beams)
|
283
|
+
scores = scores + next_scores
|
284
|
+
|
285
|
+
out = repeat(out, 'b ... -> (b beams) ...', beams = beams)
|
286
|
+
samples = rearrange(samples, 'b beams -> (b beams) 1')
|
287
|
+
|
288
|
+
if should_cache:
|
289
|
+
modify_cached_kv(cache, lambda t: repeat(t, 'b ... -> (b beams) ...', beams = beams))
|
290
|
+
|
291
|
+
# concat sample
|
292
|
+
|
293
|
+
out = torch.cat((out, samples), dim=-1)
|
294
|
+
|
295
|
+
# sort by score and excise
|
296
|
+
# excise out the beams
|
297
|
+
|
298
|
+
scores = rearrange(scores, '(b prev_beams) next_beams -> b (prev_beams next_beams)', b = batch)
|
299
|
+
curr_num_beams = scores.shape[-1]
|
300
|
+
|
301
|
+
if curr_num_beams > beams:
|
302
|
+
scores, sort_indices = scores.sort(dim = -1, descending = True)
|
303
|
+
|
304
|
+
scores = scores[:, :beams]
|
305
|
+
top_beams_indices = sort_indices[:, :beams]
|
306
|
+
top_beams_indices = beams * batch_arange[:, None] + top_beams_indices
|
307
|
+
|
308
|
+
flattened_beam_indices = rearrange(top_beams_indices, 'b beams -> (b beams)')
|
309
|
+
|
310
|
+
out = out[flattened_beam_indices]
|
311
|
+
|
312
|
+
modify_cached_kv(cache, lambda t: t[flattened_beam_indices])
|
313
|
+
|
314
|
+
scores = rearrange(scores, 'b beams -> (b beams)')
|
315
|
+
|
316
|
+
if not exists(eos_token):
|
317
|
+
continue
|
318
|
+
|
319
|
+
is_eos_tokens = (out == eos_token)
|
320
|
+
|
321
|
+
if is_eos_tokens.any(dim = -1).all():
|
322
|
+
break
|
323
|
+
|
324
|
+
if exists(eos_token):
|
325
|
+
# mask out everything after the eos tokens
|
326
|
+
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
|
327
|
+
mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
|
328
|
+
out = out.masked_fill(mask, self.pad_value)
|
329
|
+
|
330
|
+
# select out the top beam
|
331
|
+
|
332
|
+
out = rearrange(out, '(b beams) seq -> b beams seq', b = batch)
|
333
|
+
|
334
|
+
out = out[:, 0, orig_seq_len:]
|
335
|
+
|
336
|
+
out, = unpack(out, packed_shape, '* n')
|
337
|
+
|
338
|
+
return out
|
339
|
+
|
164
340
|
@torch.no_grad()
|
165
341
|
@eval_decorator
|
166
342
|
def generate(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
2
2
|
x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=oTGHT5k52gsF5MXvH1WxkJ1Zq3HEmQ4laNM96RMdeiY,17594
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
@@ -12,7 +12,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
|
|
12
12
|
x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.4.
|
16
|
-
x_transformers-2.4.
|
17
|
-
x_transformers-2.4.
|
18
|
-
x_transformers-2.4.
|
15
|
+
x_transformers-2.4.10.dist-info/METADATA,sha256=gbkmhEOCOTsxuZ7zwdya-GF2FkdqDjd9ZlVUcBuRHZ8,90224
|
16
|
+
x_transformers-2.4.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.4.10.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.4.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|