x-transformers 2.1.5__py3-none-any.whl → 2.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/__init__.py +1 -0
- x_transformers/{belief_state.py → belief_state_wrapper.py} +33 -7
- {x_transformers-2.1.5.dist-info → x_transformers-2.1.7.dist-info}/METADATA +1 -1
- {x_transformers-2.1.5.dist-info → x_transformers-2.1.7.dist-info}/RECORD +6 -6
- {x_transformers-2.1.5.dist-info → x_transformers-2.1.7.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.5.dist-info → x_transformers-2.1.7.dist-info}/licenses/LICENSE +0 -0
x_transformers/__init__.py
CHANGED
@@ -14,6 +14,7 @@ from x_transformers.x_transformers import (
|
|
14
14
|
|
15
15
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
16
16
|
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
|
17
|
+
from x_transformers.belief_state_wrapper import BeliefStateWrapper
|
17
18
|
|
18
19
|
from x_transformers.continuous import (
|
19
20
|
ContinuousTransformerWrapper,
|
@@ -6,7 +6,7 @@
|
|
6
6
|
import torch
|
7
7
|
from torch.autograd import Function
|
8
8
|
from torch.nn import Module, ModuleList
|
9
|
-
from torch import nn, cat, stack, arange, cartesian_prod
|
9
|
+
from torch import nn, cat, stack, tensor, arange, cartesian_prod
|
10
10
|
import torch.nn.functional as F
|
11
11
|
|
12
12
|
from x_transformers.x_transformers import (
|
@@ -36,7 +36,8 @@ class BeliefStateWrapper(Module):
|
|
36
36
|
self,
|
37
37
|
forward_decoder: TransformerWrapper,
|
38
38
|
backward_decoder: TransformerWrapper,
|
39
|
-
train_frac_forward_backward_pairs: float = 1
|
39
|
+
train_frac_forward_backward_pairs: float = 1.,
|
40
|
+
backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
|
40
41
|
):
|
41
42
|
super().__init__()
|
42
43
|
assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
|
@@ -70,9 +71,17 @@ class BeliefStateWrapper(Module):
|
|
70
71
|
self.train_frac_fb_pairs = train_frac_forward_backward_pairs
|
71
72
|
self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1.
|
72
73
|
|
74
|
+
# loss weighting
|
75
|
+
|
76
|
+
self.backward_ar_loss_weight = backward_ar_loss_weight
|
77
|
+
self.needs_loss_weight = backward_ar_loss_weight != 1.
|
78
|
+
|
79
|
+
self.register_buffer('loss_weights', tensor([1., self.backward_ar_loss_weight]))
|
80
|
+
|
73
81
|
def forward(
|
74
82
|
self,
|
75
|
-
seq
|
83
|
+
seq,
|
84
|
+
backward = True
|
76
85
|
):
|
77
86
|
batch, seq_len, device = *seq.shape, seq.device
|
78
87
|
|
@@ -149,14 +158,31 @@ class BeliefStateWrapper(Module):
|
|
149
158
|
|
150
159
|
fb_loss = F.cross_entropy(
|
151
160
|
rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
|
152
|
-
rearrange(labels, 'b n fb -> b (fb n)')
|
161
|
+
rearrange(labels, 'b n fb -> b (fb n)'),
|
162
|
+
reduction = 'none' if self.needs_loss_weight else 'mean'
|
153
163
|
)
|
154
164
|
|
165
|
+
# maybe loss weighting
|
166
|
+
|
167
|
+
if self.needs_loss_weight:
|
168
|
+
fb_loss = rearrange(fb_loss, 'b (fb n) -> b fb n', fb = 2)
|
169
|
+
fb_loss = einx.multiply('b fb n, fb', fb_loss, self.loss_weights)
|
170
|
+
fb_loss = fb_loss.mean()
|
171
|
+
|
155
172
|
# backwards
|
156
173
|
|
157
|
-
fb_loss
|
174
|
+
orig_backward = getattr(fb_loss, 'backward')
|
175
|
+
|
176
|
+
def patched_backward_fn(*args, **kwargs):
|
177
|
+
orig_backward(*args, **kwargs)
|
178
|
+
orig_forward_embeds.backward(forward_embeds.grad)
|
179
|
+
orig_backward_embeds.backward(backward_embeds.grad)
|
180
|
+
|
181
|
+
# can allow the researcher to call .backward from the outside
|
158
182
|
|
159
|
-
|
160
|
-
|
183
|
+
if backward:
|
184
|
+
patched_backward_fn()
|
185
|
+
else:
|
186
|
+
setattr(fb_loss, 'backward', patched_backward_fn)
|
161
187
|
|
162
188
|
return fb_loss
|
@@ -1,7 +1,7 @@
|
|
1
|
-
x_transformers/__init__.py,sha256=
|
1
|
+
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=WoagJe_cHguWAWzhEAmeDD_TzUsQcMQkgP2Mt88wrAg,5827
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=-80N4sqUr3sR51Ms4wCfc4jhxnPwf0ApNR4xfIsasfQ,110142
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.7.dist-info/METADATA,sha256=D7enhdRYdhTddt4kGSS-rSsvtHgtFaLYQH-cVuYxUVM,87570
|
14
|
+
x_transformers-2.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.7.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|