x-transformers 2.3.15__py3-none-any.whl → 2.3.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +10 -2
- {x_transformers-2.3.15.dist-info → x_transformers-2.3.16.dist-info}/METADATA +1 -1
- {x_transformers-2.3.15.dist-info → x_transformers-2.3.16.dist-info}/RECORD +5 -5
- {x_transformers-2.3.15.dist-info → x_transformers-2.3.16.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.15.dist-info → x_transformers-2.3.16.dist-info}/licenses/LICENSE +0 -0
x_transformers/continuous.py
CHANGED
@@ -130,6 +130,7 @@ class ContinuousTransformerWrapper(Module):
|
|
130
130
|
sum_embeds = None,
|
131
131
|
prepend_embeds = None,
|
132
132
|
prepend_mask = None,
|
133
|
+
seq_start_pos = None,
|
133
134
|
**kwargs
|
134
135
|
):
|
135
136
|
batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
|
@@ -145,7 +146,7 @@ class ContinuousTransformerWrapper(Module):
|
|
145
146
|
# project in + positional embedding
|
146
147
|
|
147
148
|
x = self.project_in(x)
|
148
|
-
x = x + self.pos_emb(x, pos = pos)
|
149
|
+
x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos)
|
149
150
|
|
150
151
|
if exists(sum_embeds):
|
151
152
|
x = x + sum_embeds
|
@@ -315,6 +316,8 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
315
316
|
|
316
317
|
# get target
|
317
318
|
|
319
|
+
seq_start_pos = None
|
320
|
+
|
318
321
|
if one_step_autoregress:
|
319
322
|
target = x[:, None, 1:]
|
320
323
|
else:
|
@@ -328,6 +331,8 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
328
331
|
|
329
332
|
target = x[batch_arange, target_indices] # rollout targets
|
330
333
|
|
334
|
+
seq_start_pos = torch.zeros(batch, device = device, dtype = torch.long)
|
335
|
+
|
331
336
|
# assert inputs
|
332
337
|
|
333
338
|
assert 'prepend_embeds' not in kwargs
|
@@ -363,12 +368,15 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
363
368
|
|
364
369
|
# forward
|
365
370
|
|
366
|
-
out = self.net(inp, mask = step_mask, **kwargs)
|
371
|
+
out = self.net(inp, mask = step_mask, seq_start_pos = seq_start_pos, **kwargs)
|
367
372
|
|
368
373
|
outputs.append(out)
|
369
374
|
|
370
375
|
inp = out
|
371
376
|
|
377
|
+
if not one_step_autoregress:
|
378
|
+
seq_start_pos.sub_(1)
|
379
|
+
|
372
380
|
# stack masks and predictions from rollouts
|
373
381
|
|
374
382
|
masks = stack(masks, dim = 1) if exists(mask) else None
|
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
|
|
2
2
|
x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=jy2wsQ3sS80Qwm_gnAmdAnzBfzLoWrGPacOTzU1Q6JM,11674
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.16.dist-info/METADATA,sha256=-lL73g4mG5pszuaU7lPdMVGJ7ZtqBqhaejr5VvWWUiw,89897
|
15
|
+
x_transformers-2.3.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.16.dist-info/RECORD,,
|
File without changes
|
File without changes
|