x-transformers 1.22.11__py3-none-any.whl → 1.22.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,8 +28,7 @@ def eval_decorator(fn):
28
28
 
29
29
  # for variable lengthed prefixes
30
30
 
31
- def align(t, lens, pad_id = 0, left = False, right = False):
32
- assert left ^ right
31
+ def align_right(t, lens, pad_id = 0):
33
32
  batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
34
33
 
35
34
  assert lens.ndim == 1 and lens.shape[0] == batch
@@ -41,14 +40,9 @@ def align(t, lens, pad_id = 0, left = False, right = False):
41
40
  batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
42
41
  prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
43
42
 
44
- if left:
45
- padding = (0, max_pad_len)
46
- offset = pad_lens
47
- elif right:
48
- padding = (max_pad_len, 0)
49
- offset = max_pad_len - pad_lens
43
+ t = F.pad(t, (max_pad_len, 0), value = 0)
44
+ offset = max_pad_len - pad_lens
50
45
 
51
- t = F.pad(t, padding, value = 0)
52
46
  aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
53
47
  return aligned
54
48
 
@@ -157,7 +151,7 @@ class AutoregressiveWrapper(Module):
157
151
 
158
152
  seq_start_pos = None
159
153
  if exists(prompt_lens):
160
- prompts = align(prompts, prompt_lens, pad_id = self.pad_value, right = True)
154
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
161
155
  seq_start_pos = t - prompt_lens
162
156
 
163
157
  # output from which sampled tokens appended to
@@ -244,11 +238,6 @@ class AutoregressiveWrapper(Module):
244
238
  out = out.masked_fill(mask, self.pad_value)
245
239
  break
246
240
 
247
- # if variable lengthed, needs to realign
248
-
249
- if exists(prompt_lens):
250
- out = align(out, prompt_lens, pad_id = self.pad_value, left = True)
251
-
252
241
  out = out[:, t:]
253
242
 
254
243
  out, = unpack(out, ps, '* n')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.22.11
3
+ Version: 1.22.12
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,12 +1,12 @@
1
1
  x_transformers/__init__.py,sha256=FDb654rUx8FpXRd76B8q0diH8I7q-ZjTWEtEJ4UM21Y,701
2
2
  x_transformers/attend.py,sha256=xPa6RjnMDsc1jKliQdThETMTQeRX3ycmAlw5pgzLIf4,12605
3
- x_transformers/autoregressive_wrapper.py,sha256=hz3qp_Hmt4oR2MfVw-0ctFBLHW4Liu8XHlkd-rr8O48,8818
3
+ x_transformers/autoregressive_wrapper.py,sha256=uUnwXP2uZ4oJSN4EVXfWQormKWv8c6yzrE5tDZUjSag,8480
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
6
  x_transformers/x_transformers.py,sha256=mYveA7PqRUZg9-82ALFBpuhTfQirfCF5rxL6EUCdU5I,59075
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
- x_transformers-1.22.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.22.11.dist-info/METADATA,sha256=-_D35IjVE9XNd4Tc2ixu7GVhAsliBTz3hfcL9GchDDA,662
10
- x_transformers-1.22.11.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
- x_transformers-1.22.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.22.11.dist-info/RECORD,,
8
+ x_transformers-1.22.12.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.22.12.dist-info/METADATA,sha256=lissCQf2eUs5Oalsk0PRw0gwTiybipzr9P5YRybgZdQ,662
10
+ x_transformers-1.22.12.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
+ x_transformers-1.22.12.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.22.12.dist-info/RECORD,,