x-transformers 2.1.20__py3-none-any.whl → 2.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,15 +34,6 @@ def exists(v):
34
34
  def default(v, d):
35
35
  return v if exists(v) else d
36
36
 
37
- def eval_decorator(fn):
38
- def inner(self, *args, **kwargs):
39
- was_training = self.training
40
- self.eval()
41
- out = fn(self, *args, **kwargs)
42
- self.train(was_training)
43
- return out
44
- return inner
45
-
46
37
  # wrappers
47
38
 
48
39
  class BeliefStateWrapper(Module):
@@ -69,6 +60,8 @@ class BeliefStateWrapper(Module):
69
60
  dim = forward_decoder.emb_dim
70
61
  num_tokens = forward_decoder.num_tokens
71
62
 
63
+ self.num_tokens = num_tokens
64
+
72
65
  # the suffix token
73
66
 
74
67
  self.suffix_token = nn.Parameter(torch.zeros(dim))
@@ -132,6 +125,7 @@ class BeliefStateWrapper(Module):
132
125
  filter_kwargs = dict(
133
126
  min_p = 0.1
134
127
  ),
128
+ decode_backwards = False,
135
129
  **kwargs
136
130
  ):
137
131
  max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
@@ -140,6 +134,14 @@ class BeliefStateWrapper(Module):
140
134
 
141
135
  batch, orig_seq_len = prompts.shape
142
136
 
137
+ # allow for decoding backwards, to make sure it is working
138
+
139
+ main_decoder = self.forward_decoder
140
+
141
+ if decode_backwards:
142
+ prompts = prompts.flip(1)
143
+ main_decoder = self.backward_decoder
144
+
143
145
  out = prompts
144
146
 
145
147
  # kv caches
@@ -148,31 +150,39 @@ class BeliefStateWrapper(Module):
148
150
 
149
151
  # get the encoded suffix token once
150
152
 
151
- if exists(suffix):
152
- if suffix.ndim == 1:
153
- suffix = repeat(suffix, 'n -> b n', b = batch)
153
+ if not decode_backwards:
154
+ if exists(suffix):
155
+ if suffix.ndim == 1:
156
+ suffix = repeat(suffix, 'n -> b n', b = batch)
154
157
 
155
- suffix = suffix.flip(1) # reverse autoregressive
158
+ suffix = suffix.flip(1) # reverse autoregressive
156
159
 
157
- suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
160
+ suffix_sos_tokens = rearrange(self.suffix_token, 'd -> 1 1 d')
158
161
 
159
- suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
162
+ suffix_sos_tokens = repeat(suffix_sos_tokens, '1 1 d -> b 1 d', b = batch)
160
163
 
161
- suffix_embed = self.backward_decoder(
162
- suffix,
163
- prepend_embeds = suffix_sos_tokens,
164
- return_embeddings = True
165
- )
164
+ suffix_embed = self.backward_decoder(
165
+ suffix,
166
+ prepend_embeds = suffix_sos_tokens,
167
+ return_embeddings = True
168
+ )
169
+
170
+ # pick out the last embedding for fill in the middle
171
+
172
+ suffix_embed = suffix_embed[:, -1:]
173
+
174
+ else:
175
+ # just grab a random token for now for prefix
166
176
 
167
- # pick out the last embedding for fill in the middle
177
+ prefix_embed = torch.randint(0, self.num_tokens, (batch, 1), device = device)
168
178
 
169
- suffix_embed = suffix_embed[:, -1:]
179
+ prefix_embed = self.forward_decoder(prefix_embed, return_embeddings = True)
170
180
 
171
181
  # sampling up to seq_len
172
182
 
173
183
  for _ in range(seq_len):
174
184
 
175
- embeds, new_cache = self.forward_decoder(
185
+ embeds, new_cache = main_decoder(
176
186
  out,
177
187
  return_intermediates = True,
178
188
  return_embeddings = True,
@@ -181,12 +191,18 @@ class BeliefStateWrapper(Module):
181
191
  )
182
192
 
183
193
  last_embeds = embeds[:, -1:]
184
- embeds = cat((last_embeds, suffix_embed), dim = -1)
194
+
195
+ if not decode_backwards:
196
+ embeds = cat((last_embeds, suffix_embed), dim = -1)
197
+ else:
198
+ embeds = cat((prefix_embed, last_embeds), dim = -1)
185
199
 
186
200
  if cache_kv and self.forward_decoder.can_cache_kv:
187
201
  cache = new_cache
188
202
 
189
- logits, _ = self.text_head(embeds).chunk(2, dim = -1)
203
+ forward_logits, backward_logits = self.text_head(embeds).chunk(2, dim = -1)
204
+
205
+ logits = forward_logits if not decode_backwards else backward_logits
190
206
 
191
207
  logits = logits[:, -1]
192
208
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.20
3
+ Version: 2.1.21
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,7 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=NDoiBivau559WQ0FvXG4ssU3Il9aoHmTIUFN_1juz0s,911
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
- x_transformers/belief_state_wrapper.py,sha256=22jTxhNIKJuQFU8iRanOMpDdyqT_GiCZ2MAprxz6CGo,9841
4
+ x_transformers/belief_state_wrapper.py,sha256=vQUg5djN8TJVhofhPhtMpMbbFz6d1lGsQOukLZhsa3I,10476
5
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -10,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
10
10
  x_transformers/x_transformers.py,sha256=fqgtIs6__JpLWMnJa8AY5OW3AJ2GR1B5p-9TsWdiOIU,110425
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.20.dist-info/METADATA,sha256=YU5P-lgqBdEofFNiMZH1YIbgH8FddCS-l4K-n1o2h7o,87571
14
- x_transformers-2.1.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.20.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.20.dist-info/RECORD,,
13
+ x_transformers-2.1.21.dist-info/METADATA,sha256=kTobyUmA8d0yo8NuwHSrpRchyfBMGZZAa2b0Ly3hZec,87571
14
+ x_transformers-2.1.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.21.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.21.dist-info/RECORD,,