x-transformers 2.3.14__py3-none-any.whl → 2.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import torch
4
- from torch import nn, cat, stack
4
+ from torch import nn, cat, stack, arange
5
5
  from torch.nn import Module
6
6
  import torch.nn.functional as F
7
7
  from torch.distributions import Normal
@@ -64,7 +64,7 @@ class ContinuousTransformerWrapper(Module):
64
64
  use_abs_pos_emb = True,
65
65
  scaled_sinu_pos_emb = False,
66
66
  average_pool_embed = False,
67
- probabilistic = False
67
+ probabilistic = False,
68
68
  ):
69
69
  super().__init__()
70
70
  dim = attn_layers.dim
@@ -138,7 +138,7 @@ class ContinuousTransformerWrapper(Module):
138
138
 
139
139
  if exists(lens):
140
140
  assert not exists(mask), 'either `mask` or `lens` passed in, but not both'
141
- seq_arange = torch.arange(seq, device = device)
141
+ seq_arange = arange(seq, device = device)
142
142
 
143
143
  mask = einx.less('j, i -> i j', seq_arange, lens)
144
144
 
@@ -220,7 +220,8 @@ class ContinuousAutoregressiveWrapper(Module):
220
220
  self,
221
221
  net: ContinuousTransformerWrapper,
222
222
  loss_fn: Module | None = None,
223
- equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
223
+ equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
224
+ rollout_steps = 1 # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
224
225
  ):
225
226
  super().__init__()
226
227
  self.net = net
@@ -234,6 +235,14 @@ class ContinuousAutoregressiveWrapper(Module):
234
235
  self.loss_fn = loss_fn
235
236
  self.equal_loss_weight_batch = equal_loss_weight_batch
236
237
 
238
+ # num rollout steps - if greater than one, recurrently feedback the output and enforce loss rollout steps - 1 ahead
239
+ # applied successfully in vjepa2 world model, with rollout steps of 2
240
+ # rollout steps of 1 would be the same as single step autoregressive
241
+
242
+ assert not (rollout_steps > 1 and probabilistic), f'rollout steps greater than 1 only supported for non-probabilistic'
243
+ assert 1 <= rollout_steps
244
+ self.rollout_steps = rollout_steps
245
+
237
246
  @torch.no_grad()
238
247
  def generate(
239
248
  self,
@@ -247,12 +256,13 @@ class ContinuousAutoregressiveWrapper(Module):
247
256
  device = start_tokens.device
248
257
 
249
258
  was_training = self.net.training
250
- num_dims = len(start_tokens.shape)
259
+ num_dims = start_tokens.ndim
251
260
 
252
261
  assert num_dims >= 2, 'number of dimensions of your start tokens must be greater or equal to 2'
262
+ no_batch = num_dims == 2
253
263
 
254
- if num_dims == 2:
255
- start_tokens = start_tokens[None, :]
264
+ if no_batch:
265
+ start_tokens = rearrange(start_tokens, 'n d -> 1 n d')
256
266
 
257
267
  b, t, _, device = *start_tokens.shape, start_tokens.device
258
268
 
@@ -281,8 +291,8 @@ class ContinuousAutoregressiveWrapper(Module):
281
291
 
282
292
  out = out[:, t:]
283
293
 
284
- if num_dims == 2:
285
- out = out.squeeze(0)
294
+ if no_batch:
295
+ out = rearrange(out, '1 n d -> n d')
286
296
 
287
297
  self.net.train(was_training)
288
298
  return out
@@ -292,7 +302,33 @@ class ContinuousAutoregressiveWrapper(Module):
292
302
  x,
293
303
  **kwargs
294
304
  ):
295
- inp, target = x[:, :-1], x[:, 1:]
305
+ steps = self.rollout_steps
306
+ one_step_autoregress = steps == 1
307
+
308
+ # get the input
309
+
310
+ inp = x[:, :-steps]
311
+
312
+ # variables
313
+
314
+ batch, seq_len, device = *inp.shape[:2], inp.device
315
+
316
+ # get target
317
+
318
+ if one_step_autoregress:
319
+ target = x[:, None, 1:]
320
+ else:
321
+
322
+ batch_arange = arange(batch, device = device)
323
+ batch_arange = rearrange(batch_arange, 'b -> b 1 1')
324
+ seq_arange = arange(seq_len, device = device)
325
+ steps_arange = arange(steps, device = device) + 1
326
+
327
+ target_indices = einx.add('r, n -> r n', steps_arange, seq_arange)
328
+
329
+ target = x[batch_arange, target_indices] # rollout targets
330
+
331
+ # assert inputs
296
332
 
