x-transformers 2.4.9__py3-none-any.whl → 2.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,7 @@ from torch import nn, Tensor
8
8
  from torch.nn import Module
9
9
  import torch.nn.functional as F
10
10
 
11
- from einops import rearrange, pack, unpack
11
+ from einops import rearrange, repeat, pack, unpack
12
12
 
13
13
  def exists(val):
14
14
  return val is not None
@@ -34,6 +34,21 @@ def eval_decorator(fn):
34
34
  return out
35
35
  return inner
36
36
 
37
+ # gumbel topk
38
+
39
+ def log(t, eps = 1e-20):
40
+ return t.clamp(min = eps).log()
41
+
42
+ def gumbel_noise(t):
43
+ return -log(-log(torch.rand_like(t)))
44
+
45
+ # function for modifying all the cached key / values
46
+
47
+ def modify_cached_kv(cache, fn):
48
+ for inter in cache.attn_intermediates:
49
+ if inter.layer_type == 'a':
50
+ inter.cached_kv = [fn(t) for t in inter.cached_kv]
51
+
37
52
  # for variable lengthed prefixes
38
53
 
39
54
  def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
@@ -161,6 +176,167 @@ class AutoregressiveWrapper(Module):
161
176
  self.add_continuous_pred_head = net.add_continuous_pred_head
162
177
  self.next_embed_loss_weight = next_embed_loss_weight
163
178
 
