x-transformers 2.3.15__py3-none-any.whl → 2.3.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -130,6 +130,7 @@ class ContinuousTransformerWrapper(Module):
130
130
  sum_embeds = None,
131
131
  prepend_embeds = None,
132
132
  prepend_mask = None,
133
+ seq_start_pos = None,
133
134
  **kwargs
134
135
  ):
135
136
  batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
@@ -145,7 +146,7 @@ class ContinuousTransformerWrapper(Module):
145
146
  # project in + positional embedding
146
147
 
147
148
  x = self.project_in(x)
148
- x = x + self.pos_emb(x, pos = pos)
149
+ x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos)
149
150
 
150
151
  if exists(sum_embeds):
151
152
  x = x + sum_embeds
@@ -315,6 +316,8 @@ class ContinuousAutoregressiveWrapper(Module):
315
316
 
316
317
  # get target
317
318
 
319
+ seq_start_pos = None
320
+
318
321
  if one_step_autoregress:
319
322
  target = x[:, None, 1:]
320
323
  else:
@@ -328,6 +331,8 @@ class ContinuousAutoregressiveWrapper(Module):
328
331
 
329
332
  target = x[batch_arange, target_indices] # rollout targets
330
333
 
334
+ seq_start_pos = torch.zeros(batch, device = device, dtype = torch.long)
335
+
331
336
  # assert inputs
332
337
 
333
338
  assert 'prepend_embeds' not in kwargs
@@ -363,12 +368,15 @@ class ContinuousAutoregressiveWrapper(Module):
363
368
 
364
369
  # forward
365
370
 
366
- out = self.net(inp, mask = step_mask, **kwargs)
371
+ out = self.net(inp, mask = step_mask, seq_start_pos = seq_start_pos, **kwargs)
367
372
 
368
373
  outputs.append(out)
369
374
 
370
375
  inp = out
371
376
 
377
+ if not one_step_autoregress:
378
+ seq_start_pos.sub_(1)
379
+
372
380
  # stack masks and predictions from rollouts
373
381
 
374
382
  masks = stack(masks, dim = 1) if exists(mask) else None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.15
3
+ Version: 2.3.16
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
2
2
  x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=5SZmi3Bd77aJAu50f4y1OwwruZd_3ZHptC8dtQmvvxM,11387
5
+ x_transformers/continuous.py,sha256=jy2wsQ3sS80Qwm_gnAmdAnzBfzLoWrGPacOTzU1Q6JM,11674
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.15.dist-info/METADATA,sha256=Dv6r9tZhbF-_7q5yhgmo3pto-D6NVJnYPNnSBEBt73I,89897
15
- x_transformers-2.3.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.15.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.15.dist-info/RECORD,,
14
+ x_transformers-2.3.16.dist-info/METADATA,sha256=-lL73g4mG5pszuaU7lPdMVGJ7ZtqBqhaejr5VvWWUiw,89897
15
+ x_transformers-2.3.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.16.dist-info/RECORD,,