x-transformers 2.1.20__py3-none-any.whl → 2.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,15 +34,6 @@ def exists(v):
34
34
  def default(v, d):
35
35
  return v if exists(v) else d
36
36
 
37
- def eval_decorator(fn):
38
- def inner(self, *args, **kwargs):
39
- was_training = self.training
40
- self.eval()
41
- out = fn(self, *args, **kwargs)
42
- self.train(was_training)
43
- return out
44
- return inner
45
-
46
37
  # wrappers
47
38
 
48
39
  class BeliefStateWrapper(Module):
@@ -69,6 +60,8 @@ class BeliefStateWrapper(Module):
69
60
  dim = forward_decoder.emb_dim
70
61
  num_tokens = forward_decoder.num_tokens
71
62
 
63
+ self.num_tokens = num_tokens
64
+
72
65
  # the suffix token
73
66
 
74
67
  self.suffix_token = nn.Parameter(torch.zeros(dim))
@@ -132,6 +125,7 @@ class BeliefStateWrapper(Module):
132
125
  filter_kwargs = dict(
133
126
  min_p = 0.1
134
127
  ),
128
+ decode_backwards = False,
135
129
  **kwargs
136
130
  ):
137
131
  max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
@@ -140,6 +134,14 @@ class BeliefStateWrapper(Module):
140
134
 
141
135
  batch, orig_seq_len = prompts.shape
142
136
 
137
+ # allow for decoding backwards, to make sure it is working
138
+
139
+ main_decoder = self.forward_decoder
140
+
141
+ if decode_backwards:
142
+ prompts = prompts.flip(1)
143
+ main_decoder = self.backward_decoder
144
+
143
145
  out = prompts
144
146
 
145
147
  # kv caches
@@ -148,32 +150,41 @@ class BeliefStateWrapper(Module):
148
150
 
149
151
  # get the encoded suffix token once
150
152
 
151
- if exists(suffix):
152
- if suffix.ndim == 1:
153
- suffix = repeat(suffix, 'n -> b n', b = batch)
154
-
155
- suffix = suffix.flip(1) # reverse autoregressive
156
-
157
153
  suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
158
154
 
159
155
  suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
160
156
 
161
- suffix_embed = self.backward_decoder(
162
- suffix,
163
- prepend_embeds = suffix_sos_tokens,
164
- return_embeddings = True
165
- )
157
+ if not decode_backwards:
158
+ if exists(suffix):
159
+ if suffix.ndim == 1:
160
+ suffix = repeat(suffix, 'n -> b n', b = batch)
166
161
 
167
- # pick out the last embedding for fill in the middle
162
+ suffix = suffix.flip(1) # reverse autoregressive
168
163
 
169
- suffix_embed = suffix_embed[:, -1:]
164
+ suffix_embed = self.backward_decoder(
165
+ suffix,
166
+ prepend_embeds = suffix_sos_tokens,
167
+ return_embeddings = True
168
+ )
169
+
170
+ # pick out the last embedding for fill in the middle
171
+
172
+ suffix_embed = suffix_embed[:, -1:]
173
+
174
+ else:
175
+ # just grab a random token for now for prefix
176
+
177
+ prefix_embed = torch.randint(0, self.num_tokens, (batch, 1), device = device)
178
+
179
+ prefix_embed = self.forward_decoder(prefix_embed, return_embeddings = True)
170
180
 
171
181
  # sampling up to seq_len
172
182
 
173
183
  for _ in range(seq_len):
174
184
 
175
- embeds, new_cache = self.forward_decoder(
185
+ embeds, new_cache = main_decoder(
176
186
  out,
187
+ prepend_embeds = suffix_sos_tokens if decode_backwards else None,
177
188
  return_intermediates = True,
178
189
  return_embeddings = True,
179
190
  cache = cache,
@@ -181,12 +192,18 @@ class BeliefStateWrapper(Module):
181
192
  )
182
193
 
183
194
  last_embeds = embeds[:, -1:]
184
- embeds = cat((last_embeds, suffix_embed), dim = -1)
195
+
196
+ if not decode_backwards:
197
+ embeds = cat((last_embeds, suffix_embed), dim = -1)
198
+ else:
199
+ embeds = cat((prefix_embed, last_embeds), dim = -1)
185
200
 
186
201
  if cache_kv and self.forward_decoder.can_cache_kv:
187
202
  cache = new_cache
188
203
 
189
- logits, _ = self.text_head(embeds).chunk(2, dim = -1)
204
+ forward_logits, backward_logits = self.text_head(embeds).chunk(2, dim = -1)
205
+
206
+ logits = forward_logits if not decode_backwards else backward_logits
190
207
 
191
208
  logits = logits[:, -1]
192
209
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.20
3
+ Version: 2.1.22
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2444,4 +2444,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2444
2444
  }
2445
2445
  ```
2446
2446
 
2447
+ ```bibtex
2448
+ @article{Charpentier2024GPTOB,
2449
+ title = {GPT or BERT: why not both?},
2450
+ author = {Lucas Georges Gabriel Charpentier and David Samuel},
2451
+ journal = {ArXiv},
2452
+ year = {2024},
2453
+ volume = {abs/2410.24159},
2454
+ url = {https://api.semanticscholar.org/CorpusID:273707069}
2455
+ }
2456
+ ```
2457
+
2447
2458
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=22jTxhNIKJuQFU8iRanOMpDdyqT_GiCZ2MAprxz6CGo,9841
4
+ x_transformers/belief_state_wrapper.py,sha256=crj0yaTNmszDYlueGu_plGKpxVg0GKH8Z-B66SQluNs,10550
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.20.dist-info/METADATA,sha256=YU5P-lgqBdEofFNiMZH1YIbgH8FddCS-l4K-n1o2h7o,87571
14
- x_transformers-2.1.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.20.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.20.dist-info/RECORD,,
13
+ x_transformers-2.1.22.dist-info/METADATA,sha256=FBxTg2dObipuXg2cC_ykYpzLF1AHEebrMyRjoAc0Xk4,87875
14
+ x_transformers-2.1.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.22.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.22.dist-info/RECORD,,