x-transformers 2.3.14__py3-none-any.whl → 2.3.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import torch
4
- from torch import nn, cat, stack
4
+ from torch import nn, cat, stack, arange
5
5
  from torch.nn import Module
6
6
  import torch.nn.functional as F
7
7
  from torch.distributions import Normal
@@ -64,7 +64,7 @@ class ContinuousTransformerWrapper(Module):
64
64
  use_abs_pos_emb = True,
65
65
  scaled_sinu_pos_emb = False,
66
66
  average_pool_embed = False,
67
- probabilistic = False
67
+ probabilistic = False,
68
68
  ):
69
69
  super().__init__()
70
70
  dim = attn_layers.dim
@@ -130,6 +130,7 @@ class ContinuousTransformerWrapper(Module):
130
130
  sum_embeds = None,
131
131
  prepend_embeds = None,
132
132
  prepend_mask = None,
133
+ seq_start_pos = None,
133
134
  **kwargs
134
135
  ):
135
136
  batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
@@ -138,14 +139,14 @@ class ContinuousTransformerWrapper(Module):
138
139
 
139
140
  if exists(lens):
140
141
  assert not exists(mask), 'either `mask` or `lens` passed in, but not both'
141
- seq_arange = torch.arange(seq, device = device)
142
+ seq_arange = arange(seq, device = device)
142
143
 
143
144
  mask = einx.less('j, i -> i j', seq_arange, lens)
144
145
 
145
146
  # project in + positional embedding
146
147
 
147
148
  x = self.project_in(x)
148
- x = x + self.pos_emb(x, pos = pos)
149
+ x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos)
149
150
 
150
151
  if exists(sum_embeds):
151
152
  x = x + sum_embeds
@@ -220,7 +221,8 @@ class ContinuousAutoregressiveWrapper(Module):
220
221
  self,
221
222
  net: ContinuousTransformerWrapper,
222
223
  loss_fn: Module | None = None,
223
- equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
224
+ equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
225
+ rollout_steps = 1 # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
224
226
  ):
225
227
  super().__init__()
226
228
  self.net = net
@@ -234,6 +236,14 @@ class ContinuousAutoregressiveWrapper(Module):
234
236
  self.loss_fn = loss_fn
235
237
  self.equal_loss_weight_batch = equal_loss_weight_batch
236
238
 
239
+ # num rollout steps - if greater than one, recurrently feedback the output and enforce loss rollout steps - 1 ahead
240
+ # applied successfully in vjepa2 world model, with rollout steps of 2
241
+ # rollout steps of 1 would be the same as single step autoregressive
242
+
243
+ assert not (rollout_steps > 1 and probabilistic), f'rollout steps greater than 1 only supported for non-probabilistic'
244
+ assert 1 <= rollout_steps
245
+ self.rollout_steps = rollout_steps
246
+
237
247
  @torch.no_grad()
238
248
  def generate(
239
249
  self,
@@ -247,12 +257,13 @@ class ContinuousAutoregressiveWrapper(Module):
247
257
  device = start_tokens.device
248
258
 
249
259
  was_training = self.net.training
250
- num_dims = len(start_tokens.shape)
260
+ num_dims = start_tokens.ndim
251
261
 
252
262
  assert num_dims >= 2, 'number of dimensions of your start tokens must be greater or equal to 2'
263
+ no_batch = num_dims == 2
253
264
 
254
- if num_dims == 2:
255
- start_tokens = start_tokens[None, :]
265
+ if no_batch:
266
+ start_tokens = rearrange(start_tokens, 'n d -> 1 n d')
256
267
 
257
268
  b, t, _, device = *start_tokens.shape, start_tokens.device
258
269
 
@@ -281,8 +292,8 @@ class ContinuousAutoregressiveWrapper(Module):
281
292
 
282
293
  out = out[:, t:]
283
294
 
284
- if num_dims == 2:
285
- out = out.squeeze(0)
295
+ if no_batch:
296
+ out = rearrange(out, '1 n d -> n d')
286
297
 
287
298
  self.net.train(was_training)
288
299
  return out
@@ -292,7 +303,37 @@ class ContinuousAutoregressiveWrapper(Module):
292
303
  x,
293
304
  **kwargs
294
305
  ):
295
- inp, target = x[:, :-1], x[:, 1:]
306
+ steps = self.rollout_steps
307
+ one_step_autoregress = steps == 1
308
+
309
+ # get the input
310
+
311
+ inp = x[:, :-steps]
312
+
313
+ # variables
314
+
315
+ batch, seq_len, device = *inp.shape[:2], inp.device
316
+
317
+ # get target
318
+
319
+ seq_start_pos = None
320
+
321
+ if one_step_autoregress:
322
+ target = x[:, None, 1:]
323
+ else:
324
+
325
+ batch_arange = arange(batch, device = device)
326
+ batch_arange = rearrange(batch_arange, 'b -> b 1 1')
327
+ seq_arange = arange(seq_len, device = device)
328
+ steps_arange = arange(steps, device = device) + 1
329
+
330
+ target_indices = einx.add('r, n -> r n', steps_arange, seq_arange)
331
+
332
+ target = x[batch_arange, target_indices] # rollout targets
333
+
334
+ seq_start_pos = torch.zeros(batch, device = device, dtype = torch.long)
335
+
336
+ # assert inputs
296
337
 
