x-transformers 2.4.10__py3-none-any.whl → 2.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +14 -8
- {x_transformers-2.4.10.dist-info → x_transformers-2.4.12.dist-info}/METADATA +1 -1
- {x_transformers-2.4.10.dist-info → x_transformers-2.4.12.dist-info}/RECORD +5 -5
- {x_transformers-2.4.10.dist-info → x_transformers-2.4.12.dist-info}/WHEEL +0 -0
- {x_transformers-2.4.10.dist-info → x_transformers-2.4.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -183,11 +183,12 @@ class AutoregressiveWrapper(Module):
|
|
|
183
183
|
prompts,
|
|
184
184
|
seq_len,
|
|
185
185
|
beams = 4,
|
|
186
|
+
return_beams_and_scores = False,
|
|
186
187
|
eos_token = None,
|
|
187
188
|
temperature = 1.,
|
|
188
189
|
stochastic = False,
|
|
189
190
|
prompt_lens: Tensor | None = None,
|
|
190
|
-
filter_logits_fn: str | Callable =
|
|
191
|
+
filter_logits_fn: str | Callable = identity,
|
|
191
192
|
restrict_to_max_seq_len = True,
|
|
192
193
|
filter_kwargs: dict = dict(),
|
|
193
194
|
cache_kv = True,
|
|
@@ -285,7 +286,7 @@ class AutoregressiveWrapper(Module):
|
|
|
285
286
|
out = repeat(out, 'b ... -> (b beams) ...', beams = beams)
|
|
286
287
|
samples = rearrange(samples, 'b beams -> (b beams) 1')
|
|
287
288
|
|
|
288
|
-
if should_cache:
|
|
289
|
+
if should_cache and is_first:
|
|
289
290
|
modify_cached_kv(cache, lambda t: repeat(t, 'b ... -> (b beams) ...', beams = beams))
|
|
290
291
|
|
|
291
292
|
# concat sample
|
|
@@ -303,14 +304,13 @@ class AutoregressiveWrapper(Module):
|
|
|
303
304
|
|
|
304
305
|
scores = scores[:, :beams]
|
|
305
306
|
top_beams_indices = sort_indices[:, :beams]
|
|
306
|
-
|
|
307
|
+
|
|
308
|
+
top_beams_indices = curr_num_beams * batch_arange[:, None] + top_beams_indices
|
|
307
309
|
|
|
308
310
|
flattened_beam_indices = rearrange(top_beams_indices, 'b beams -> (b beams)')
|
|
309
311
|
|
|
310
312
|
out = out[flattened_beam_indices]
|
|
311
313
|
|
|
312
|
-
modify_cached_kv(cache, lambda t: t[flattened_beam_indices])
|
|
313
|
-
|
|
314
314
|
scores = rearrange(scores, 'b beams -> (b beams)')
|
|
315
315
|
|
|
316
316
|
if not exists(eos_token):
|
|
@@ -331,11 +331,17 @@ class AutoregressiveWrapper(Module):
|
|
|
331
331
|
|
|
332
332
|
out = rearrange(out, '(b beams) seq -> b beams seq', b = batch)
|
|
333
333
|
|
|
334
|
-
out = out[
|
|
334
|
+
out = out[..., orig_seq_len:]
|
|
335
335
|
|
|
336
|
-
out, = unpack(out, packed_shape, '* n')
|
|
336
|
+
out, = unpack(out, packed_shape, '* beams n') # prompt may have no batch dimension
|
|
337
337
|
|
|
338
|
-
|
|
338
|
+
if not return_beams_and_scores:
|
|
339
|
+
return out[..., 0, :]
|
|
340
|
+
|
|
341
|
+
scores = rearrange(scores, '(b beams) -> beams b', b = batch)
|
|
342
|
+
out = rearrange(out, 'b beams n -> beams b n')
|
|
343
|
+
|
|
344
|
+
return out, scores
|
|
339
345
|
|
|
340
346
|
@torch.no_grad()
|
|
341
347
|
@eval_decorator
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
|
2
2
|
x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
|
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=y798kS9_VvPOY_5Ilits_64aXNqYvGuilsky1y07ryE,17834
|
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
|
5
5
|
x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
|
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
@@ -12,7 +12,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
|
|
|
12
12
|
x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
|
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
|
15
|
-
x_transformers-2.4.
|
|
16
|
-
x_transformers-2.4.
|
|
17
|
-
x_transformers-2.4.
|
|
18
|
-
x_transformers-2.4.
|
|
15
|
+
x_transformers-2.4.12.dist-info/METADATA,sha256=0o0pofz1ZRwIeFs-e-D9fISslemepdDCQne8FHHeccc,90224
|
|
16
|
+
x_transformers-2.4.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
17
|
+
x_transformers-2.4.12.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
|
18
|
+
x_transformers-2.4.12.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|