x-transformers 2.3.15__py3-none-any.whl → 2.3.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,8 +36,16 @@ def eval_decorator(fn):
36
36
 
37
37
  # for variable lengthed prefixes
38
38
 
39
+ def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
40
+ if pad == (0, 0):
41
+ return t
42
+
43
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
44
+ zeros = ((0, 0) * dims_from_right)
45
+ return F.pad(t, (*zeros, *pad), value = value)
46
+
39
47
  def align_right(t, lens, pad_id = 0):
40
- batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
48
+ batch, seq_len, device, dtype = *t.shape[:2], t.device, t.dtype
41
49
 
42
50
  assert lens.ndim == 1 and lens.shape[0] == batch
43
51
  assert lens.amax() <= seq_len
@@ -48,10 +56,10 @@ def align_right(t, lens, pad_id = 0):
48
56
  batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
49
57
  prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
50
58
 
51
- t = F.pad(t, (max_pad_len, 0), value = pad_id)
59
+ t = pad_at_dim(t, (max_pad_len, 0), value = pad_id, dim = 1)
52
60
  offset = max_pad_len - pad_lens
53
61
 
54
- aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
62
+ aligned = t[batch_arange, prompt_len_arange + offset[..., None], ...]
55
63
  return aligned
56
64
 
57
65
  # nucleus
@@ -9,6 +9,8 @@ from torch.distributions import Normal
9
9
  import einx
10
10
  from einops import rearrange, reduce, pack, repeat, unpack
11
11
 
12
+ from x_transformers.autoregressive_wrapper import align_right
13
+
12
14
  from x_transformers.x_transformers import (
13
15
  Attention,
14
16
  AttentionLayers,
@@ -130,6 +132,7 @@ class ContinuousTransformerWrapper(Module):
130
132
  sum_embeds = None,
131
133
  prepend_embeds = None,
132
134
  prepend_mask = None,
135
+ seq_start_pos = None,
133
136
  **kwargs
134
137
  ):
135
138
  batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
@@ -145,7 +148,7 @@ class ContinuousTransformerWrapper(Module):
145
148
  # project in + positional embedding
146
149
 
147
150
  x = self.project_in(x)
148
- x = x + self.pos_emb(x, pos = pos)
151
+ x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos)
149
152
 
150
153
  if exists(sum_embeds):
151
154
  x = x + sum_embeds
@@ -221,7 +224,6 @@ class ContinuousAutoregressiveWrapper(Module):
221
224
  net: ContinuousTransformerWrapper,
222
225
  loss_fn: Module | None = None,
223
226
  equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
224
- rollout_steps = 1 # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
225
227
  ):
226
228
  super().__init__()
227
229
  self.net = net
@@ -235,14 +237,6 @@ class ContinuousAutoregressiveWrapper(Module):
235
237
  self.loss_fn = loss_fn
236
238
  self.equal_loss_weight_batch = equal_loss_weight_batch
237
239
 
238
- # num rollout steps - if greater than one, recurrently feedback the output and enforce loss rollout steps - 1 ahead
239
- # applied successfully in vjepa2 world model, with rollout steps of 2
240
- # rollout steps of 1 would be the same as single step autoregressive
241
-
242
- assert not (rollout_steps > 1 and probabilistic), f'rollout steps greater than 1 only supported for non-probabilistic'
243
- assert 1 <= rollout_steps
244
- self.rollout_steps = rollout_steps
245
-
246
240
  @torch.no_grad()
