x-transformers 2.1.26__py3-none-any.whl → 2.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from __future__ import annotations
9
9
  import torch
10
10
  from torch.autograd import Function
11
11
  from torch.nn import Module, ModuleList
12
- from torch import nn, cat, stack, tensor, arange, cartesian_prod
12
+ from torch import nn, cat, stack, tensor, Tensor, arange, cartesian_prod
13
13
  import torch.nn.functional as F
14
14
 
15
15
  from x_transformers.autoregressive_wrapper import (
@@ -34,6 +34,25 @@ def exists(v):
34
34
  def default(v, d):
35
35
  return v if exists(v) else d
36
36
 
37
+ # a custom flip that can handle variable lengths across batch
38
+
39
+ def flip(x, dim = 1, lens = None):
40
+ if not exists(lens):
41
+ return x.flip(dim)
42
+
43
+ batch, seq_len, device = *x.shape[:2], x.device
44
+ seq = arange(seq_len, device = device)
45
+
46
+ mask = einx.less('j, i -> i j', seq, lens)
47
+ masked_seq = einx.where('i j, j,', mask, seq, -1)
48
+
49
+ flip_indices = masked_seq.argsort(dim = -1, descending = True)
50
+
51
+ if x.ndim == 3:
52
+ flip_indices = repeat(flip_indices, '... -> ... d', d = x.shape[-1])
53
+
54
+ return x.gather(dim, flip_indices)
55
+
37
56
  # wrappers
38
57
 
39
58
  class BeliefStateWrapper(Module):
@@ -230,19 +249,26 @@ class BeliefStateWrapper(Module):
230
249
  def forward(
231
250
  self,
232
251
  seq,
252
+ lens: Tensor | None = None, # Int['b']
233
253
  return_loss_only = False,
234
254
  loss_scale = 1.,
235
255
  loss_weight_by_fb_indices: callable | None = None
236
256
  ):
237
257
  batch, seq_len, device = *seq.shape, seq.device
238
258
 
259
+ # handle variable length sequences
260
+
261
+ if exists(lens):
262
+ mask = einx.less('j, i -> i j', arange(seq_len, device = device), lens)
263
+ seq_for_labels = torch.where(mask, seq, -1)
264
+
239
265
  # forward autoregressive
240
266
 
241
267
  forward_embeds = self.forward_decoder(seq, return_embeddings = True)
242
268
 
243
269
  # backward autoregressive
244
270
 
245
- backward_seq = seq.flip(1)
271
+ backward_seq = flip(seq, lens = lens)
246
272
 
247
273
  suffix_tokens = repeat(self.suffix_token, 'd -> b 1 d', b = batch)
248
274
 
@@ -252,7 +278,7 @@ class BeliefStateWrapper(Module):
252
278
  return_embeddings = True
253
279
  )
254
280
 
255
- backward_embeds = backward_embeds.flip(1)
281
+ backward_embeds = flip(backward_embeds, lens = lens)
256
282
 
257
283
  # trick to reduce memory on backwards pass
258
284
 
@@ -294,7 +320,7 @@ class BeliefStateWrapper(Module):
294
320
 
295
321
  labels_fi, labels_bi = (fi + 1), (bi - 1)
296
322
 
297
- forward_labels, backward_labels = seq[:, labels_fi], seq[:, labels_bi]
323
+ forward_labels, backward_labels = seq_for_labels[:, labels_fi], seq_for_labels[:, labels_bi]
298
324
 
299
325
  labels = cat((forward_labels, backward_labels), dim = -1)
300
326
 
@@ -312,7 +338,8 @@ class BeliefStateWrapper(Module):
312
338
  loss = F.cross_entropy(
313
339
  rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
314
340
  labels,
315
- reduction = 'none' if self.needs_loss_weight else 'mean'
341
+ reduction = 'none' if self.needs_loss_weight else 'mean',
342
+ ignore_index = -1
316
343
  )
317
344
 
318
345
  # maybe predict terminal
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.26
3
+ Version: 2.1.28
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=dnLlKe830rWjUbCKEuRg880KkXW2bVxd2vvdM4O5ZQU,11613
4
+ x_transformers/belief_state_wrapper.py,sha256=nx7NdEZQ98Puz1RwAl7wThFJ_R8xLpUbwoqYjb6IF28,12508
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.26.dist-info/METADATA,sha256=vk30PeHoMDinssUDBEaUV3rT3NpATdJ7GDZSNzDfzhg,87875
14
- x_transformers-2.1.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.26.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.26.dist-info/RECORD,,
13
+ x_transformers-2.1.28.dist-info/METADATA,sha256=9VQXzWtJjNhONmS9sSxM4DQrJZJok1TgkUN0q8eT-S0,87875
14
+ x_transformers-2.1.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.28.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.28.dist-info/RECORD,,