x-transformers 2.3.16__py3-none-any.whl → 2.3.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,8 +36,16 @@ def eval_decorator(fn):
36
36
 
37
37
  # for variable lengthed prefixes
38
38
 
39
+ def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
40
+ if pad == (0, 0):
41
+ return t
42
+
43
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
44
+ zeros = ((0, 0) * dims_from_right)
45
+ return F.pad(t, (*zeros, *pad), value = value)
46
+
39
47
  def align_right(t, lens, pad_id = 0):
40
- batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
48
+ batch, seq_len, device, dtype = *t.shape[:2], t.device, t.dtype
41
49
 
42
50
  assert lens.ndim == 1 and lens.shape[0] == batch
43
51
  assert lens.amax() <= seq_len
@@ -48,10 +56,10 @@ def align_right(t, lens, pad_id = 0):
48
56
  batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
49
57
  prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
50
58
 
51
- t = F.pad(t, (max_pad_len, 0), value = pad_id)
59
+ t = pad_at_dim(t, (max_pad_len, 0), value = pad_id, dim = 1)
52
60
  offset = max_pad_len - pad_lens
53
61
 
54
- aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
62
+ aligned = t[batch_arange, prompt_len_arange + offset[..., None], ...]
55
63
  return aligned
56
64
 
57
65
  # nucleus
@@ -9,6 +9,8 @@ from torch.distributions import Normal
9
9
  import einx
10
10
  from einops import rearrange, reduce, pack, repeat, unpack
11
11
 
12
+ from x_transformers.autoregressive_wrapper import align_right
13
+
12
14
  from x_transformers.x_transformers import (
13
15
  Attention,
14
16
  AttentionLayers,
@@ -222,7 +224,6 @@ class ContinuousAutoregressiveWrapper(Module):
222
224
  net: ContinuousTransformerWrapper,
223
225
  loss_fn: Module | None = None,
224
226
  equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
225
- rollout_steps = 1 # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
226
227
  ):
227
228
  super().__init__()
228
229
  self.net = net
@@ -236,14 +237,6 @@ class ContinuousAutoregressiveWrapper(Module):
236
237
  self.loss_fn = loss_fn
237
238
  self.equal_loss_weight_batch = equal_loss_weight_batch
238
239
 
239
- # num rollout steps - if greater than one, recurrently feedback the output and enforce loss rollout steps - 1 ahead
240
- # applied successfully in vjepa2 world model, with rollout steps of 2
241
- # rollout steps of 1 would be the same as single step autoregressive
242
-
243
- assert not (rollout_steps > 1 and probabilistic), f'rollout steps greater than 1 only supported for non-probabilistic'
244
- assert 1 <= rollout_steps
245
- self.rollout_steps = rollout_steps
246
-
247
240
  @torch.no_grad()