247
241
  def generate(
248
242
  self,
@@ -297,36 +291,18 @@ class ContinuousAutoregressiveWrapper(Module):
297
291
  self.net.train(was_training)
298
292
  return out
299
293
 
300
- def forward(
294
+ def forward_rollout(
301
295
  self,
302
296
  x,
297
+ rollout_steps = 2,
303
298
  **kwargs
304
299
  ):
305
- steps = self.rollout_steps
306
- one_step_autoregress = steps == 1
307
-
308
- # get the input
309
-
310
- inp = x[:, :-steps]
311
-
312
- # variables
313
-
314
- batch, seq_len, device = *inp.shape[:2], inp.device
315
-
316
- # get target
317
-
318
- if one_step_autoregress:
319
- target = x[:, None, 1:]
320
- else:
321
-
322
- batch_arange = arange(batch, device = device)
323
- batch_arange = rearrange(batch_arange, 'b -> b 1 1')
324
- seq_arange = arange(seq_len, device = device)
325
- steps_arange = arange(steps, device = device) + 1
300
+ assert rollout_steps > 1
301
+ assert not self.probabilistic, 'probabilistic not supported yet'
326
302
 
327
- target_indices = einx.add('r, n -> r n', steps_arange, seq_arange)
303
+ steps = rollout_steps
328
304
 
329
- target = x[batch_arange, target_indices] # rollout targets
305
+ device = x.device
330
306
 
331
307
  # assert inputs
332
308
 
@@ -343,50 +319,114 @@ class ContinuousAutoregressiveWrapper(Module):
343
319
  mask = einx.less('j, i -> i j', seq_arange, lens)
344
320
  kwargs['mask'] = mask
345
321
 
322
+ if not exists(lens):
323
+ batch, seq_len = x.shape[:2]
324
+ lens = torch.full((batch,), seq_len, device = device)
325
+
346
326
  # handle mask manually
347
327
 
348
328
  mask = kwargs.pop('mask', None)
349
329
 
350
- has_mask = exists(mask)
330
+ # pick a random range for each batch sample and aligh the sequence to the right for rollout loss
351
331
 
352
- # maybe rollout
332
+ valid_tokens_for_rollout = (lens - steps).clamp(min = 0)
333
+ valid_sample = valid_tokens_for_rollout > 0
353
334
 
354
- outputs = []
355
- masks = []
335
+ x = x[valid_sample] # remove invalid sequence (lens less than rollout steps)
356
336
 
357
- for step_index in range(steps):
337
+ if exists(mask):
338
+ mask = mask[valid_sample]
358
339
 
359
- step_mask = None
360
- if has_mask:
361
- step_mask = mask[:, step_index:(step_index + seq_len)]
362
- masks.append(step_mask)
340
+ batch = x.shape[0]
341
+ seq_start_pos = (torch.rand((batch,), device = device) * valid_tokens_for_rollout).floor().long()
363
342
 
364
- # forward
343
+ batch_arange = torch.arange(batch, device = device)
344
+ batch_arange = rearrange(batch_arange, 'b -> b 1')
365
345
 
366
- out = self.net(inp, mask = step_mask, **kwargs)
346
+ # crop out sequence to use
367
347
 
368
- outputs.append(out)
348
+ seq_end_pos = seq_start_pos + steps
349
+ max_end_pos = seq_end_pos.amax().item()
350
+ x = x[:, :max_end_pos]
369
351
 
370
- inp = out
352
+ x = align_right(x, seq_end_pos)
353
+
354
+ # get the input
371
355
 
372
- # stack masks and predictions from rollouts
356
+ inp, targets = x[:, :-steps], x[:, -steps:]
373
357
 
374
- masks = stack(masks, dim = 1) if exists(mask) else None
358
+ # maybe rollout
359
+
360
+ cache = None
361
+ preds = []
375
362
 
376
- pred = stack(outputs, dim = 1)
363
+ for _ in range(steps):
364
+
365
+ out, cache = self.net(
366
+ inp,
367
+ seq_start_pos = seq_start_pos,
368
+ return_intermediates = True,
369
+ **kwargs
370
+ )
371
+
372
+ last_pred = out[:, -1:]
373
+ inp = last_pred
374
+
375
+ preds.append(last_pred)
376
+
377
+ # stack for predictions
378
+
379
+ preds = cat(preds, dim = 1)
377
380
 
378
381
  # loss
379
382
 
380
- loss = self.loss_fn(pred, target)
383
+ loss = self.loss_fn(preds, targets)
384
+
385
+ return loss.mean()
386
+
387
+ def forward(
388
+ self,
389
+ x,
390
+ rollout_steps = 1, # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
391
+ **kwargs
392
+ ):
393
+ if rollout_steps > 1:
394
+ return self.forward_rollout(x, rollout_steps = rollout_steps, **kwargs)
395
+
396
+ inp, target = x[:, :-1], x[:, 1:]
397
+
398
+ assert 'prepend_embeds' not in kwargs
399
+
400
+ # lens
401
+
402
+ lens = kwargs.pop('lens', None)
403
+
404
+ if exists(lens):
405
+ assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
406
+ seq_len, device = inp.shape[1], inp.device
407
+ seq_arange = torch.arange(seq_len, device = device)
408
+ mask = einx.less('j, i -> i j', seq_arange, lens)
409
+
410
+ kwargs['mask'] = mask
411
+
412
+ # mask
413
+
414
+ mask = kwargs.get('mask', None)
415
+
416
+ if exists(mask) and mask.shape[1] == x.shape[1]:
417
+ mask = mask[:, :-1]
418
+ kwargs['mask'] = mask
419
+
420
+ out = self.net(inp, **kwargs)
381
421
 
382
- # adjusting loss based on mask
422
+ loss = self.loss_fn(out, target)
383
423
 
384
- if has_mask:
424
+ if exists(mask):
385
425
  assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
386
426
 
387
427
  if self.equal_loss_weight_batch:
388
- loss = masked_mean(loss, masks)
428
+ loss = masked_mean(loss, mask)
389
429
  else:
390
- loss = loss[masks]
430
+ loss = loss[mask]
391
431
 
392
432
  return loss.mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.15
3
+ Version: 2.3.17
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,8 +1,8 @@
1
1
  x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
2
  x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
3
- x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
3
+ x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=5SZmi3Bd77aJAu50f4y1OwwruZd_3ZHptC8dtQmvvxM,11387
5
+ x_transformers/continuous.py,sha256=QidhbSgBhYDpAp7FaryxrNJxWXp0-pJkeygguTpVp4k,12308
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.15.dist-info/METADATA,sha256=Dv6r9tZhbF-_7q5yhgmo3pto-D6NVJnYPNnSBEBt73I,89897
15
- x_transformers-2.3.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.15.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.15.dist-info/RECORD,,
14
+ x_transformers-2.3.17.dist-info/METADATA,sha256=u_DVXsX7KsVhnfMhpV-3KV6KGNHddmzW_SNG9om557s,89897
15
+ x_transformers-2.3.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.17.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.17.dist-info/RECORD,,