x-transformers 2.5.3__tar.gz → 2.5.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.5.3 → x_transformers-2.5.4}/PKG-INFO +1 -1
- {x_transformers-2.5.3 → x_transformers-2.5.4}/pyproject.toml +1 -1
- {x_transformers-2.5.3 → x_transformers-2.5.4}/tests/test_x_transformers.py +29 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/autoregressive_wrapper.py +18 -4
- {x_transformers-2.5.3 → x_transformers-2.5.4}/.github/FUNDING.yml +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/.gitignore +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/LICENSE +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/README.md +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/data/README.md +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/data/enwik8.gz +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/all-attention.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/attention-on-attention.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/deepnorm.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/fcm.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/ffglu.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/flash-attention.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/gate_values.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/gating.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/macaron-1.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/macaron-2.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/memory-transformer.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/normformer.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/pia.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/resi_dual.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/residual_attn.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/rezero.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/rotary.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/sandwich-2.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/sandwich.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/sandwich_norm.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/scalenorm.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/talking-heads.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/topk-attention.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/images/xval.png +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/train_belief_state.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/train_copy.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/train_enwik8.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/train_length_extrapolate.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/train_parity.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/__init__.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/attend.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/continuous.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/dpo.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.5.3 → x_transformers-2.5.4}/x_transformers/xval.py +0 -0
@@ -1181,3 +1181,32 @@ def test_attn_pooler(
|
|
1181
1181
|
out = model(x)
|
1182
1182
|
|
1183
1183
|
assert out.shape == (2, num_pooled_tokens, 77)
|
1184
|
+
|
1185
|
+
def test_prompts_given_as_list_tensor():
|
1186
|
+
from x_transformers import AutoregressiveWrapper
|
1187
|
+
|
1188
|
+
model = TransformerWrapper(
|
1189
|
+
num_tokens = 20000,
|
1190
|
+
max_seq_len = 1024,
|
1191
|
+
attn_layers = Decoder(
|
1192
|
+
dim = 512,
|
1193
|
+
depth = 12,
|
1194
|
+
heads = 8
|
1195
|
+
)
|
1196
|
+
)
|
1197
|
+
|
1198
|
+
wrapped = AutoregressiveWrapper(model)
|
1199
|
+
|
1200
|
+
seq = torch.randint(0, 20000, (3, 1024))
|
1201
|
+
|
1202
|
+
loss = wrapped(seq)
|
1203
|
+
loss.backward()
|
1204
|
+
|
1205
|
+
sampled = wrapped.generate([
|
1206
|
+
torch.randint(0, 20000, (3,)),
|
1207
|
+
torch.randint(0, 20000, (5,)),
|
1208
|
+
torch.randint(0, 20000, (2,)),
|
1209
|
+
torch.randint(0, 20000, (7,)),
|
1210
|
+
], 256)
|
1211
|
+
|
1212
|
+
assert sampled.shape == (4, 256)
|
@@ -4,9 +4,10 @@ from math import ceil, log
|
|
4
4
|
from typing import Tuple, Callable
|
5
5
|
|
6
6
|
import torch
|
7
|
-
from torch import nn, Tensor
|
7
|
+
from torch import nn, tensor, Tensor
|
8
8
|
from torch.nn import Module
|
9
9
|
import torch.nn.functional as F
|
10
|
+
from torch.nn.utils.rnn import pad_sequence
|
10
11
|
|
11
12
|
from einops import rearrange, repeat, pack, unpack
|
12
13
|
|
@@ -347,7 +348,7 @@ class AutoregressiveWrapper(Module):
|
|
347
348
|
@eval_decorator
|
348
349
|
def generate(
|
349
350
|
self,
|
350
|
-
prompts,
|
351
|
+
prompts: list[Tensor] | Tensor,
|
351
352
|
seq_len,
|
352
353
|
eos_token = None,
|
353
354
|
temperature = 1.,
|
@@ -363,11 +364,23 @@ class AutoregressiveWrapper(Module):
|
|
363
364
|
cache_kv = True,
|
364
365
|
**kwargs
|
365
366
|
):
|
366
|
-
max_seq_len, greedy
|
367
|
+
max_seq_len, greedy = self.max_seq_len, temperature == 0.
|
368
|
+
|
369
|
+
# handle prompts given as list of variable lengthed token ids
|
370
|
+
|
371
|
+
if isinstance(prompts, list):
|
372
|
+
assert len(prompts) > 0, 'prompts cannot be empty list'
|
373
|
+
assert not exists(prompt_lens), '`prompt_len` will be auto derived if prompts are passed in as list of Tensors'
|
374
|
+
|
375
|
+
prompt_lens = tensor([t.shape[0] for t in prompts], device = prompts[0].device)
|
376
|
+
|
377
|
+
prompts = pad_sequence(prompts, batch_first = True)
|
378
|
+
|
379
|
+
# pack maybe no batch
|
367
380
|
|
368
381
|
prompts, ps = pack([prompts], '* n')
|
369
382
|
|
370
|
-
b, t = prompts.shape
|
383
|
+
b, t, device = *prompts.shape, prompts.device
|
371
384
|
|
372
385
|
# handle filter logits fn given as string
|
373
386
|
|
@@ -380,6 +393,7 @@ class AutoregressiveWrapper(Module):
|
|
380
393
|
|
381
394
|
seq_start_pos = None
|
382
395
|
if exists(prompt_lens):
|
396
|
+
print('prompt lens')
|
383
397
|
prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
|
384
398
|
seq_start_pos = t - prompt_lens
|
385
399
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|