248
241
  def generate(
249
242
  self,
@@ -298,40 +291,18 @@ class ContinuousAutoregressiveWrapper(Module):
298
291
  self.net.train(was_training)
299
292
  return out
300
293
 
301
- def forward(
294
+ def forward_rollout(
302
295
  self,
303
296
  x,
297
+ rollout_steps = 2,
304
298
  **kwargs
305
299
  ):
306
- steps = self.rollout_steps
307
- one_step_autoregress = steps == 1
308
-
309
- # get the input
310
-
311
- inp = x[:, :-steps]
300
+ assert rollout_steps > 1
301
+ assert not self.probabilistic, 'probabilistic not supported yet'
312
302
 
313
- # variables
314
-
315
- batch, seq_len, device = *inp.shape[:2], inp.device
316
-
317
- # get target
318
-
319
- seq_start_pos = None
320
-
321
- if one_step_autoregress:
322
- target = x[:, None, 1:]
323
- else:
324
-
325
- batch_arange = arange(batch, device = device)
326
- batch_arange = rearrange(batch_arange, 'b -> b 1 1')
327
- seq_arange = arange(seq_len, device = device)
328
- steps_arange = arange(steps, device = device) + 1
303
+ steps = rollout_steps
329
304
 
330
- target_indices = einx.add('r, n -> r n', steps_arange, seq_arange)
331
-
332
- target = x[batch_arange, target_indices] # rollout targets
333
-
334
- seq_start_pos = torch.zeros(batch, device = device, dtype = torch.long)
305
+ device = x.device
335
306
 
336
307
  # assert inputs
337
308
 
@@ -348,53 +319,114 @@ class ContinuousAutoregressiveWrapper(Module):
348
319
  mask = einx.less('j, i -> i j', seq_arange, lens)
349
320
  kwargs['mask'] = mask
350
321
 
322
+ if not exists(lens):
323
+ batch, seq_len = x.shape[:2]
324
+ lens = torch.full((batch,), seq_len, device = device)
325
+
351
326
  # handle mask manually
352
327
 
353
328
  mask = kwargs.pop('mask', None)
354
329
 
355
- has_mask = exists(mask)
330
+ # pick a random range for each batch sample and aligh the sequence to the right for rollout loss
356
331
 
357
- # maybe rollout
332
+ valid_tokens_for_rollout = (lens - steps).clamp(min = 0)
333
+ valid_sample = valid_tokens_for_rollout > 0
334
+
335
+ x = x[valid_sample] # remove invalid sequence (lens less than rollout steps)
336
+
337
+ if exists(mask):
338
+ mask = mask[valid_sample]
339
+
340
+ batch = x.shape[0]
341
+ seq_start_pos = (torch.rand((batch,), device = device) * valid_tokens_for_rollout).floor().long()
358
342
 
359
- outputs = []
360
- masks = []
343
+ batch_arange = torch.arange(batch, device = device)
344
+ batch_arange = rearrange(batch_arange, 'b -> b 1')
361
345
 
362
- for step_index in range(steps):
346
+ # crop out sequence to use
363
347
 
364
- step_mask = None
365
- if has_mask:
366
- step_mask = mask[:, step_index:(step_index + seq_len)]
367
- masks.append(step_mask)
348
+ seq_end_pos = seq_start_pos + steps
349
+ max_end_pos = seq_end_pos.amax().item()
350
+ x = x[:, :max_end_pos]
368
351
 
369
- # forward
352
+ x = align_right(x, seq_end_pos)
370
353
 
371
- out = self.net(inp, mask = step_mask, seq_start_pos = seq_start_pos, **kwargs)
354
+ # get the input
355
+
356
+ inp, targets = x[:, :-steps], x[:, -steps:]
357
+
358
+ # maybe rollout
359
+
360
+ cache = None
361
+ preds = []
372
362
 
373
- outputs.append(out)
363
+ for _ in range(steps):
374
364
 
375
- inp = out
365
+ out, cache = self.net(
366
+ inp,
367
+ seq_start_pos = seq_start_pos,
368
+ return_intermediates = True,
369
+ **kwargs
370
+ )
376
371
 
377
- if not one_step_autoregress:
378
- seq_start_pos.sub_(1)
372
+ last_pred = out[:, -1:]
373
+ inp = last_pred
379
374
 
380
- # stack masks and predictions from rollouts
375
+ preds.append(last_pred)
381
376
 
382
- masks = stack(masks, dim = 1) if exists(mask) else None
377
+ # stack for predictions
383
378
 
384
- pred = stack(outputs, dim = 1)
379
+ preds = cat(preds, dim = 1)
385
380
 
386
381
  # loss
387
382
 
388
- loss = self.loss_fn(pred, target)
383
+ loss = self.loss_fn(preds, targets)
384
+
385
+ return loss.mean()
386
+
387
+ def forward(
388
+ self,
389
+ x,
390
+ rollout_steps = 1, # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
391
+ **kwargs
392
+ ):
393
+ if rollout_steps > 1:
394
+ return self.forward_rollout(x, rollout_steps = rollout_steps, **kwargs)
395
+
396
+ inp, target = x[:, :-1], x[:, 1:]
397
+
398
+ assert 'prepend_embeds' not in kwargs
399
+
400
+ # lens
401
+
402
+ lens = kwargs.pop('lens', None)
403
+
404
+ if exists(lens):
405
+ assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
406
+ seq_len, device = inp.shape[1], inp.device
407
+ seq_arange = torch.arange(seq_len, device = device)
408
+ mask = einx.less('j, i -> i j', seq_arange, lens)
409
+
410
+ kwargs['mask'] = mask
411
+
412
+ # mask
413
+
414
+ mask = kwargs.get('mask', None)
415
+
416
+ if exists(mask) and mask.shape[1] == x.shape[1]:
417
+ mask = mask[:, :-1]
418
+ kwargs['mask'] = mask
419
+
420
+ out = self.net(inp, **kwargs)
389
421
 
390
- # adjusting loss based on mask
422
+ loss = self.loss_fn(out, target)
391
423
 
392
- if has_mask:
424
+ if exists(mask):
393
425
  assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
394
426
 
395
427
  if self.equal_loss_weight_batch:
396
- loss = masked_mean(loss, masks)
428
+ loss = masked_mean(loss, mask)
397
429
  else:
398
- loss = loss[masks]
430
+ loss = loss[mask]
399
431
 
400
432
  return loss.mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.16
3
+ Version: 2.3.17
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,8 +1,8 @@
1
1
  x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
2
  x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
3
- x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
3
+ x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=jy2wsQ3sS80Qwm_gnAmdAnzBfzLoWrGPacOTzU1Q6JM,11674
5
+ x_transformers/continuous.py,sha256=QidhbSgBhYDpAp7FaryxrNJxWXp0-pJkeygguTpVp4k,12308
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.16.dist-info/METADATA,sha256=-lL73g4mG5pszuaU7lPdMVGJ7ZtqBqhaejr5VvWWUiw,89897
15
- x_transformers-2.3.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.16.dist-info/RECORD,,
14
+ x_transformers-2.3.17.dist-info/METADATA,sha256=u_DVXsX7KsVhnfMhpV-3KV6KGNHddmzW_SNG9om557s,89897
15
+ x_transformers-2.3.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.17.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.17.dist-info/RECORD,,