297
338
  assert 'prepend_embeds' not in kwargs
298
339
 
@@ -303,29 +344,57 @@ class ContinuousAutoregressiveWrapper(Module):
303
344
  if exists(lens):
304
345
  assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
305
346
  seq_len, device = inp.shape[1], inp.device
306
- seq_arange = torch.arange(seq_len, device = device)
347
+ seq_arange = arange(seq_len, device = device)
307
348
  mask = einx.less('j, i -> i j', seq_arange, lens)
308
-
309
349
  kwargs['mask'] = mask
310
350
 
311
- # mask
351
+ # handle mask manually
312
352
 
313
- mask = kwargs.get('mask', None)
353
+ mask = kwargs.pop('mask', None)
314
354
 
315
- if exists(mask) and mask.shape[1] == x.shape[1]:
316
- mask = mask[:, :-1]
317
- kwargs['mask'] = mask
355
+ has_mask = exists(mask)
356
+
357
+ # maybe rollout
358
+
359
+ outputs = []
360
+ masks = []
361
+
362
+ for step_index in range(steps):
363
+
364
+ step_mask = None
365
+ if has_mask:
366
+ step_mask = mask[:, step_index:(step_index + seq_len)]
367
+ masks.append(step_mask)
368
+
369
+ # forward
370
+
371
+ out = self.net(inp, mask = step_mask, seq_start_pos = seq_start_pos, **kwargs)
372
+
373
+ outputs.append(out)
374
+
375
+ inp = out
376
+
377
+ if not one_step_autoregress:
378
+ seq_start_pos.sub_(1)
379
+
380
+ # stack masks and predictions from rollouts
381
+
382
+ masks = stack(masks, dim = 1) if exists(mask) else None
383
+
384
+ pred = stack(outputs, dim = 1)
385
+
386
+ # loss
318
387
 
319
- out = self.net(inp, **kwargs)
388
+ loss = self.loss_fn(pred, target)
320
389
 
321
- loss = self.loss_fn(out, target)
390
+ # adjusting loss based on mask
322
391
 
323
- if exists(mask):
392
+ if has_mask:
324
393
  assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
325
394
 
326
395
  if self.equal_loss_weight_batch:
327
- loss = masked_mean(loss, mask)
396
+ loss = masked_mean(loss, masks)
328
397
  else:
329
- loss = loss[mask]
398
+ loss = loss[masks]
330
399
 
331
400
  return loss.mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.14
3
+ Version: 2.3.16
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2486,4 +2486,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2486
2486
  }
2487
2487
  ```
2488
2488
 
2489
+ ```bibtex
2490
+ @inproceedings{Assran2025VJEPA2S,
2491
+ title = {V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning},
2492
+ author = {Mahmoud Assran and Adrien Bardes and David Fan and Quentin Garrido and Russell Howes and Mojtaba Komeili and Matthew Muckley and Ammar Rizvi and Claire Roberts and Koustuv Sinha and Artem Zholus and Sergio Arnaud and Abha Gejji and Ada Martin and Francois Robert Hogan and Daniel Dugas and Piotr Bojanowski and Vasil Khalidov and Patrick Labatut and Francisco Massa and Marc Szafraniec and Kapil Krishnakumar and Yong Li and Xiaodong Ma and Sarath Chandar and Franziska Meier and Yann LeCun and Michael Rabbat and Nicolas Ballas and Fair at Meta and Mila - Qu{\'e}bec and AI Institute and Polytechnique Montr{\'e}al},
2493
+ year = {2025},
2494
+ url = {https://api.semanticscholar.org/CorpusID:279306055}
2495
+ }
2496
+ ```
2497
+
2489
2498
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
2
2
  x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=KPKi7TKqHYcDWYVhSkSB9y5iZMnhzVZxHhjJRdL7w5I,9521
5
+ x_transformers/continuous.py,sha256=jy2wsQ3sS80Qwm_gnAmdAnzBfzLoWrGPacOTzU1Q6JM,11674
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.14.dist-info/METADATA,sha256=Tnvnrfnr-eIlUVEH3IePLykynVikAq-t01v4pSh3yPQ,89022
15
- x_transformers-2.3.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.14.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.14.dist-info/RECORD,,
14
+ x_transformers-2.3.16.dist-info/METADATA,sha256=-lL73g4mG5pszuaU7lPdMVGJ7ZtqBqhaejr5VvWWUiw,89897
15
+ x_transformers-2.3.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.16.dist-info/RECORD,,