x-transformers 2.1.20__py3-none-any.whl → 2.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +41 -25
- {x_transformers-2.1.20.dist-info → x_transformers-2.1.21.dist-info}/METADATA +1 -1
- {x_transformers-2.1.20.dist-info → x_transformers-2.1.21.dist-info}/RECORD +5 -5
- {x_transformers-2.1.20.dist-info → x_transformers-2.1.21.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.20.dist-info → x_transformers-2.1.21.dist-info}/licenses/LICENSE +0 -0
@@ -34,15 +34,6 @@ def exists(v):
|
|
34
34
|
def default(v, d):
|
35
35
|
return v if exists(v) else d
|
36
36
|
|
37
|
-
def eval_decorator(fn):
|
38
|
-
def inner(self, *args, **kwargs):
|
39
|
-
was_training = self.training
|
40
|
-
self.eval()
|
41
|
-
out = fn(self, *args, **kwargs)
|
42
|
-
self.train(was_training)
|
43
|
-
return out
|
44
|
-
return inner
|
45
|
-
|
46
37
|
# wrappers
|
47
38
|
|
48
39
|
class BeliefStateWrapper(Module):
|
@@ -69,6 +60,8 @@ class BeliefStateWrapper(Module):
|
|
69
60
|
dim = forward_decoder.emb_dim
|
70
61
|
num_tokens = forward_decoder.num_tokens
|
71
62
|
|
63
|
+
self.num_tokens = num_tokens
|
64
|
+
|
72
65
|
# the suffix token
|
73
66
|
|
74
67
|
self.suffix_token = nn.Parameter(torch.zeros(dim))
|
@@ -132,6 +125,7 @@ class BeliefStateWrapper(Module):
|
|
132
125
|
filter_kwargs = dict(
|
133
126
|
min_p = 0.1
|
134
127
|
),
|
128
|
+
decode_backwards = False,
|
135
129
|
**kwargs
|
136
130
|
):
|
137
131
|
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
@@ -140,6 +134,14 @@ class BeliefStateWrapper(Module):
|
|
140
134
|
|
141
135
|
batch, orig_seq_len = prompts.shape
|
142
136
|
|
137
|
+
# allow for decoding backwards, to make sure it is working
|
138
|
+
|
139
|
+
main_decoder = self.forward_decoder
|
140
|
+
|
141
|
+
if decode_backwards:
|
142
|
+
prompts = prompts.flip(1)
|
143
|
+
main_decoder = self.backward_decoder
|
144
|
+
|
143
145
|
out = prompts
|
144
146
|
|
145
147
|
# kv caches
|
@@ -148,31 +150,39 @@ class BeliefStateWrapper(Module):
|
|
148
150
|
|
149
151
|
# get the encoded suffix token once
|
150
152
|
|
151
|
-
if
|
152
|
-
if suffix
|
153
|
-
|
153
|
+
if not decode_backwards:
|
154
|
+
if exists(suffix):
|
155
|
+
if suffix.ndim == 1:
|
156
|
+
suffix = repeat(suffix, 'n -> b n', b = batch)
|
154
157
|
|
155
|
-
|
158
|
+
suffix = suffix.flip(1) # reverse autoregressive
|
156
159
|
|
157
|
-
|
160
|
+
suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
|
158
161
|
|
159
|
-
|
162
|
+
suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
|
160
163
|
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
164
|
+
suffix_embed = self.backward_decoder(
|
165
|
+
suffix,
|
166
|
+
prepend_embeds = suffix_sos_tokens,
|
167
|
+
return_embeddings = True
|
168
|
+
)
|
169
|
+
|
170
|
+
# pick out the last embedding for fill in the middle
|
171
|
+
|
172
|
+
suffix_embed = suffix_embed[:, -1:]
|
173
|
+
|
174
|
+
else:
|
175
|
+
# just grab a random token for now for prefix
|
166
176
|
|
167
|
-
|
177
|
+
prefix_embed = torch.randint(0, self.num_tokens, (batch, 1), device = device)
|
168
178
|
|
169
|
-
|
179
|
+
prefix_embed = self.forward_decoder(prefix_embed, return_embeddings = True)
|
170
180
|
|
171
181
|
# sampling up to seq_len
|
172
182
|
|
173
183
|
for _ in range(seq_len):
|
174
184
|
|
175
|
-
embeds, new_cache =
|
185
|
+
embeds, new_cache = main_decoder(
|
176
186
|
out,
|
177
187
|
return_intermediates = True,
|
178
188
|
return_embeddings = True,
|
@@ -181,12 +191,18 @@ class BeliefStateWrapper(Module):
|
|
181
191
|
)
|
182
192
|
|
183
193
|
last_embeds = embeds[:, -1:]
|
184
|
-
|
194
|
+
|
195
|
+
if not decode_backwards:
|
196
|
+
embeds = cat((last_embeds, suffix_embed), dim = -1)
|
197
|
+
else:
|
198
|
+
embeds = cat((prefix_embed, last_embeds), dim = -1)
|
185
199
|
|
186
200
|
if cache_kv and self.forward_decoder.can_cache_kv:
|
187
201
|
cache = new_cache
|
188
202
|
|
189
|
-
|
203
|
+
forward_logits, backward_logits = self.text_head(embeds).chunk(2, dim = -1)
|
204
|
+
|
205
|
+
logits = forward_logits if not decode_backwards else backward_logits
|
190
206
|
|
191
207
|
logits = logits[:, -1]
|
192
208
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=vQUg5djN8TJVhofhPhtMpMbbFz6d1lGsQOukLZhsa3I,10476
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.21.dist-info/METADATA,sha256=kTobyUmA8d0yo8NuwHSrpRchyfBMGZZAa2b0Ly3hZec,87571
|
14
|
+
x_transformers-2.1.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.21.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.21.dist-info/RECORD,,
|
File without changes
|
File without changes
|