x-transformers 2.1.15__py3-none-any.whl → 2.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +40 -8
- {x_transformers-2.1.15.dist-info → x_transformers-2.1.16.dist-info}/METADATA +1 -1
- {x_transformers-2.1.15.dist-info → x_transformers-2.1.16.dist-info}/RECORD +5 -5
- {x_transformers-2.1.15.dist-info → x_transformers-2.1.16.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.15.dist-info → x_transformers-2.1.16.dist-info}/licenses/LICENSE +0 -0
@@ -24,6 +24,7 @@ from x_transformers.x_transformers import (
|
|
24
24
|
|
25
25
|
import einx
|
26
26
|
from einops import rearrange, repeat
|
27
|
+
from einops.layers.torch import Rearrange
|
27
28
|
|
28
29
|
# helper functions
|
29
30
|
|
@@ -55,7 +56,9 @@ class BeliefStateWrapper(Module):
|
|
55
56
|
backward_decoder: TransformerWrapper | None = None,
|
56
57
|
train_frac_forward_backward_pairs: float = 1.,
|
57
58
|
text_head: Module | None = None,
|
58
|
-
backward_ar_loss_weight: float = 1
|
59
|
+
backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
|
60
|
+
pred_terminal = False,
|
61
|
+
pred_terminal_loss_weight: float = 1.
|
59
62
|
):
|
60
63
|
super().__init__()
|
61
64
|
backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token
|
@@ -82,6 +85,17 @@ class BeliefStateWrapper(Module):
|
|
82
85
|
|
83
86
|
self.text_head = text_head
|
84
87
|
|
88
|
+
# predicting terminal state (when suffix and prefix predict the same token)
|
89
|
+
|
90
|
+
self.to_terminal_logit = nn.Sequential(
|
91
|
+
nn.Linear(dim * 2, dim),
|
92
|
+
nn.LeakyReLU(),
|
93
|
+
nn.Linear(dim, 1),
|
94
|
+
Rearrange('... 1 -> ...')
|
95
|
+
) if pred_terminal else None
|
96
|
+
|
97
|
+
self.pred_terminal_loss_weight = pred_terminal_loss_weight
|
98
|
+
|
85
99
|
# the two decoders, one which is causal forward, the other causal backwards
|
86
100
|
|
87
101
|
self.forward_decoder = forward_decoder
|
@@ -265,22 +279,40 @@ class BeliefStateWrapper(Module):
|
|
265
279
|
|
266
280
|
# cross entropy loss
|
267
281
|
|
268
|
-
|
282
|
+
loss = F.cross_entropy(
|
269
283
|
rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
|
270
284
|
rearrange(labels, 'b n fb -> b (fb n)'),
|
271
285
|
reduction = 'none' if self.needs_loss_weight else 'mean'
|
272
286
|
)
|
273
287
|
|
288
|
+
# maybe predict terminal
|
289
|
+
|
290
|
+
if exists(self.to_terminal_logit):
|
291
|
+
terminal_logits = self.to_terminal_logit(fb_embeds)
|
292
|
+
|
293
|
+
labels = ((bi - fi) == 2).float() # distance is exactly 2
|
294
|
+
labels = repeat(labels, 'n -> b n', b = batch)
|
295
|
+
|
296
|
+
is_end_loss = F.binary_cross_entropy_with_logits(
|
297
|
+
terminal_logits,
|
298
|
+
labels
|
299
|
+
)
|
300
|
+
|
301
|
+
loss = (
|
302
|
+
loss +
|
303
|
+
is_end_loss * self.pred_terminal_loss_weight
|
304
|
+
)
|
305
|
+
|
274
306
|
# maybe loss weighting
|
275
307
|
|
276
308
|
if self.needs_loss_weight:
|
277
|
-
|
278
|
-
|
279
|
-
|
309
|
+
loss = rearrange(loss, 'b (fb n) -> b fb n', fb = 2)
|
310
|
+
loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
|
311
|
+
loss = loss.mean()
|
280
312
|
|
281
313
|
# backwards
|
282
314
|
|
283
|
-
orig_backward = getattr(
|
315
|
+
orig_backward = getattr(loss, 'backward')
|
284
316
|
|
285
317
|
def patched_backward_fn(*args, **kwargs):
|
286
318
|
orig_backward(*args, **kwargs)
|
@@ -292,6 +324,6 @@ class BeliefStateWrapper(Module):
|
|
292
324
|
if backward:
|
293
325
|
patched_backward_fn()
|
294
326
|
else:
|
295
|
-
setattr(
|
327
|
+
setattr(loss, 'backward', patched_backward_fn)
|
296
328
|
|
297
|
-
return
|
329
|
+
return loss
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=VjUB73yFBWevN6xMc6_1s-Yc58pJv8SDAUUEXwpR-W0,9842
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.16.dist-info/METADATA,sha256=MXaag1fuq1BsAmyh9k8sSRiHpy-jAUQW2Hn1GC53MnQ,87571
|
14
|
+
x_transformers-2.1.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.16.dist-info/RECORD,,
|
File without changes
|
File without changes
|