179
+ @torch.no_grad()
180
+ @eval_decorator
181
+ def beam_search(
182
+ self,
183
+ prompts,
184
+ seq_len,
185
+ beams = 4,
186
+ eos_token = None,
187
+ temperature = 1.,
188
+ stochastic = False,
189
+ prompt_lens: Tensor | None = None,
190
+ filter_logits_fn: str | Callable = top_k,
191
+ restrict_to_max_seq_len = True,
192
+ filter_kwargs: dict = dict(),
193
+ cache_kv = True,
194
+ **kwargs
195
+ ):
196
+ assert not exists(eos_token), 'eos token not supported yet'
197
+
198
+ max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
199
+
200
+ prompts, packed_shape = pack([prompts], '* n')
201
+
202
+ batch, orig_seq_len = prompts.shape
203
+
204
+ # handle filter logits fn given as string
205
+
206
+ if isinstance(filter_logits_fn, str):
207
+ assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
208
+
209
+ filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
210
+
211
+ # handle variable lengthed prompts (prefixes)
212
+
213
+ seq_start_pos = None
214
+ if exists(prompt_lens):
215
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
216
+ seq_start_pos = orig_seq_len - prompt_lens
217
+
218
+ # output from which sampled tokens appended to
219
+
220
+ out = prompts
221
+
222
+ # kv caches
223
+
224
+ cache = None
225
+
226
+ should_cache = cache_kv and self.net.can_cache_kv
227
+
228
+ # scores for the beams
229
+
230
+ scores = torch.zeros((batch,), device = device)
231
+
232
+ batch_arange = torch.arange(batch, device = device)
233
+
234
+ # sampling up to seq_len
235
+
236
+ for i in range(seq_len):
237
+ is_first = i == 0
238
+
239
+ if restrict_to_max_seq_len:
240
+ max_len_exceeded = out.shape[-1] > max_seq_len
241
+
242
+ assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
243
+
244
+ x = out[:, -max_seq_len:]
245
+
246
+ if exists(cache):
247
+ modify_cached_kv(cache, lambda t: t[..., -(max_seq_len - 1):, :])
248
+
249
+ logits, new_cache = self.net(
250
+ x,
251
+ return_intermediates = True,
252
+ cache = cache,
253
+ seq_start_pos = seq_start_pos,
254
+ **kwargs
255
+ )
256
+
257
+ if should_cache:
258
+ cache = new_cache
259
+
260
+ logits = logits[:, -1]
261
+
262
+ # to add to the scores
263
+
264
+ log_probs = logits.log_softmax(dim = -1)
265
+
266
+ # maybe filter by top_k, top_p (nucleus) for stochastic beam search
267
+
268
+ if stochastic and not greedy:
269
+ logits = filter_logits_fn(logits, **filter_kwargs)
270
+ logits = (logits / temperature) + gumbel_noise(logits)
271
+
272
+ # (gumbel) topk
273
+
274
+ samples = logits.topk(beams, dim = -1).indices
275
+
276
+ # get the scores for keeping track of beams
277
+
278
+ next_scores = log_probs.gather(-1, samples)
279
+
280
+ # expand beam times
281
+
282
+ scores = repeat(scores, 'b -> b beams', beams = beams)
283
+ scores = scores + next_scores
284
+
285
+ out = repeat(out, 'b ... -> (b beams) ...', beams = beams)
286
+ samples = rearrange(samples, 'b beams -> (b beams) 1')
287
+
288
+ if should_cache:
289
+ modify_cached_kv(cache, lambda t: repeat(t, 'b ... -> (b beams) ...', beams = beams))
290
+
291
+ # concat sample
292
+
293
+ out = torch.cat((out, samples), dim=-1)
294
+
295
+ # sort by score and excise
296
+ # excise out the beams
297
+
298
+ scores = rearrange(scores, '(b prev_beams) next_beams -> b (prev_beams next_beams)', b = batch)
299
+ curr_num_beams = scores.shape[-1]
300
+
301
+ if curr_num_beams > beams:
302
+ scores, sort_indices = scores.sort(dim = -1, descending = True)
303
+
304
+ scores = scores[:, :beams]
305
+ top_beams_indices = sort_indices[:, :beams]
306
+ top_beams_indices = beams * batch_arange[:, None] + top_beams_indices
307
+
308
+ flattened_beam_indices = rearrange(top_beams_indices, 'b beams -> (b beams)')
309
+
310
+ out = out[flattened_beam_indices]
311
+
312
+ modify_cached_kv(cache, lambda t: t[flattened_beam_indices])
313
+
314
+ scores = rearrange(scores, 'b beams -> (b beams)')
315
+
316
+ if not exists(eos_token):
317
+ continue
318
+
319
+ is_eos_tokens = (out == eos_token)
320
+
321
+ if is_eos_tokens.any(dim = -1).all():
322
+ break
323
+
324
+ if exists(eos_token):
325
+ # mask out everything after the eos tokens
326
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
327
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
328
+ out = out.masked_fill(mask, self.pad_value)
329
+
330
+ # select out the top beam
331
+
332
+ out = rearrange(out, '(b beams) seq -> b beams seq', b = batch)
333
+
334
+ out = out[:, 0, orig_seq_len:]
335
+
336
+ out, = unpack(out, packed_shape, '* n')
337
+
338
+ return out
339
+
164
340
  @torch.no_grad()
165
341
  @eval_decorator
166
342
  def generate(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.9
3
+ Version: 2.4.10
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
2
  x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
3
- x_transformers/autoregressive_wrapper.py,sha256=3tDUiY5kNcxNUjRERoeuFV0mXztOvgGrckoACIfHvqI,12091
3
+ x_transformers/autoregressive_wrapper.py,sha256=oTGHT5k52gsF5MXvH1WxkJ1Zq3HEmQ4laNM96RMdeiY,17594
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
@@ -12,7 +12,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
12
12
  x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.4.9.dist-info/METADATA,sha256=yRYvqg0EZr7jvv-sRfBC2iU2tCu_pUo37KtLxYu44hg,90223
16
- x_transformers-2.4.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.4.9.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.4.9.dist-info/RECORD,,
15
+ x_transformers-2.4.10.dist-info/METADATA,sha256=gbkmhEOCOTsxuZ7zwdya-GF2FkdqDjd9ZlVUcBuRHZ8,90224
16
+ x_transformers-2.4.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.4.10.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.4.10.dist-info/RECORD,,