297
333
  assert 'prepend_embeds' not in kwargs
298
334
 
@@ -303,29 +339,54 @@ class ContinuousAutoregressiveWrapper(Module):
303
339
  if exists(lens):
304
340
  assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
305
341
  seq_len, device = inp.shape[1], inp.device
306
- seq_arange = torch.arange(seq_len, device = device)
342
+ seq_arange = arange(seq_len, device = device)
307
343
  mask = einx.less('j, i -> i j', seq_arange, lens)
308
-
309
344
  kwargs['mask'] = mask
310
345
 
311
- # mask
346
+ # handle mask manually
312
347
 
313
- mask = kwargs.get('mask', None)
348
+ mask = kwargs.pop('mask', None)
314
349
 
315
- if exists(mask) and mask.shape[1] == x.shape[1]:
316
- mask = mask[:, :-1]
317
- kwargs['mask'] = mask
350
+ has_mask = exists(mask)
351
+
352
+ # maybe rollout
353
+
354
+ outputs = []
355
+ masks = []
356
+
357
+ for step_index in range(steps):
358
+
359
+ step_mask = None
360
+ if has_mask:
361
+ step_mask = mask[:, step_index:(step_index + seq_len)]
362
+ masks.append(step_mask)
363
+
364
+ # forward
365
+
366
+ out = self.net(inp, mask = step_mask, **kwargs)
367
+
368
+ outputs.append(out)
369
+
370
+ inp = out
371
+
372
+ # stack masks and predictions from rollouts
373
+
374
+ masks = stack(masks, dim = 1) if exists(mask) else None
375
+
376
+ pred = stack(outputs, dim = 1)
377
+
378
+ # loss
318
379
 
319
- out = self.net(inp, **kwargs)
380
+ loss = self.loss_fn(pred, target)
320
381
 
321
- loss = self.loss_fn(out, target)
382
+ # adjusting loss based on mask
322
383
 
323
- if exists(mask):
384
+ if has_mask:
324
385
  assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
325
386
 
326
387
  if self.equal_loss_weight_batch:
327
- loss = masked_mean(loss, mask)
388
+ loss = masked_mean(loss, masks)
328
389
  else:
329
- loss = loss[mask]
390
+ loss = loss[masks]
330
391
 
331
392
  return loss.mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.14
3
+ Version: 2.3.15
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2486,4 +2486,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2486
2486
  }
2487
2487
  ```
2488
2488
 
2489
+ ```bibtex
2490
+ @inproceedings{Assran2025VJEPA2S,
2491
+ title = {V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning},
2492
+ author = {Mahmoud Assran and Adrien Bardes and David Fan and Quentin Garrido and Russell Howes and Mojtaba Komeili and Matthew Muckley and Ammar Rizvi and Claire Roberts and Koustuv Sinha and Artem Zholus and Sergio Arnaud and Abha Gejji and Ada Martin and Francois Robert Hogan and Daniel Dugas and Piotr Bojanowski and Vasil Khalidov and Patrick Labatut and Francisco Massa and Marc Szafraniec and Kapil Krishnakumar and Yong Li and Xiaodong Ma and Sarath Chandar and Franziska Meier and Yann LeCun and Michael Rabbat and Nicolas Ballas and Fair at Meta and Mila - Qu{\'e}bec and AI Institute and Polytechnique Montr{\'e}al},
2493
+ year = {2025},
2494
+ url = {https://api.semanticscholar.org/CorpusID:279306055}
2495
+ }
2496
+ ```
2497
+
2489
2498
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
2
2
  x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=KPKi7TKqHYcDWYVhSkSB9y5iZMnhzVZxHhjJRdL7w5I,9521
5
+ x_transformers/continuous.py,sha256=5SZmi3Bd77aJAu50f4y1OwwruZd_3ZHptC8dtQmvvxM,11387
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.14.dist-info/METADATA,sha256=Tnvnrfnr-eIlUVEH3IePLykynVikAq-t01v4pSh3yPQ,89022
15
- x_transformers-2.3.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.14.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.14.dist-info/RECORD,,
14
+ x_transformers-2.3.15.dist-info/METADATA,sha256=Dv6r9tZhbF-_7q5yhgmo3pto-D6NVJnYPNnSBEBt73I,89897
15
+ x_transformers-2.3.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.15.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.15.dist-info/RECORD,,