x-transformers 2.3.16__py3-none-any.whl → 2.3.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,8 +36,16 @@ def eval_decorator(fn):
36
36
 
37
37
  # for variable lengthed prefixes
38
38
 
39
+ def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
40
+ if pad == (0, 0):
41
+ return t
42
+
43
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
44
+ zeros = ((0, 0) * dims_from_right)
45
+ return F.pad(t, (*zeros, *pad), value = value)
46
+
39
47
  def align_right(t, lens, pad_id = 0):
40
- batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
48
+ batch, seq_len, device, dtype = *t.shape[:2], t.device, t.dtype
41
49
 
42
50
  assert lens.ndim == 1 and lens.shape[0] == batch
43
51
  assert lens.amax() <= seq_len
@@ -48,10 +56,10 @@ def align_right(t, lens, pad_id = 0):
48
56
  batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
49
57
  prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
50
58
 
51
- t = F.pad(t, (max_pad_len, 0), value = pad_id)
59
+ t = pad_at_dim(t, (max_pad_len, 0), value = pad_id, dim = 1)
52
60
  offset = max_pad_len - pad_lens
53
61
 
54
- aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
62
+ aligned = t[batch_arange, prompt_len_arange + offset[..., None], ...]
55
63
  return aligned
56
64
 
57
65
  # nucleus
@@ -9,6 +9,8 @@ from torch.distributions import Normal
9
9
  import einx
10
10
  from einops import rearrange, reduce, pack, repeat, unpack
11
11
 
12
+ from x_transformers.autoregressive_wrapper import align_right
13
+
12
14
  from x_transformers.x_transformers import (
13
15
  Attention,
14
16
  AttentionLayers,
@@ -222,7 +224,6 @@ class ContinuousAutoregressiveWrapper(Module):
222
224
  net: ContinuousTransformerWrapper,
223
225
  loss_fn: Module | None = None,
224
226
  equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
225
- rollout_steps = 1 # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
226
227
  ):
227
228
  super().__init__()
228
229
  self.net = net
@@ -236,14 +237,6 @@ class ContinuousAutoregressiveWrapper(Module):
236
237
  self.loss_fn = loss_fn
237
238
  self.equal_loss_weight_batch = equal_loss_weight_batch
238
239
 
239
- # num rollout steps - if greater than one, recurrently feedback the output and enforce loss rollout steps - 1 ahead
240
- # applied successfully in vjepa2 world model, with rollout steps of 2
241
- # rollout steps of 1 would be the same as single step autoregressive
242
-
243
- assert not (rollout_steps > 1 and probabilistic), f'rollout steps greater than 1 only supported for non-probabilistic'
244
- assert 1 <= rollout_steps
245
- self.rollout_steps = rollout_steps
246
-
247
240
  @torch.no_grad()
