x-transformers 2.1.16__py3-none-any.whl → 2.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +28 -23
- {x_transformers-2.1.16.dist-info → x_transformers-2.1.18.dist-info}/METADATA +1 -1
- {x_transformers-2.1.16.dist-info → x_transformers-2.1.18.dist-info}/RECORD +5 -5
- {x_transformers-2.1.16.dist-info → x_transformers-2.1.18.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.16.dist-info → x_transformers-2.1.18.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,7 @@ from x_transformers.x_transformers import (
|
|
23
23
|
)
|
24
24
|
|
25
25
|
import einx
|
26
|
-
from einops import rearrange, repeat
|
26
|
+
from einops import rearrange, repeat, pack, unpack
|
27
27
|
from einops.layers.torch import Rearrange
|
28
28
|
|
29
29
|
# helper functions
|
@@ -126,7 +126,7 @@ class BeliefStateWrapper(Module):
|
|
126
126
|
prompts,
|
127
127
|
seq_len,
|
128
128
|
temperature = 1.25,
|
129
|
-
cache_kv =
|
129
|
+
cache_kv = False,
|
130
130
|
suffix: Tensor | None = None, # the goal conditioning
|
131
131
|
filter_logits_fn = min_p,
|
132
132
|
filter_kwargs = dict(
|
@@ -136,6 +136,8 @@ class BeliefStateWrapper(Module):
|
|
136
136
|
):
|
137
137
|
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
138
138
|
|
139
|
+
prompts, batch_ps = pack([prompts], '* d')
|
140
|
+
|
139
141
|
batch, orig_seq_len = prompts.shape
|
140
142
|
|
141
143
|
out = prompts
|
@@ -197,14 +199,19 @@ class BeliefStateWrapper(Module):
|
|
197
199
|
|
198
200
|
# concat sample
|
199
201
|
|
200
|
-
out = torch.cat((out, sample), dim
|
202
|
+
out = torch.cat((out, sample), dim = -1)
|
203
|
+
|
204
|
+
out = out[:, orig_seq_len:]
|
201
205
|
|
202
|
-
|
206
|
+
out, = unpack(out, batch_ps, '* n')
|
207
|
+
|
208
|
+
return out
|
203
209
|
|
204
210
|
def forward(
|
205
211
|
self,
|
206
212
|
seq,
|
207
|
-
|
213
|
+
return_loss_only = False,
|
214
|
+
loss_scale = 1.
|
208
215
|
):
|
209
216
|
batch, seq_len, device = *seq.shape, seq.device
|
210
217
|
|
@@ -244,6 +251,7 @@ class BeliefStateWrapper(Module):
|
|
244
251
|
# f - forward, b - backward, i - indices
|
245
252
|
|
246
253
|
fi, bi = fb_pairs.unbind(dim = -1)
|
254
|
+
|
247
255
|
valid_mask = (bi - fi) >= 2
|
248
256
|
|
249
257
|
fb_pairs = fb_pairs[valid_mask]
|
@@ -265,8 +273,9 @@ class BeliefStateWrapper(Module):
|
|
265
273
|
|
266
274
|
labels_fi, labels_bi = (fi + 1), bi
|
267
275
|
|
268
|
-
forward_labels, backward_labels = seq[:,
|
269
|
-
|
276
|
+
forward_labels, backward_labels = seq[:, labels_fi], seq[:, labels_bi]
|
277
|
+
|
278
|
+
labels = cat((forward_labels, backward_labels), dim = -1)
|
270
279
|
|
271
280
|
# get the forward and backward embedding pairs and feed them through the text head for both forward and backward predictions
|
272
281
|
|
@@ -281,7 +290,7 @@ class BeliefStateWrapper(Module):
|
|
281
290
|
|
282
291
|
loss = F.cross_entropy(
|
283
292
|
rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
|
284
|
-
|
293
|
+
labels,
|
285
294
|
reduction = 'none' if self.needs_loss_weight else 'mean'
|
286
295
|
)
|
287
296
|
|
@@ -290,12 +299,12 @@ class BeliefStateWrapper(Module):
|
|
290
299
|
if exists(self.to_terminal_logit):
|
291
300
|
terminal_logits = self.to_terminal_logit(fb_embeds)
|
292
301
|
|
293
|
-
|
294
|
-
|
302
|
+
terminal_labels = ((bi - fi) == 2).float() # distance is exactly 2
|
303
|
+
terminal_labels = repeat(terminal_labels, 'n -> b n', b = batch)
|
295
304
|
|
296
305
|
is_end_loss = F.binary_cross_entropy_with_logits(
|
297
306
|
terminal_logits,
|
298
|
-
|
307
|
+
terminal_labels
|
299
308
|
)
|
300
309
|
|
301
310
|
loss = (
|
@@ -303,6 +312,11 @@ class BeliefStateWrapper(Module):
|
|
303
312
|
is_end_loss * self.pred_terminal_loss_weight
|
304
313
|
)
|
305
314
|
|
315
|
+
# maybe early return loss
|
316
|
+
|
317
|
+
if return_loss_only:
|
318
|
+
return loss
|
319
|
+
|
306
320
|
# maybe loss weighting
|
307
321
|
|
308
322
|
if self.needs_loss_weight:
|
@@ -312,18 +326,9 @@ class BeliefStateWrapper(Module):
|
|
312
326
|
|
313
327
|
# backwards
|
314
328
|
|
315
|
-
|
316
|
-
|
317
|
-
def patched_backward_fn(*args, **kwargs):
|
318
|
-
orig_backward(*args, **kwargs)
|
319
|
-
orig_forward_embeds.backward(forward_embeds.grad)
|
320
|
-
orig_backward_embeds.backward(backward_embeds.grad)
|
321
|
-
|
322
|
-
# can allow the researcher to call .backward from the outside
|
329
|
+
(loss * loss_scale).backward()
|
323
330
|
|
324
|
-
|
325
|
-
|
326
|
-
else:
|
327
|
-
setattr(loss, 'backward', patched_backward_fn)
|
331
|
+
orig_forward_embeds.backward(forward_embeds.grad)
|
332
|
+
orig_backward_embeds.backward(backward_embeds.grad)
|
328
333
|
|
329
334
|
return loss
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=MkZqVpwEZJfZGTCB24n2DvGbnFSZefk9LZ-HA5LqGS0,9803
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.18.dist-info/METADATA,sha256=bW46vXghSJjACgIgOlf5UYzX9fZz6ryEAoevC0DnE-0,87571
|
14
|
+
x_transformers-2.1.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.18.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|