x-transformers 2.3.2__py3-none-any.whl → 2.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +26 -0
- {x_transformers-2.3.2.dist-info → x_transformers-2.3.4.dist-info}/METADATA +1 -1
- {x_transformers-2.3.2.dist-info → x_transformers-2.3.4.dist-info}/RECORD +5 -5
- {x_transformers-2.3.2.dist-info → x_transformers-2.3.4.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.2.dist-info → x_transformers-2.3.4.dist-info}/licenses/LICENSE +0 -0
x_transformers/continuous.py
CHANGED
@@ -123,6 +123,7 @@ class ContinuousTransformerWrapper(Module):
|
|
123
123
|
return_intermediates = False,
|
124
124
|
return_mems = False,
|
125
125
|
mask = None,
|
126
|
+
lens = None,
|
126
127
|
return_attn = False,
|
127
128
|
mems = None,
|
128
129
|
mem_masks = None,
|
@@ -133,6 +134,16 @@ class ContinuousTransformerWrapper(Module):
|
|
133
134
|
):
|
134
135
|
batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
|
135
136
|
|
137
|
+
# maybe seq lengths passed in
|
138
|
+
|
139
|
+
if exists(lens):
|
140
|
+
assert not exists(mask), 'either `mask` or `lens` passed in, but not both'
|
141
|
+
seq_arange = torch.arange(seq, device = device)
|
142
|
+
|
143
|
+
mask = einx.less('j, i -> i j', seq_arange, lens)
|
144
|
+
|
145
|
+
# project in + positional embedding
|
146
|
+
|
136
147
|
x = self.project_in(x)
|
137
148
|
x = x + self.pos_emb(x, pos = pos)
|
138
149
|
|
@@ -282,7 +293,22 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
282
293
|
|
283
294
|
assert 'prepend_embeds' not in kwargs
|
284
295
|
|
296
|
+
# lens
|
297
|
+
|
298
|
+
lens = kwargs.pop('lens', None)
|
299
|
+
|
300
|
+
if exists(lens):
|
301
|
+
assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
|
302
|
+
seq_len, device = inp.shape[1], inp.device
|
303
|
+
seq_arange = torch.arange(seq_len, device = device)
|
304
|
+
mask = einx.less('j, i -> i j', seq_arange, lens)
|
305
|
+
|
306
|
+
kwargs['mask'] = mask
|
307
|
+
|
308
|
+
# mask
|
309
|
+
|
285
310
|
mask = kwargs.get('mask', None)
|
311
|
+
|
286
312
|
if exists(mask) and mask.shape[1] == x.shape[1]:
|
287
313
|
mask = mask[:, :-1]
|
288
314
|
kwargs['mask'] = mask
|
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
|
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=TwqPXZG14JGzfK3iK4bdQo1u0u2KKNLRcRjh54Jc0z8,9422
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=MF91aJGr2DOjIGe57uqwgyNxCExBg_tI9z7usAJMxOM,112401
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.4.dist-info/METADATA,sha256=dRdw8JjntClC_behPfg-Upf0bGtuhn1UKklkXRrLCmU,88686
|
15
|
+
x_transformers-2.3.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.4.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|