248
241
  def generate(
249
242
  self,
@@ -298,40 +291,17 @@ class ContinuousAutoregressiveWrapper(Module):
298
291
  self.net.train(was_training)
299
292
  return out
300
293
 
301
- def forward(
294
+ def forward_rollout(
302
295
  self,
303
296
  x,
297
+ rollout_steps = 2,
304
298
  **kwargs
305
299
  ):
306
- steps = self.rollout_steps
307
- one_step_autoregress = steps == 1
308
-
309
- # get the input
310
-
311
- inp = x[:, :-steps]
312
-
313
- # variables
314
-
315
- batch, seq_len, device = *inp.shape[:2], inp.device
316
-
317
- # get target
318
-
319
- seq_start_pos = None
320
-
321
- if one_step_autoregress:
322
- target = x[:, None, 1:]
323
- else:
324
-
325
- batch_arange = arange(batch, device = device)
326
- batch_arange = rearrange(batch_arange, 'b -> b 1 1')
327
- seq_arange = arange(seq_len, device = device)
328
- steps_arange = arange(steps, device = device) + 1
300
+ assert rollout_steps > 1
329
301
 
330
- target_indices = einx.add('r, n -> r n', steps_arange, seq_arange)
302
+ steps = rollout_steps
331
303
 
332
- target = x[batch_arange, target_indices] # rollout targets
333
-
334
- seq_start_pos = torch.zeros(batch, device = device, dtype = torch.long)
304
+ device = x.device
335
305
 
336
306
  # assert inputs
337
307
 
@@ -348,53 +318,120 @@ class ContinuousAutoregressiveWrapper(Module):
348
318
  mask = einx.less('j, i -> i j', seq_arange, lens)
349
319
  kwargs['mask'] = mask
350
320
 
321
+ if not exists(lens):
322
+ batch, seq_len = x.shape[:2]
323
+ lens = torch.full((batch,), seq_len, device = device)
324
+
351
325
  # handle mask manually
352
326
 
353
327
  mask = kwargs.pop('mask', None)
354
328
 
355
- has_mask = exists(mask)
329
+ # pick a random range for each batch sample and aligh the sequence to the right for rollout loss
356
330
 
357
- # maybe rollout
331
+ valid_tokens_for_rollout = (lens - steps).clamp(min = 0)
332
+ valid_sample = valid_tokens_for_rollout > 0
333
+
334
+ x = x[valid_sample] # remove invalid sequence (lens less than rollout steps)
335
+
336
+ if exists(mask):
337
+ mask = mask[valid_sample]
338
+
339
+ batch = x.shape[0]
340
+ seq_start_pos = (torch.rand((batch,), device = device) * valid_tokens_for_rollout).floor().long()
341
+
342
+ batch_arange = torch.arange(batch, device = device)
343
+ batch_arange = rearrange(batch_arange, 'b -> b 1')
358
344
 
359
- outputs = []
360
- masks = []
345
+ # crop out sequence to use
361
346
 
362
- for step_index in range(steps):
347
+ seq_end_pos = seq_start_pos + steps
348
+ max_end_pos = seq_end_pos.amax().item()
349
+ x = x[:, :max_end_pos]
363
350
 
364
- step_mask = None
365
- if has_mask:
366
- step_mask = mask[:, step_index:(step_index + seq_len)]
367
- masks.append(step_mask)
351
+ x = align_right(x, seq_end_pos)
368
352
 
369
- # forward
353
+ # get the input
354
+
355
+ inp, targets = x[:, :-steps], x[:, -steps:]
370
356
 
371
- out = self.net(inp, mask = step_mask, seq_start_pos = seq_start_pos, **kwargs)
357
+ # maybe rollout
372
358
 
373
- outputs.append(out)
359
+ cache = None
360
+ preds = []
374
361
 
375
- inp = out
362
+ for _ in range(steps):
376
363
 
377
- if not one_step_autoregress:
378
- seq_start_pos.sub_(1)
364
+ out, cache = self.net(
365
+ inp,
366
+ seq_start_pos = seq_start_pos,
367
+ return_intermediates = True,
368
+ **kwargs
369
+ )
379
370
 
380
- # stack masks and predictions from rollouts
371
+ last_pred = out[..., -1:, :]
372
+
373
+ if self.probabilistic:
374
+ mean, var = last_pred
375
+ std = var.clamp(min = 1e-5).sqrt()
376
+ inp = torch.normal(mean, std)
377
+ else:
378
+ inp = last_pred
381
379
 
382
- masks = stack(masks, dim = 1) if exists(mask) else None
380
+ preds.append(last_pred)
383
381
 
384
- pred = stack(outputs, dim = 1)
382
+ # stack for predictions
383
+
384
+ preds = cat(preds, dim = 1)
385
385
 
386
386
  # loss
387
387
 
388
- loss = self.loss_fn(pred, target)
388
+ loss = self.loss_fn(preds, targets)
389
+
390
+ return loss.mean()
391
+
392
+ def forward(
393
+ self,
394
+ x,
395
+ rollout_steps = 1, # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
396
+ **kwargs
397
+ ):
398
+ if rollout_steps > 1:
399
+ return self.forward_rollout(x, rollout_steps = rollout_steps, **kwargs)
400
+
401
+ inp, target = x[:, :-1], x[:, 1:]
402
+
403
+ assert 'prepend_embeds' not in kwargs
404
+
405
+ # lens
406
+
407
+ lens = kwargs.pop('lens', None)
408
+
409
+ if exists(lens):
410
+ assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
411
+ seq_len, device = inp.shape[1], inp.device
412
+ seq_arange = torch.arange(seq_len, device = device)
413
+ mask = einx.less('j, i -> i j', seq_arange, lens)
414
+
415
+ kwargs['mask'] = mask
416
+
417
+ # mask
418
+
419
+ mask = kwargs.get('mask', None)
420
+
421
+ if exists(mask) and mask.shape[1] == x.shape[1]:
422
+ mask = mask[:, :-1]
423
+ kwargs['mask'] = mask
424
+
425
+ out = self.net(inp, **kwargs)
389
426
 
390
- # adjusting loss based on mask
427
+ loss = self.loss_fn(out, target)
391
428
 
392
- if has_mask:
429
+ if exists(mask):
393
430
  assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
394
431
 
395
432
  if self.equal_loss_weight_batch:
396
- loss = masked_mean(loss, masks)
433
+ loss = masked_mean(loss, mask)
397
434
  else:
398
- loss = loss[masks]
435
+ loss = loss[mask]
399
436
 
400
437
  return loss.mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.16
3
+ Version: 2.3.18
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,8 +1,8 @@
1
1
  x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
2
  x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
3
- x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
3
+ x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=jy2wsQ3sS80Qwm_gnAmdAnzBfzLoWrGPacOTzU1Q6JM,11674
5
+ x_transformers/continuous.py,sha256=uV2hLQOckeRsybqJy-0F8RhAyMPJlkVHmA7QqUJHG4g,12433
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.16.dist-info/METADATA,sha256=-lL73g4mG5pszuaU7lPdMVGJ7ZtqBqhaejr5VvWWUiw,89897
15
- x_transformers-2.3.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.16.dist-info/RECORD,,
14
+ x_transformers-2.3.18.dist-info/METADATA,sha256=RKXNlO50fifu1Nas38iZRn6IJVDkv4Cen94XYVJlWg0,89897
15
+ x_transformers-2.3.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.18.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.18.dist-info/RECORD,,