x-transformers 2.1.26__py3-none-any.whl → 2.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +28 -3
- {x_transformers-2.1.26.dist-info → x_transformers-2.1.27.dist-info}/METADATA +1 -1
- {x_transformers-2.1.26.dist-info → x_transformers-2.1.27.dist-info}/RECORD +5 -5
- {x_transformers-2.1.26.dist-info → x_transformers-2.1.27.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.26.dist-info → x_transformers-2.1.27.dist-info}/licenses/LICENSE +0 -0
@@ -9,7 +9,7 @@ from __future__ import annotations
|
|
9
9
|
import torch
|
10
10
|
from torch.autograd import Function
|
11
11
|
from torch.nn import Module, ModuleList
|
12
|
-
from torch import nn, cat, stack, tensor, arange, cartesian_prod
|
12
|
+
from torch import nn, cat, stack, tensor, Tensor, arange, cartesian_prod
|
13
13
|
import torch.nn.functional as F
|
14
14
|
|
15
15
|
from x_transformers.autoregressive_wrapper import (
|
@@ -34,6 +34,25 @@ def exists(v):
|
|
34
34
|
def default(v, d):
|
35
35
|
return v if exists(v) else d
|
36
36
|
|
37
|
+
# a custom flip that can handle variable lengths across batch
|
38
|
+
|
39
|
+
def flip(x, dim = 1, lens = None):
|
40
|
+
if not exists(lens):
|
41
|
+
return x.flip(dim)
|
42
|
+
|
43
|
+
batch, seq_len, device = *x.shape[:2], x.device
|
44
|
+
seq = arange(seq_len, device = device)
|
45
|
+
|
46
|
+
mask = einx.less('j, i -> i j', seq, lens)
|
47
|
+
masked_seq = einx.where('i j, j,', mask, seq, -1)
|
48
|
+
|
49
|
+
flip_indices = masked_seq.argsort(dim = -1, descending = True)
|
50
|
+
|
51
|
+
if x.ndim == 3:
|
52
|
+
flip_indices = repeat(flip_indices, '... -> ... d', d = x.shape[-1])
|
53
|
+
|
54
|
+
return x.gather(dim, flip_indices)
|
55
|
+
|
37
56
|
# wrappers
|
38
57
|
|
39
58
|
class BeliefStateWrapper(Module):
|
@@ -230,19 +249,25 @@ class BeliefStateWrapper(Module):
|
|
230
249
|
def forward(
|
231
250
|
self,
|
232
251
|
seq,
|
252
|
+
lens: Tensor | None = None, # Int['b']
|
233
253
|
return_loss_only = False,
|
234
254
|
loss_scale = 1.,
|
235
255
|
loss_weight_by_fb_indices: callable | None = None
|
236
256
|
):
|
237
257
|
batch, seq_len, device = *seq.shape, seq.device
|
238
258
|
|
259
|
+
# handle variable length sequences
|
260
|
+
|
261
|
+
if exists(lens):
|
262
|
+
mask = einx.less('j, i -> i j', arange(seq_len, device = device), lens)
|
263
|
+
|
239
264
|
# forward autoregressive
|
240
265
|
|
241
266
|
forward_embeds = self.forward_decoder(seq, return_embeddings = True)
|
242
267
|
|
243
268
|
# backward autoregressive
|
244
269
|
|
245
|
-
backward_seq =
|
270
|
+
backward_seq = flip(seq, lens = lens)
|
246
271
|
|
247
272
|
suffix_tokens = repeat(self.suffix_token, 'd -> b 1 d', b = batch)
|
248
273
|
|
@@ -252,7 +277,7 @@ class BeliefStateWrapper(Module):
|
|
252
277
|
return_embeddings = True
|
253
278
|
)
|
254
279
|
|
255
|
-
backward_embeds =
|
280
|
+
backward_embeds = flip(backward_embeds, lens = lens)
|
256
281
|
|
257
282
|
# trick to reduce memory on backwards pass
|
258
283
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=p9OZTX-xs8YAiPdDCpZLqCG8VGmF8OJU6Zriy3mTGfo,12399
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.27.dist-info/METADATA,sha256=mUUzC0KPtW-RFPxwHyY_xr72ZrfDjPjjhaCyH4erUTw,87875
|
14
|
+
x_transformers-2.1.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.27.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.27.dist-info/RECORD,,
|
File without changes
|
File without changes
|