x-transformers 2.3.14__py3-none-any.whl → 2.3.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +92 -23
- {x_transformers-2.3.14.dist-info → x_transformers-2.3.16.dist-info}/METADATA +10 -1
- {x_transformers-2.3.14.dist-info → x_transformers-2.3.16.dist-info}/RECORD +5 -5
- {x_transformers-2.3.14.dist-info → x_transformers-2.3.16.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.14.dist-info → x_transformers-2.3.16.dist-info}/licenses/LICENSE +0 -0
x_transformers/continuous.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import torch
|
4
|
-
from torch import nn, cat, stack
|
4
|
+
from torch import nn, cat, stack, arange
|
5
5
|
from torch.nn import Module
|
6
6
|
import torch.nn.functional as F
|
7
7
|
from torch.distributions import Normal
|
@@ -64,7 +64,7 @@ class ContinuousTransformerWrapper(Module):
|
|
64
64
|
use_abs_pos_emb = True,
|
65
65
|
scaled_sinu_pos_emb = False,
|
66
66
|
average_pool_embed = False,
|
67
|
-
probabilistic = False
|
67
|
+
probabilistic = False,
|
68
68
|
):
|
69
69
|
super().__init__()
|
70
70
|
dim = attn_layers.dim
|
@@ -130,6 +130,7 @@ class ContinuousTransformerWrapper(Module):
|
|
130
130
|
sum_embeds = None,
|
131
131
|
prepend_embeds = None,
|
132
132
|
prepend_mask = None,
|
133
|
+
seq_start_pos = None,
|
133
134
|
**kwargs
|
134
135
|
):
|
135
136
|
batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
|
@@ -138,14 +139,14 @@ class ContinuousTransformerWrapper(Module):
|
|
138
139
|
|
139
140
|
if exists(lens):
|
140
141
|
assert not exists(mask), 'either `mask` or `lens` passed in, but not both'
|
141
|
-
seq_arange =
|
142
|
+
seq_arange = arange(seq, device = device)
|
142
143
|
|
143
144
|
mask = einx.less('j, i -> i j', seq_arange, lens)
|
144
145
|
|
145
146
|
# project in + positional embedding
|
146
147
|
|
147
148
|
x = self.project_in(x)
|
148
|
-
x = x + self.pos_emb(x, pos = pos)
|
149
|
+
x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos)
|
149
150
|
|
150
151
|
if exists(sum_embeds):
|
151
152
|
x = x + sum_embeds
|
@@ -220,7 +221,8 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
220
221
|
self,
|
221
222
|
net: ContinuousTransformerWrapper,
|
222
223
|
loss_fn: Module | None = None,
|
223
|
-
equal_loss_weight_batch = False # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
|
224
|
+
equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
|
225
|
+
rollout_steps = 1 # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
|
224
226
|
):
|
225
227
|
super().__init__()
|
226
228
|
self.net = net
|
@@ -234,6 +236,14 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
234
236
|
self.loss_fn = loss_fn
|
235
237
|
self.equal_loss_weight_batch = equal_loss_weight_batch
|
236
238
|
|
239
|
+
# num rollout steps - if greater than one, recurrently feedback the output and enforce loss rollout steps - 1 ahead
|
240
|
+
# applied successfully in vjepa2 world model, with rollout steps of 2
|
241
|
+
# rollout steps of 1 would be the same as single step autoregressive
|
242
|
+
|
243
|
+
assert not (rollout_steps > 1 and probabilistic), f'rollout steps greater than 1 only supported for non-probabilistic'
|
244
|
+
assert 1 <= rollout_steps
|
245
|
+
self.rollout_steps = rollout_steps
|
246
|
+
|
237
247
|
@torch.no_grad()
|
238
248
|
def generate(
|
239
249
|
self,
|
@@ -247,12 +257,13 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
247
257
|
device = start_tokens.device
|
248
258
|
|
249
259
|
was_training = self.net.training
|
250
|
-
num_dims =
|
260
|
+
num_dims = start_tokens.ndim
|
251
261
|
|
252
262
|
assert num_dims >= 2, 'number of dimensions of your start tokens must be greater or equal to 2'
|
263
|
+
no_batch = num_dims == 2
|
253
264
|
|
254
|
-
if
|
255
|
-
start_tokens = start_tokens
|
265
|
+
if no_batch:
|
266
|
+
start_tokens = rearrange(start_tokens, 'n d -> 1 n d')
|
256
267
|
|
257
268
|
b, t, _, device = *start_tokens.shape, start_tokens.device
|
258
269
|
|
@@ -281,8 +292,8 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
281
292
|
|
282
293
|
out = out[:, t:]
|
283
294
|
|
284
|
-
if
|
285
|
-
out = out
|
295
|
+
if no_batch:
|
296
|
+
out = rearrange(out, '1 n d -> n d')
|
286
297
|
|
287
298
|
self.net.train(was_training)
|
288
299
|
return out
|
@@ -292,7 +303,37 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
292
303
|
x,
|
293
304
|
**kwargs
|
294
305
|
):
|
295
|
-
|
306
|
+
steps = self.rollout_steps
|
307
|
+
one_step_autoregress = steps == 1
|
308
|
+
|
309
|
+
# get the input
|
310
|
+
|
311
|
+
inp = x[:, :-steps]
|
312
|
+
|
313
|
+
# variables
|
314
|
+
|
315
|
+
batch, seq_len, device = *inp.shape[:2], inp.device
|
316
|
+
|
317
|
+
# get target
|
318
|
+
|
319
|
+
seq_start_pos = None
|
320
|
+
|
321
|
+
if one_step_autoregress:
|
322
|
+
target = x[:, None, 1:]
|
323
|
+
else:
|
324
|
+
|
325
|
+
batch_arange = arange(batch, device = device)
|
326
|
+
batch_arange = rearrange(batch_arange, 'b -> b 1 1')
|
327
|
+
seq_arange = arange(seq_len, device = device)
|
328
|
+
steps_arange = arange(steps, device = device) + 1
|
329
|
+
|
330
|
+
target_indices = einx.add('r, n -> r n', steps_arange, seq_arange)
|
331
|
+
|
332
|
+
target = x[batch_arange, target_indices] # rollout targets
|
333
|
+
|
334
|
+
seq_start_pos = torch.zeros(batch, device = device, dtype = torch.long)
|
335
|
+
|
336
|
+
# assert inputs
|
296
337
|
|
297
338
|
assert 'prepend_embeds' not in kwargs
|
298
339
|
|
@@ -303,29 +344,57 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
303
344
|
if exists(lens):
|
304
345
|
assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
|
305
346
|
seq_len, device = inp.shape[1], inp.device
|
306
|
-
seq_arange =
|
347
|
+
seq_arange = arange(seq_len, device = device)
|
307
348
|
mask = einx.less('j, i -> i j', seq_arange, lens)
|
308
|
-
|
309
349
|
kwargs['mask'] = mask
|
310
350
|
|
311
|
-
# mask
|
351
|
+
# handle mask manually
|
312
352
|
|
313
|
-
mask = kwargs.
|
353
|
+
mask = kwargs.pop('mask', None)
|
314
354
|
|
315
|
-
|
316
|
-
|
317
|
-
|
355
|
+
has_mask = exists(mask)
|
356
|
+
|
357
|
+
# maybe rollout
|
358
|
+
|
359
|
+
outputs = []
|
360
|
+
masks = []
|
361
|
+
|
362
|
+
for step_index in range(steps):
|
363
|
+
|
364
|
+
step_mask = None
|
365
|
+
if has_mask:
|
366
|
+
step_mask = mask[:, step_index:(step_index + seq_len)]
|
367
|
+
masks.append(step_mask)
|
368
|
+
|
369
|
+
# forward
|
370
|
+
|
371
|
+
out = self.net(inp, mask = step_mask, seq_start_pos = seq_start_pos, **kwargs)
|
372
|
+
|
373
|
+
outputs.append(out)
|
374
|
+
|
375
|
+
inp = out
|
376
|
+
|
377
|
+
if not one_step_autoregress:
|
378
|
+
seq_start_pos.sub_(1)
|
379
|
+
|
380
|
+
# stack masks and predictions from rollouts
|
381
|
+
|
382
|
+
masks = stack(masks, dim = 1) if exists(mask) else None
|
383
|
+
|
384
|
+
pred = stack(outputs, dim = 1)
|
385
|
+
|
386
|
+
# loss
|
318
387
|
|
319
|
-
|
388
|
+
loss = self.loss_fn(pred, target)
|
320
389
|
|
321
|
-
loss
|
390
|
+
# adjusting loss based on mask
|
322
391
|
|
323
|
-
if
|
392
|
+
if has_mask:
|
324
393
|
assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
|
325
394
|
|
326
395
|
if self.equal_loss_weight_batch:
|
327
|
-
loss = masked_mean(loss,
|
396
|
+
loss = masked_mean(loss, masks)
|
328
397
|
else:
|
329
|
-
loss = loss[
|
398
|
+
loss = loss[masks]
|
330
399
|
|
331
400
|
return loss.mean()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.3.
|
3
|
+
Version: 2.3.16
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2486,4 +2486,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2486
2486
|
}
|
2487
2487
|
```
|
2488
2488
|
|
2489
|
+
```bibtex
|
2490
|
+
@inproceedings{Assran2025VJEPA2S,
|
2491
|
+
title = {V-JEPA 2: Self-Supervised Video Models Enable Understanding, Prediction and Planning},
|
2492
|
+
author = {Mahmoud Assran and Adrien Bardes and David Fan and Quentin Garrido and Russell Howes and Mojtaba Komeili and Matthew Muckley and Ammar Rizvi and Claire Roberts and Koustuv Sinha and Artem Zholus and Sergio Arnaud and Abha Gejji and Ada Martin and Francois Robert Hogan and Daniel Dugas and Piotr Bojanowski and Vasil Khalidov and Patrick Labatut and Francisco Massa and Marc Szafraniec and Kapil Krishnakumar and Yong Li and Xiaodong Ma and Sarath Chandar and Franziska Meier and Yann LeCun and Michael Rabbat and Nicolas Ballas and Fair at Meta and Mila - Qu{\'e}bec and AI Institute and Polytechnique Montr{\'e}al},
|
2493
|
+
year = {2025},
|
2494
|
+
url = {https://api.semanticscholar.org/CorpusID:279306055}
|
2495
|
+
}
|
2496
|
+
```
|
2497
|
+
|
2489
2498
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
|
|
2
2
|
x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=jy2wsQ3sS80Qwm_gnAmdAnzBfzLoWrGPacOTzU1Q6JM,11674
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.16.dist-info/METADATA,sha256=-lL73g4mG5pszuaU7lPdMVGJ7ZtqBqhaejr5VvWWUiw,89897
|
15
|
+
x_transformers-2.3.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.16.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.16.dist-info/RECORD,,
|
File without changes
|
File without changes
|