x-transformers 2.1.20__py3-none-any.whl → 2.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/belief_state_wrapper.py +42 -25
- {x_transformers-2.1.20.dist-info → x_transformers-2.1.22.dist-info}/METADATA +12 -1
- {x_transformers-2.1.20.dist-info → x_transformers-2.1.22.dist-info}/RECORD +5 -5
- {x_transformers-2.1.20.dist-info → x_transformers-2.1.22.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.20.dist-info → x_transformers-2.1.22.dist-info}/licenses/LICENSE +0 -0
@@ -34,15 +34,6 @@ def exists(v):
|
|
34
34
|
def default(v, d):
|
35
35
|
return v if exists(v) else d
|
36
36
|
|
37
|
-
def eval_decorator(fn):
|
38
|
-
def inner(self, *args, **kwargs):
|
39
|
-
was_training = self.training
|
40
|
-
self.eval()
|
41
|
-
out = fn(self, *args, **kwargs)
|
42
|
-
self.train(was_training)
|
43
|
-
return out
|
44
|
-
return inner
|
45
|
-
|
46
37
|
# wrappers
|
47
38
|
|
48
39
|
class BeliefStateWrapper(Module):
|
@@ -69,6 +60,8 @@ class BeliefStateWrapper(Module):
|
|
69
60
|
dim = forward_decoder.emb_dim
|
70
61
|
num_tokens = forward_decoder.num_tokens
|
71
62
|
|
63
|
+
self.num_tokens = num_tokens
|
64
|
+
|
72
65
|
# the suffix token
|
73
66
|
|
74
67
|
self.suffix_token = nn.Parameter(torch.zeros(dim))
|
@@ -132,6 +125,7 @@ class BeliefStateWrapper(Module):
|
|
132
125
|
filter_kwargs = dict(
|
133
126
|
min_p = 0.1
|
134
127
|
),
|
128
|
+
decode_backwards = False,
|
135
129
|
**kwargs
|
136
130
|
):
|
137
131
|
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
@@ -140,6 +134,14 @@ class BeliefStateWrapper(Module):
|
|
140
134
|
|
141
135
|
batch, orig_seq_len = prompts.shape
|
142
136
|
|
137
|
+
# allow for decoding backwards, to make sure it is working
|
138
|
+
|
139
|
+
main_decoder = self.forward_decoder
|
140
|
+
|
141
|
+
if decode_backwards:
|
142
|
+
prompts = prompts.flip(1)
|
143
|
+
main_decoder = self.backward_decoder
|
144
|
+
|
143
145
|
out = prompts
|
144
146
|
|
145
147
|
# kv caches
|
@@ -148,32 +150,41 @@ class BeliefStateWrapper(Module):
|
|
148
150
|
|
149
151
|
# get the encoded suffix token once
|
150
152
|
|
151
|
-
if exists(suffix):
|
152
|
-
if suffix.ndim == 1:
|
153
|
-
suffix = repeat(suffix, 'n -> b n', b = batch)
|
154
|
-
|
155
|
-
suffix = suffix.flip(1) # reverse autoregressive
|
156
|
-
|
157
153
|
suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
|
158
154
|
|
159
155
|
suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
|
160
156
|
|
161
|
-
|
162
|
-
suffix
|
163
|
-
|
164
|
-
|
165
|
-
)
|
157
|
+
if not decode_backwards:
|
158
|
+
if exists(suffix):
|
159
|
+
if suffix.ndim == 1:
|
160
|
+
suffix = repeat(suffix, 'n -> b n', b = batch)
|
166
161
|
|
167
|
-
|
162
|
+
suffix = suffix.flip(1) # reverse autoregressive
|
168
163
|
|
169
|
-
|
164
|
+
suffix_embed = self.backward_decoder(
|
165
|
+
suffix,
|
166
|
+
prepend_embeds = suffix_sos_tokens,
|
167
|
+
return_embeddings = True
|
168
|
+
)
|
169
|
+
|
170
|
+
# pick out the last embedding for fill in the middle
|
171
|
+
|
172
|
+
suffix_embed = suffix_embed[:, -1:]
|
173
|
+
|
174
|
+
else:
|
175
|
+
# just grab a random token for now for prefix
|
176
|
+
|
177
|
+
prefix_embed = torch.randint(0, self.num_tokens, (batch, 1), device = device)
|
178
|
+
|
179
|
+
prefix_embed = self.forward_decoder(prefix_embed, return_embeddings = True)
|
170
180
|
|
171
181
|
# sampling up to seq_len
|
172
182
|
|
173
183
|
for _ in range(seq_len):
|
174
184
|
|
175
|
-
embeds, new_cache =
|
185
|
+
embeds, new_cache = main_decoder(
|
176
186
|
out,
|
187
|
+
prepend_embeds = suffix_sos_tokens if decode_backwards else None,
|
177
188
|
return_intermediates = True,
|
178
189
|
return_embeddings = True,
|
179
190
|
cache = cache,
|
@@ -181,12 +192,18 @@ class BeliefStateWrapper(Module):
|
|
181
192
|
)
|
182
193
|
|
183
194
|
last_embeds = embeds[:, -1:]
|
184
|
-
|
195
|
+
|
196
|
+
if not decode_backwards:
|
197
|
+
embeds = cat((last_embeds, suffix_embed), dim = -1)
|
198
|
+
else:
|
199
|
+
embeds = cat((prefix_embed, last_embeds), dim = -1)
|
185
200
|
|
186
201
|
if cache_kv and self.forward_decoder.can_cache_kv:
|
187
202
|
cache = new_cache
|
188
203
|
|
189
|
-
|
204
|
+
forward_logits, backward_logits = self.text_head(embeds).chunk(2, dim = -1)
|
205
|
+
|
206
|
+
logits = forward_logits if not decode_backwards else backward_logits
|
190
207
|
|
191
208
|
logits = logits[:, -1]
|
192
209
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.1.
|
3
|
+
Version: 2.1.22
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2444,4 +2444,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2444
2444
|
}
|
2445
2445
|
```
|
2446
2446
|
|
2447
|
+
```bibtex
|
2448
|
+
@article{Charpentier2024GPTOB,
|
2449
|
+
title = {GPT or BERT: why not both?},
|
2450
|
+
author = {Lucas Georges Gabriel Charpentier and David Samuel},
|
2451
|
+
journal = {ArXiv},
|
2452
|
+
year = {2024},
|
2453
|
+
volume = {abs/2410.24159},
|
2454
|
+
url = {https://api.semanticscholar.org/CorpusID:273707069}
|
2455
|
+
}
|
2456
|
+
```
|
2457
|
+
|
2447
2458
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -1,7 +1,7 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
|
-
x_transformers/belief_state_wrapper.py,sha256=
|
4
|
+
x_transformers/belief_state_wrapper.py,sha256=crj0yaTNmszDYlueGu_plGKpxVg0GKH8Z-B66SQluNs,10550
|
5
5
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
10
10
|
x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.22.dist-info/METADATA,sha256=FBxTg2dObipuXg2cC_ykYpzLF1AHEebrMyRjoAc0Xk4,87875
|
14
|
+
x_transformers-2.1.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.22.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|