x-transformers 2.1.25__py3-none-any.whl → 2.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from __future__ import annotations
9
9
  import torch
10
10
  from torch.autograd import Function
11
11
  from torch.nn import Module, ModuleList
12
- from torch import nn, cat, stack, tensor, arange, cartesian_prod
12
+ from torch import nn, cat, stack, tensor, Tensor, arange, cartesian_prod
13
13
  import torch.nn.functional as F
14
14
 
15
15
  from x_transformers.autoregressive_wrapper import (
@@ -34,6 +34,25 @@ def exists(v):
34
34
  def default(v, d):
35
35
  return v if exists(v) else d
36
36
 
37
+ # a custom flip that can handle variable lengths across batch
38
+
39
+ def flip(x, dim = 1, lens = None):
40
+ if not exists(lens):
41
+ return x.flip(dim)
42
+
43
+ batch, seq_len, device = *x.shape[:2], x.device
44
+ seq = arange(seq_len, device = device)
45
+
46
+ mask = einx.less('j, i -> i j', seq, lens)
47
+ masked_seq = einx.where('i j, j,', mask, seq, -1)
48
+
49
+ flip_indices = masked_seq.argsort(dim = -1, descending = True)
50
+
51
+ if x.ndim == 3:
52
+ flip_indices = repeat(flip_indices, '... -> ... d', d = x.shape[-1])
53
+
54
+ return x.gather(dim, flip_indices)
55
+
37
56
  # wrappers
38
57
 
39
58
  class BeliefStateWrapper(Module):
@@ -230,19 +249,25 @@ class BeliefStateWrapper(Module):
230
249
  def forward(
231
250
  self,
232
251
  seq,
252
+ lens: Tensor | None = None, # Int['b']
233
253
  return_loss_only = False,
234
254
  loss_scale = 1.,
235
255
  loss_weight_by_fb_indices: callable | None = None
236
256
  ):
237
257
  batch, seq_len, device = *seq.shape, seq.device
238
258
 
259
+ # handle variable length sequences
260
+
261
+ if exists(lens):
262
+ mask = einx.less('j, i -> i j', arange(seq_len, device = device), lens)
263
+
239
264
  # forward autoregressive
240
265
 
241
266
  forward_embeds = self.forward_decoder(seq, return_embeddings = True)
242
267
 
243
268
  # backward autoregressive
244
269
 
245
- backward_seq = seq.flip(1)
270
+ backward_seq = flip(seq, lens = lens)
246
271
 
247
272
  suffix_tokens = repeat(self.suffix_token, 'd -> b 1 d', b = batch)
248
273
 
@@ -252,7 +277,7 @@ class BeliefStateWrapper(Module):
252
277
  return_embeddings = True
253
278
  )
254
279
 
255
- backward_embeds = backward_embeds.flip(1)
280
+ backward_embeds = flip(backward_embeds, lens = lens)
256
281
 
257
282
  # trick to reduce memory on backwards pass
258
283
 
@@ -320,7 +345,7 @@ class BeliefStateWrapper(Module):
320
345
  if exists(self.to_distance_logits):
321
346
  distance_logits = self.to_distance_logits(fb_embeds)
322
347
 
323
- distance_labels = bi - fi
348
+ distance_labels = (bi - fi).clamp(max = self.max_pred_distance - 1)
324
349
  distance_labels = repeat(distance_labels, 'n -> b n', b = batch)
325
350
 
326
351
  pred_dist_loss = F.cross_entropy(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.25
3
+ Version: 2.1.27
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=xPzhUZYm7qdHYp9fQ73HjwvWmEhve6-cEisvYK5serI,11571
4
+ x_transformers/belief_state_wrapper.py,sha256=p9OZTX-xs8YAiPdDCpZLqCG8VGmF8OJU6Zriy3mTGfo,12399
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.25.dist-info/METADATA,sha256=h182tY5ffDUrwUr8VZYekkTeMcDlMD_Krw1XP2K0YWU,87875
14
- x_transformers-2.1.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.25.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.25.dist-info/RECORD,,
13
+ x_transformers-2.1.27.dist-info/METADATA,sha256=mUUzC0KPtW-RFPxwHyY_xr72ZrfDjPjjhaCyH4erUTw,87875
14
+ x_transformers-2.1.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.27.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.27.dist-info/RECORD,,