x-transformers 2.3.14__py3-none-any.whl → 2.3.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +83 -22
- {x_transformers-2.3.14.dist-info → x_transformers-2.3.15.dist-info}/METADATA +10 -1
- {x_transformers-2.3.14.dist-info → x_transformers-2.3.15.dist-info}/RECORD +5 -5
- {x_transformers-2.3.14.dist-info → x_transformers-2.3.15.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.14.dist-info → x_transformers-2.3.15.dist-info}/licenses/LICENSE +0 -0
x_transformers/continuous.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import torch
|
4
|
-
from torch import nn, cat, stack
|
4
|
+
from torch import nn, cat, stack, arange
|
5
5
|
from torch.nn import Module
|
6
6
|
import torch.nn.functional as F
|
7
7
|
from torch.distributions import Normal
|
@@ -64,7 +64,7 @@ class ContinuousTransformerWrapper(Module):
|
|
64
64
|
use_abs_pos_emb = True,
|
65
65
|
scaled_sinu_pos_emb = False,
|
66
66
|
average_pool_embed = False,
|
67
|
-
probabilistic = False
|
67
|
+
probabilistic = False,
|
68
68
|
):
|
69
69
|
super().__init__()
|
70
70
|
dim = attn_layers.dim
|
@@ -138,7 +138,7 @@ class ContinuousTransformerWrapper(Module):
|
|
138
138
|
|
139
139
|
if exists(lens):
|
140
140
|
assert not exists(mask), 'either `mask` or `lens` passed in, but not both'
|
141
|
-
seq_arange =
|
141
|
+
seq_arange = arange(seq, device = device)
|
142
142
|
|
143
143
|
mask = einx.less('j, i -> i j', seq_arange, lens)
|
144
144
|
|
@@ -220,7 +220,8 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
220
220
|
self,
|
221
221
|
net: ContinuousTransformerWrapper,
|
222
222
|
loss_fn: Module | None = None,
|
223
|
-
equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
|
223
|
+
equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
|
224
|
+
rollout_steps = 1 # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
|
224
225
|
):
|
225
226
|
super().__init__()
|
226
227
|
self.net = net
|
@@ -234,6 +235,14 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
234
235
|
self.loss_fn = loss_fn
|
235
236
|
self.equal_loss_weight_batch = equal_loss_weight_batch
|
236
237
|
|
238
|
+
# num rollout steps - if greater than one, recurrently feedback the output and enforce loss rollout steps - 1 ahead
|
239
|
+
# applied successfully in vjepa2 world model, with rollout steps of 2
|
240
|
+
# rollout steps of 1 would be the same as single step autoregressive
|
241
|
+
|
242
|
+
assert not (rollout_steps > 1 and probabilistic), f'rollout steps greater than 1 only supported for non-probabilistic'
|
243
|
+
assert 1 <= rollout_steps
|
244
|
+
self.rollout_steps = rollout_steps
|
245
|
+
|
237
246
|
@torch.no_grad()
|
238
247
|
def generate(
|
239
248
|
self,
|
@@ -247,12 +256,13 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
247
256
|
device = start_tokens.device
|
248
257
|
|
249
258
|
was_training = self.net.training
|
250
|
-
num_dims =
|
259
|
+
num_dims = start_tokens.ndim
|
251
260
|
|
252
261
|
assert num_dims >= 2, 'number of dimensions of your start tokens must be greater or equal to 2'
|
262
|
+
no_batch = num_dims == 2
|
253
263
|
|
254
|
-
if
|
255
|
-
start_tokens = start_tokens
|
264
|
+
if no_batch:
|
265
|
+
start_tokens = rearrange(start_tokens, 'n d -> 1 n d')
|
256
266
|
|
257
267
|
b, t, _, device = *start_tokens.shape, start_tokens.device
|
258
268
|
|
@@ -281,8 +291,8 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
281
291
|
|
282
292
|
out = out[:, t:]
|
283
293
|
|
284
|
-
if
|
285
|
-
out = out
|
294
|
+
if no_batch:
|
295
|
+
out = rearrange(out, '1 n d -> n d')
|
286
296
|
|
287
297
|
self.net.train(was_training)
|
288
298
|
return out
|
@@ -292,7 +302,33 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
292
302
|
x,
|
293
303
|
**kwargs
|
294
304
|
):
|
295
|
-
|
305
|
+
steps = self.rollout_steps
|
306
|
+
one_step_autoregress = steps == 1
|
307
|
+
|
308
|
+
# get the input
|
309
|
+
|
310
|
+
inp = x[:, :-steps]
|
311
|
+
|
312
|
+
# variables
|
313
|
+
|
314
|
+
batch, seq_len, device = *inp.shape[:2], inp.device
|
315
|
+
|
316
|
+
# get target
|
317
|
+
|
318
|
+
if one_step_autoregress:
|
319
|
+
target = x[:, None, 1:]
|
320
|
+
else:
|
321
|
+
|
322
|
+
batch_arange = arange(batch, device = device)
|
323
|
+
batch_arange = rearrange(batch_arange, 'b -> b 1 1')
|
324
|
+
seq_arange = arange(seq_len, device = device)
|
325
|
+
steps_arange = arange(steps, device = device) + 1
|
326
|
+
|
327
|
+
target_indices = einx.add('r, n -> r n', steps_arange, seq_arange)
|
328
|
+
|
329
|
+
target = x[batch_arange, target_indices] # rollout targets
|
330
|
+
|
331
|
+
# assert inputs
|
296
332
|
|
297
333
|
assert 'prepend_embeds' not in kwargs
|
298
334
|
|
@@ -303,29 +339,54 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
303
339
|
if exists(lens):
|
304
340
|
assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
|
305
341
|
seq_len, device = inp.shape[1], inp.device
|
306
|
-
seq_arange =
|
342
|
+
seq_arange = arange(seq_len, device = device)
|
307
343
|
mask = einx.less('j, i -> i j', seq_arange, lens)
|
308
|
-
|
309
344
|
kwargs['mask'] = mask
|
310
345
|
|
311
|
-
# mask
|
346
|
+
# handle mask manually
|
312
347
|
|
313
|
-
mask = kwargs.
|
348
|
+
mask = kwargs.pop('mask', None)
|
314
349
|
|
315
|
-
|
316
|
-
|
317
|
-
|
350
|
+
has_mask = exists(mask)
|
351
|
+
|
352
|
+
# maybe rollout
|
353
|
+
|
354
|
+
outputs = []
|
355
|
+
masks = []
|
356
|
+
|
357
|
+
for step_index in range(steps):
|
358
|
+
|
359
|
+
step_mask = None
|
360
|
+
if has_mask:
|
361
|
+
step_mask = mask[:, step_index:(step_index + seq_len)]
|
362
|
+
masks.append(step_mask)
|
363
|
+
|
364
|
+
# forward
|
365
|
+
|
366
|
+
out = self.net(inp, mask = step_mask, **kwargs)
|
367
|
+
|
368
|
+
outputs.append(out)
|
369
|
+
|
370
|
+
inp = out
|
371
|
+
|
372
|
+
# stack masks and predictions from rollouts
|
373
|
+
|
374
|
+
masks = stack(masks, dim = 1) if exists(mask) else None
|
375
|
+
|
376
|
+
pred = stack(outputs, dim = 1)
|
377
|
+
|
378
|
+
# loss
|
318
379
|
|
319
|
-
|
380
|
+
loss = self.loss_fn(pred, target)
|
320
381
|
|
321
|
-
loss
|
382
|
+
# adjusting loss based on mask
|
322
383
|
|
323
|
-
if
|
384
|
+
if has_mask:
|
324
385
|
assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
|
325
386
|
|
326
387
|
if self.equal_loss_weight_batch:
|
327
|
-
loss = masked_mean(loss,
|
388
|
+
loss = masked_mean(loss, masks)
|
328
389
|
else:
|
329
|
-
loss = loss[
|
390
|
+
loss = loss[masks]
|
330
391
|
|
331
392
|
return loss.mean()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.3.
|
3
|
+
Version: 2.3.15
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2486,4 +2486,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2486
2486
|
}
|
2487
2487
|
```
|
2488
2488
|
|
2489
|
+
```bibtex
|
2490
|
+
@inproceedings{Assran2025VJEPA2S,
|
2491
|
+
title = {V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning},
|
2492
|
+
author = {Mahmoud Assran and Adrien Bardes and David Fan and Quentin Garrido and Russell Howes and Mojtaba Komeili and Matthew Muckley and Ammar Rizvi and Claire Roberts and Koustuv Sinha and Artem Zholus and Sergio Arnaud and Abha Gejji and Ada Martin and Francois Robert Hogan and Daniel Dugas and Piotr Bojanowski and Vasil Khalidov and Patrick Labatut and Francisco Massa and Marc Szafraniec and Kapil Krishnakumar and Yong Li and Xiaodong Ma and Sarath Chandar and Franziska Meier and Yann LeCun and Michael Rabbat and Nicolas Ballas and Fair at Meta and Mila - Qu{\'e}bec and AI Institute and Polytechnique Montr{\'e}al},
|
2493
|
+
year = {2025},
|
2494
|
+
url = {https://api.semanticscholar.org/CorpusID:279306055}
|
2495
|
+
}
|
2496
|
+
```
|
2497
|
+
|
2489
2498
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
|
|
2
2
|
x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=5SZmi3Bd77aJAu50f4y1OwwruZd_3ZHptC8dtQmvvxM,11387
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.15.dist-info/METADATA,sha256=Dv6r9tZhbF-_7q5yhgmo3pto-D6NVJnYPNnSBEBt73I,89897
|
15
|
+
x_transformers-2.3.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.15.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|