x-transformers 2.5.3__py3-none-any.whl → 2.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +17 -4
- {x_transformers-2.5.3.dist-info → x_transformers-2.5.5.dist-info}/METADATA +1 -1
- {x_transformers-2.5.3.dist-info → x_transformers-2.5.5.dist-info}/RECORD +5 -5
- {x_transformers-2.5.3.dist-info → x_transformers-2.5.5.dist-info}/WHEEL +0 -0
- {x_transformers-2.5.3.dist-info → x_transformers-2.5.5.dist-info}/licenses/LICENSE +0 -0
@@ -4,9 +4,10 @@ from math import ceil, log
|
|
4
4
|
from typing import Tuple, Callable
|
5
5
|
|
6
6
|
import torch
|
7
|
-
from torch import nn, Tensor
|
7
|
+
from torch import nn, tensor, Tensor
|
8
8
|
from torch.nn import Module
|
9
9
|
import torch.nn.functional as F
|
10
|
+
from torch.nn.utils.rnn import pad_sequence
|
10
11
|
|
11
12
|
from einops import rearrange, repeat, pack, unpack
|
12
13
|
|
@@ -347,7 +348,7 @@ class AutoregressiveWrapper(Module):
|
|
347
348
|
@eval_decorator
|
348
349
|
def generate(
|
349
350
|
self,
|
350
|
-
prompts,
|
351
|
+
prompts: list[Tensor] | Tensor,
|
351
352
|
seq_len,
|
352
353
|
eos_token = None,
|
353
354
|
temperature = 1.,
|
@@ -363,11 +364,23 @@ class AutoregressiveWrapper(Module):
|
|
363
364
|
cache_kv = True,
|
364
365
|
**kwargs
|
365
366
|
):
|
366
|
-
max_seq_len, greedy
|
367
|
+
max_seq_len, greedy = self.max_seq_len, temperature == 0.
|
368
|
+
|
369
|
+
# handle prompts given as list of variable lengthed token ids
|
370
|
+
|
371
|
+
if isinstance(prompts, list):
|
372
|
+
assert len(prompts) > 0, 'prompts cannot be empty list'
|
373
|
+
assert not exists(prompt_lens), '`prompt_len` will be auto derived if prompts are passed in as list of Tensors'
|
374
|
+
|
375
|
+
prompt_lens = tensor([t.shape[0] for t in prompts], device = prompts[0].device)
|
376
|
+
|
377
|
+
prompts = pad_sequence(prompts, batch_first = True)
|
378
|
+
|
379
|
+
# pack maybe no batch
|
367
380
|
|
368
381
|
prompts, ps = pack([prompts], '* n')
|
369
382
|
|
370
|
-
b, t = prompts.shape
|
383
|
+
b, t, device = *prompts.shape, prompts.device
|
371
384
|
|
372
385
|
# handle filter logits fn given as string
|
373
386
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
|
2
2
|
x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=BsGO9xfVYkvynqbU1__tu_S_cxl7gss0YwnkhIa2baY,18401
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
5
|
x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
@@ -12,7 +12,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
|
|
12
12
|
x_transformers/x_transformers.py,sha256=fW-AoomNCw4n2JFbZN9rZV3lKQvz_Tl6L4txUvac_9o,119993
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.5.
|
16
|
-
x_transformers-2.5.
|
17
|
-
x_transformers-2.5.
|
18
|
-
x_transformers-2.5.
|
15
|
+
x_transformers-2.5.5.dist-info/METADATA,sha256=Igay1acyeLzF_vDvB9BW7NWuAy_ck7G2rhITKre3Lew,90223
|
16
|
+
x_transformers-2.5.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.5.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.5.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|