x-transformers 2.1.26__py3-none-any.whl → 2.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +32 -5
- {x_transformers-2.1.26.dist-info → x_transformers-2.1.28.dist-info}/METADATA +1 -1
- {x_transformers-2.1.26.dist-info → x_transformers-2.1.28.dist-info}/RECORD +5 -5
- {x_transformers-2.1.26.dist-info → x_transformers-2.1.28.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.26.dist-info → x_transformers-2.1.28.dist-info}/licenses/LICENSE +0 -0
@@ -9,7 +9,7 @@ from __future__ import annotations
|
|
9
9
|
import torch
|
10
10
|
from torch.autograd import Function
|
11
11
|
from torch.nn import Module, ModuleList
|
12
|
-
from torch import nn, cat, stack, tensor, arange, cartesian_prod
|
12
|
+
from torch import nn, cat, stack, tensor, Tensor, arange, cartesian_prod
|
13
13
|
import torch.nn.functional as F
|
14
14
|
|
15
15
|
from x_transformers.autoregressive_wrapper import (
|
@@ -34,6 +34,25 @@ def exists(v):
|
|
34
34
|
def default(v, d):
|
35
35
|
return v if exists(v) else d
|
36
36
|
|
37
|
+
# a custom flip that can handle variable lengths across batch
|
38
|
+
|
39
|
+
def flip(x, dim = 1, lens = None):
|
40
|
+
if not exists(lens):
|
41
|
+
return x.flip(dim)
|
42
|
+
|
43
|
+
batch, seq_len, device = *x.shape[:2], x.device
|
44
|
+
seq = arange(seq_len, device = device)
|
45
|
+
|
46
|
+
mask = einx.less('j, i -> i j', seq, lens)
|
47
|
+
masked_seq = einx.where('i j, j,', mask, seq, -1)
|
48
|
+
|
49
|
+
flip_indices = masked_seq.argsort(dim = -1, descending = True)
|
50
|
+
|
51
|
+
if x.ndim == 3:
|
52
|
+
flip_indices = repeat(flip_indices, '... -> ... d', d = x.shape[-1])
|
53
|
+
|
54
|
+
return x.gather(dim, flip_indices)
|
55
|
+
|
37
56
|
# wrappers
|
38
57
|
|
39
58
|
class BeliefStateWrapper(Module):
|
@@ -230,19 +249,26 @@ class BeliefStateWrapper(Module):
|
|
230
249
|
def forward(
|
231
250
|
self,
|
232
251
|
seq,
|
252
|
+
lens: Tensor | None = None, # Int['b']
|
233
253
|
return_loss_only = False,
|
234
254
|
loss_scale = 1.,
|
235
255
|
loss_weight_by_fb_indices: callable | None = None
|
236
256
|
):
|
237
257
|
batch, seq_len, device = *seq.shape, seq.device
|
238
258
|
|
259
|
+
# handle variable length sequences
|
260
|
+
|
261
|
+
if exists(lens):
|
262
|
+
mask = einx.less('j, i -> i j', arange(seq_len, device = device), lens)
|
263
|
+
seq_for_labels = torch.where(mask, seq, -1)
|
264
|
+
|
239
265
|
# forward autoregressive
|
240
266
|
|
241
267
|
forward_embeds = self.forward_decoder(seq, return_embeddings = True)
|
242
268
|
|
243
269
|
# backward autoregressive
|
244
270
|
|
245
|
-
backward_seq =
|
271
|
+
backward_seq = flip(seq, lens = lens)
|
246
272
|
|
247
273
|
suffix_tokens = repeat(self.suffix_token, 'd -> b 1 d', b = batch)
|
248
274
|
|
@@ -252,7 +278,7 @@ class BeliefStateWrapper(Module):
|
|
252
278
|
return_embeddings = True
|
253
279
|
)
|
254
280
|
|
255
|
-
backward_embeds =
|
281
|
+
backward_embeds = flip(backward_embeds, lens = lens)
|
256
282
|
|
257
283
|
# trick to reduce memory on backwards pass
|
258
284
|
|
@@ -294,7 +320,7 @@ class BeliefStateWrapper(Module):
|
|
294
320
|
|
295
321
|
labels_fi, labels_bi = (fi + 1), (bi - 1)
|
296
322
|
|
297
|
-
forward_labels, backward_labels =
|
323
|
+
forward_labels, backward_labels = seq_for_labels[:, labels_fi], seq_for_labels[:, labels_bi]
|
298
324
|
|
299
325
|
labels = cat((forward_labels, backward_labels), dim = -1)
|
300
326
|
|
@@ -312,7 +338,8 @@ class BeliefStateWrapper(Module):
|
|
312
338
|
loss = F.cross_entropy(
|
313
339
|
rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
|
314
340
|
labels,
|
315
|
-
reduction = 'none' if self.needs_loss_weight else 'mean'
|
341
|
+
reduction = 'none' if self.needs_loss_weight else 'mean',
|
342
|
+
ignore_index = -1
|
316
343
|
)
|
317
344
|
|
318
345
|
# maybe predict terminal
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=nx7NdEZQ98Puz1RwAl7wThFJ_R8xLpUbwoqYjb6IF28,12508
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.28.dist-info/METADATA,sha256=9VQXzWtJjNhONmS9sSxM4DQrJZJok1TgkUN0q8eT-S0,87875
|
14
|
+
x_transformers-2.1.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.28.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.28.dist-info/RECORD,,
|
File without changes
|
File without changes
|