x-transformers 2.5.2__py3-none-any.whl → 2.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,9 +4,10 @@ from math import ceil, log
4
4
  from typing import Tuple, Callable
5
5
 
6
6
  import torch
7
- from torch import nn, Tensor
7
+ from torch import nn, tensor, Tensor
8
8
  from torch.nn import Module
9
9
  import torch.nn.functional as F
10
+ from torch.nn.utils.rnn import pad_sequence
10
11
 
11
12
  from einops import rearrange, repeat, pack, unpack
12
13
 
@@ -347,7 +348,7 @@ class AutoregressiveWrapper(Module):
347
348
  @eval_decorator
348
349
  def generate(
349
350
  self,
350
- prompts,
351
+ prompts: list[Tensor] | Tensor,
351
352
  seq_len,
352
353
  eos_token = None,
353
354
  temperature = 1.,
@@ -363,11 +364,23 @@ class AutoregressiveWrapper(Module):
363
364
  cache_kv = True,
364
365
  **kwargs
365
366
  ):
366
- max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
367
+ max_seq_len, greedy = self.max_seq_len, temperature == 0.
368
+
369
+ # handle prompts given as list of variable lengthed token ids
370
+
371
+ if isinstance(prompts, list):
372
+ assert len(prompts) > 0, 'prompts cannot be empty list'
373
+ assert not exists(prompt_lens), '`prompt_len` will be auto derived if prompts are passed in as list of Tensors'
374
+
375
+ prompt_lens = tensor([t.shape[0] for t in prompts], device = prompts[0].device)
376
+
377
+ prompts = pad_sequence(prompts, batch_first = True)
378
+
379
+ # pack maybe no batch
367
380
 
368
381
  prompts, ps = pack([prompts], '* n')
369
382
 
370
- b, t = prompts.shape
383
+ b, t, device = *prompts.shape, prompts.device
371
384
 
372
385
  # handle filter logits fn given as string
373
386
 
@@ -380,6 +393,7 @@ class AutoregressiveWrapper(Module):
380
393
 
381
394
  seq_start_pos = None
382
395
  if exists(prompt_lens):
396
+ print('prompt lens')
383
397
  prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
384
398
  seq_start_pos = t - prompt_lens
385
399
 
@@ -2787,6 +2787,7 @@ class AttentionPool(Module):
2787
2787
  self.pooler = Attention(dim = dim, dim_context = dim_context, heads = heads, dim_head = dim_head, **attn_kwargs)
2788
2788
 
2789
2789
  self.add_residual = add_residual
2790
+ self.squeeze_output = squeeze_output
2790
2791
 
2791
2792
  def forward(self, context, mask = None):
2792
2793
  batch = context.shape[0]
@@ -2798,6 +2799,9 @@ class AttentionPool(Module):
2798
2799
  if self.add_residual:
2799
2800
  pooled = pooled + queries
2800
2801
 
2802
+ if self.squeeze_output:
2803
+ pooled = rearrange(pooled, 'b 1 d -> b d')
2804
+
2801
2805
  return pooled
2802
2806
 
2803
2807
  class ViTransformerWrapper(Module):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.5.2
3
+ Version: 2.5.4
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  x_transformers/__init__.py,sha256=aVuhUU0572TJHW88BVc4yA2tla0Zb8l3NH7W4RZ1AEs,1005
2
2
  x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
3
- x_transformers/autoregressive_wrapper.py,sha256=y798kS9_VvPOY_5Ilits_64aXNqYvGuilsky1y07ryE,17834
3
+ x_transformers/autoregressive_wrapper.py,sha256=wu7yJOf2XL5QI0vJtRTSsyOI-JpCuwE7YXGEumhtQYQ,18434
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
11
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=vmMrHP3hAQ9iAJlRN1pKmXOn7pD3mfh_ndtaR7LMPzU,119860
12
+ x_transformers/x_transformers.py,sha256=fW-AoomNCw4n2JFbZN9rZV3lKQvz_Tl6L4txUvac_9o,119993
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.5.2.dist-info/METADATA,sha256=yeferX_PJIv0Lxs36vZSV7Z2w9ol4udiUAON95hP_bY,90223
16
- x_transformers-2.5.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.5.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.5.2.dist-info/RECORD,,
15
+ x_transformers-2.5.4.dist-info/METADATA,sha256=sDkhsAQnYbflFctIdoHdtrXV6doTSO_jNLcnFus0CEY,90223
16
+ x_transformers-2.5.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.5.4.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.5.4.dist-info/RECORD,,