x-transformers 2.1.4__py3-none-any.whl → 2.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state.py +22 -2
- {x_transformers-2.1.4.dist-info → x_transformers-2.1.5.dist-info}/METADATA +1 -1
- {x_transformers-2.1.4.dist-info → x_transformers-2.1.5.dist-info}/RECORD +5 -5
- {x_transformers-2.1.4.dist-info → x_transformers-2.1.5.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.4.dist-info → x_transformers-2.1.5.dist-info}/licenses/LICENSE +0 -0
x_transformers/belief_state.py
CHANGED
@@ -35,11 +35,12 @@ class BeliefStateWrapper(Module):
|
|
35
35
|
def __init__(
|
36
36
|
self,
|
37
37
|
forward_decoder: TransformerWrapper,
|
38
|
-
backward_decoder: TransformerWrapper
|
38
|
+
backward_decoder: TransformerWrapper,
|
39
|
+
train_frac_forward_backward_pairs: float = 1.
|
39
40
|
):
|
40
41
|
super().__init__()
|
41
42
|
assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
|
42
|
-
assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same
|
43
|
+
assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same number of tokens'
|
43
44
|
|
44
45
|
dim = forward_decoder.emb_dim
|
45
46
|
num_tokens = forward_decoder.num_tokens
|
@@ -47,6 +48,7 @@ class BeliefStateWrapper(Module):
|
|
47
48
|
# the suffix token
|
48
49
|
|
49
50
|
self.suffix_token = nn.Parameter(torch.zeros(dim))
|
51
|
+
nn.init.normal_(self.suffix_token, std = 0.02)
|
50
52
|
|
51
53
|
# the text prediction head, which predicts for the combinations of prefix and suffix the next and previous token for forwards and backward sequences
|
52
54
|
|
@@ -61,6 +63,13 @@ class BeliefStateWrapper(Module):
|
|
61
63
|
self.forward_decoder = forward_decoder
|
62
64
|
self.backward_decoder = backward_decoder
|
63
65
|
|
66
|
+
# what fraction of forward backward pairs to train on
|
67
|
+
# for further memory efficiency
|
68
|
+
|
69
|
+
assert 0 < train_frac_forward_backward_pairs <= 1.
|
70
|
+
self.train_frac_fb_pairs = train_frac_forward_backward_pairs
|
71
|
+
self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1.
|
72
|
+
|
64
73
|
def forward(
|
65
74
|
self,
|
66
75
|
seq
|
@@ -107,6 +116,17 @@ class BeliefStateWrapper(Module):
|
|
107
116
|
|
108
117
|
fb_pairs = fb_pairs[valid_mask]
|
109
118
|
|
119
|
+
# maybe subsample fb pairs
|
120
|
+
|
121
|
+
if self.needs_subsample_fb_pairs:
|
122
|
+
num_pairs = fb_pairs.shape[0]
|
123
|
+
|
124
|
+
num_subsampled = max(int(num_pairs * self.train_frac_fb_pairs), 1)
|
125
|
+
|
126
|
+
rand_subsampled_indices = torch.randperm(num_pairs, device = device)[:num_subsampled]
|
127
|
+
|
128
|
+
fb_pairs = fb_pairs[rand_subsampled_indices]
|
129
|
+
|
110
130
|
# get labels for both
|
111
131
|
|
112
132
|
fi, bi = fb_pairs.unbind(dim = -1)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state.py,sha256=
|
4
|
+
x_transformers/belief_state.py,sha256=GfYDeDqmhldozECgFsJ9zhd6O5NMvdYA5OwueVs8SB4,4742
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=-80N4sqUr3sR51Ms4wCfc4jhxnPwf0ApNR4xfIsasfQ,110142
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.5.dist-info/METADATA,sha256=-nQpm1eBGBXkLEXGiQK06NaIWa13CEnBujBNfTzvnJ8,87570
|
14
|
+
x_transformers-2.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|