x-transformers 2.3.3__py3-none-any.whl → 2.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -128,6 +128,7 @@ class ContinuousTransformerWrapper(Module):
128
128
  mems = None,
129
129
  mem_masks = None,
130
130
  pos = None,
131
+ sum_embeds = None,
131
132
  prepend_embeds = None,
132
133
  prepend_mask = None,
133
134
  **kwargs
@@ -147,6 +148,9 @@ class ContinuousTransformerWrapper(Module):
147
148
  x = self.project_in(x)
148
149
  x = x + self.pos_emb(x, pos = pos)
149
150
 
151
+ if exists(sum_embeds):
152
+ x = x + sum_embeds
153
+
150
154
  x = self.post_emb_norm(x)
151
155
 
152
156
  # memory tokens
@@ -293,6 +297,20 @@ class ContinuousAutoregressiveWrapper(Module):
293
297
 
294
298
  assert 'prepend_embeds' not in kwargs
295
299
 
300
+ # lens
301
+
302
+ lens = kwargs.pop('lens', None)
303
+
304
+ if exists(lens):
305
+ assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
306
+ seq_len, device = inp.shape[1], inp.device
307
+ seq_arange = torch.arange(seq_len, device = device)
308
+ mask = einx.less('j, i -> i j', seq_arange, lens)
309
+
310
+ kwargs['mask'] = mask
311
+
312
+ # mask
313
+
296
314
  mask = kwargs.get('mask', None)
297
315
 
298
316
  if exists(mask) and mask.shape[1] == x.shape[1]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.3
3
+ Version: 2.3.5
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=663KriM8j3VQPe732kTklB4RPheS4Zj30sB8F9c8ySg,9016
5
+ x_transformers/continuous.py,sha256=bTxwCt_8RlT1-aR2F4R8YOhpjMF-TbpElRbbRiNd6M8,9512
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=MF91aJGr2DOjIGe57uqwgyNxCExBg_tI9z7usAJMxOM,112401
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
14
- x_transformers-2.3.3.dist-info/METADATA,sha256=1Ly7M8Jc_tL2auBw7N8Q4wpCSsp8zq-Td7AHqHzkZmk,88686
15
- x_transformers-2.3.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.3.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.3.dist-info/RECORD,,
14
+ x_transformers-2.3.5.dist-info/METADATA,sha256=wPHqpSgc75F3npfdSNCzro1F6PBlVXabA0oarpvZMHI,88686
15
+ x_transformers-2.3.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.5.dist-info/RECORD,,