x-transformers 2.3.16__py3-none-any.whl → 2.3.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +11 -3
- x_transformers/continuous.py +92 -60
- {x_transformers-2.3.16.dist-info → x_transformers-2.3.17.dist-info}/METADATA +1 -1
- {x_transformers-2.3.16.dist-info → x_transformers-2.3.17.dist-info}/RECORD +6 -6
- {x_transformers-2.3.16.dist-info → x_transformers-2.3.17.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.16.dist-info → x_transformers-2.3.17.dist-info}/licenses/LICENSE +0 -0
@@ -36,8 +36,16 @@ def eval_decorator(fn):
|
|
36
36
|
|
37
37
|
# for variable lengthed prefixes
|
38
38
|
|
39
|
+
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
40
|
+
if pad == (0, 0):
|
41
|
+
return t
|
42
|
+
|
43
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
44
|
+
zeros = ((0, 0) * dims_from_right)
|
45
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
46
|
+
|
39
47
|
def align_right(t, lens, pad_id = 0):
|
40
|
-
batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
|
48
|
+
batch, seq_len, device, dtype = *t.shape[:2], t.device, t.dtype
|
41
49
|
|
42
50
|
assert lens.ndim == 1 and lens.shape[0] == batch
|
43
51
|
assert lens.amax() <= seq_len
|
@@ -48,10 +56,10 @@ def align_right(t, lens, pad_id = 0):
|
|
48
56
|
batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
|
49
57
|
prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
|
50
58
|
|
51
|
-
t =
|
59
|
+
t = pad_at_dim(t, (max_pad_len, 0), value = pad_id, dim = 1)
|
52
60
|
offset = max_pad_len - pad_lens
|
53
61
|
|
54
|
-
aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
|
62
|
+
aligned = t[batch_arange, prompt_len_arange + offset[..., None], ...]
|
55
63
|
return aligned
|
56
64
|
|
57
65
|
# nucleus
|
x_transformers/continuous.py
CHANGED
@@ -9,6 +9,8 @@ from torch.distributions import Normal
|
|
9
9
|
import einx
|
10
10
|
from einops import rearrange, reduce, pack, repeat, unpack
|
11
11
|
|
12
|
+
from x_transformers.autoregressive_wrapper import align_right
|
13
|
+
|
12
14
|
from x_transformers.x_transformers import (
|
13
15
|
Attention,
|
14
16
|
AttentionLayers,
|
@@ -222,7 +224,6 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
222
224
|
net: ContinuousTransformerWrapper,
|
223
225
|
loss_fn: Module | None = None,
|
224
226
|
equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
|
225
|
-
rollout_steps = 1 # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
|
226
227
|
):
|
227
228
|
super().__init__()
|
228
229
|
self.net = net
|
@@ -236,14 +237,6 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
236
237
|
self.loss_fn = loss_fn
|
237
238
|
self.equal_loss_weight_batch = equal_loss_weight_batch
|
238
239
|
|
239
|
-
# num rollout steps - if greater than one, recurrently feedback the output and enforce loss rollout steps - 1 ahead
|
240
|
-
# applied successfully in vjepa2 world model, with rollout steps of 2
|
241
|
-
# rollout steps of 1 would be the same as single step autoregressive
|
242
|
-
|
243
|
-
assert not (rollout_steps > 1 and probabilistic), f'rollout steps greater than 1 only supported for non-probabilistic'
|
244
|
-
assert 1 <= rollout_steps
|
245
|
-
self.rollout_steps = rollout_steps
|
246
|
-
|
247
240
|
@torch.no_grad()
|
248
241
|
def generate(
|
249
242
|
self,
|
@@ -298,40 +291,18 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
298
291
|
self.net.train(was_training)
|
299
292
|
return out
|
300
293
|
|
301
|
-
def
|
294
|
+
def forward_rollout(
|
302
295
|
self,
|
303
296
|
x,
|
297
|
+
rollout_steps = 2,
|
304
298
|
**kwargs
|
305
299
|
):
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
# get the input
|
310
|
-
|
311
|
-
inp = x[:, :-steps]
|
300
|
+
assert rollout_steps > 1
|
301
|
+
assert not self.probabilistic, 'probabilistic not supported yet'
|
312
302
|
|
313
|
-
|
314
|
-
|
315
|
-
batch, seq_len, device = *inp.shape[:2], inp.device
|
316
|
-
|
317
|
-
# get target
|
318
|
-
|
319
|
-
seq_start_pos = None
|
320
|
-
|
321
|
-
if one_step_autoregress:
|
322
|
-
target = x[:, None, 1:]
|
323
|
-
else:
|
324
|
-
|
325
|
-
batch_arange = arange(batch, device = device)
|
326
|
-
batch_arange = rearrange(batch_arange, 'b -> b 1 1')
|
327
|
-
seq_arange = arange(seq_len, device = device)
|
328
|
-
steps_arange = arange(steps, device = device) + 1
|
303
|
+
steps = rollout_steps
|
329
304
|
|
330
|
-
|
331
|
-
|
332
|
-
target = x[batch_arange, target_indices] # rollout targets
|
333
|
-
|
334
|
-
seq_start_pos = torch.zeros(batch, device = device, dtype = torch.long)
|
305
|
+
device = x.device
|
335
306
|
|
336
307
|
# assert inputs
|
337
308
|
|
@@ -348,53 +319,114 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
348
319
|
mask = einx.less('j, i -> i j', seq_arange, lens)
|
349
320
|
kwargs['mask'] = mask
|
350
321
|
|
322
|
+
if not exists(lens):
|
323
|
+
batch, seq_len = x.shape[:2]
|
324
|
+
lens = torch.full((batch,), seq_len, device = device)
|
325
|
+
|
351
326
|
# handle mask manually
|
352
327
|
|
353
328
|
mask = kwargs.pop('mask', None)
|
354
329
|
|
355
|
-
|
330
|
+
# pick a random range for each batch sample and aligh the sequence to the right for rollout loss
|
356
331
|
|
357
|
-
|
332
|
+
valid_tokens_for_rollout = (lens - steps).clamp(min = 0)
|
333
|
+
valid_sample = valid_tokens_for_rollout > 0
|
334
|
+
|
335
|
+
x = x[valid_sample] # remove invalid sequence (lens less than rollout steps)
|
336
|
+
|
337
|
+
if exists(mask):
|
338
|
+
mask = mask[valid_sample]
|
339
|
+
|
340
|
+
batch = x.shape[0]
|
341
|
+
seq_start_pos = (torch.rand((batch,), device = device) * valid_tokens_for_rollout).floor().long()
|
358
342
|
|
359
|
-
|
360
|
-
|
343
|
+
batch_arange = torch.arange(batch, device = device)
|
344
|
+
batch_arange = rearrange(batch_arange, 'b -> b 1')
|
361
345
|
|
362
|
-
|
346
|
+
# crop out sequence to use
|
363
347
|
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
masks.append(step_mask)
|
348
|
+
seq_end_pos = seq_start_pos + steps
|
349
|
+
max_end_pos = seq_end_pos.amax().item()
|
350
|
+
x = x[:, :max_end_pos]
|
368
351
|
|
369
|
-
|
352
|
+
x = align_right(x, seq_end_pos)
|
370
353
|
|
371
|
-
|
354
|
+
# get the input
|
355
|
+
|
356
|
+
inp, targets = x[:, :-steps], x[:, -steps:]
|
357
|
+
|
358
|
+
# maybe rollout
|
359
|
+
|
360
|
+
cache = None
|
361
|
+
preds = []
|
372
362
|
|
373
|
-
|
363
|
+
for _ in range(steps):
|
374
364
|
|
375
|
-
|
365
|
+
out, cache = self.net(
|
366
|
+
inp,
|
367
|
+
seq_start_pos = seq_start_pos,
|
368
|
+
return_intermediates = True,
|
369
|
+
**kwargs
|
370
|
+
)
|
376
371
|
|
377
|
-
|
378
|
-
|
372
|
+
last_pred = out[:, -1:]
|
373
|
+
inp = last_pred
|
379
374
|
|
380
|
-
|
375
|
+
preds.append(last_pred)
|
381
376
|
|
382
|
-
|
377
|
+
# stack for predictions
|
383
378
|
|
384
|
-
|
379
|
+
preds = cat(preds, dim = 1)
|
385
380
|
|
386
381
|
# loss
|
387
382
|
|
388
|
-
loss = self.loss_fn(
|
383
|
+
loss = self.loss_fn(preds, targets)
|
384
|
+
|
385
|
+
return loss.mean()
|
386
|
+
|
387
|
+
def forward(
|
388
|
+
self,
|
389
|
+
x,
|
390
|
+
rollout_steps = 1, # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
|
391
|
+
**kwargs
|
392
|
+
):
|
393
|
+
if rollout_steps > 1:
|
394
|
+
return self.forward_rollout(x, rollout_steps = rollout_steps, **kwargs)
|
395
|
+
|
396
|
+
inp, target = x[:, :-1], x[:, 1:]
|
397
|
+
|
398
|
+
assert 'prepend_embeds' not in kwargs
|
399
|
+
|
400
|
+
# lens
|
401
|
+
|
402
|
+
lens = kwargs.pop('lens', None)
|
403
|
+
|
404
|
+
if exists(lens):
|
405
|
+
assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
|
406
|
+
seq_len, device = inp.shape[1], inp.device
|
407
|
+
seq_arange = torch.arange(seq_len, device = device)
|
408
|
+
mask = einx.less('j, i -> i j', seq_arange, lens)
|
409
|
+
|
410
|
+
kwargs['mask'] = mask
|
411
|
+
|
412
|
+
# mask
|
413
|
+
|
414
|
+
mask = kwargs.get('mask', None)
|
415
|
+
|
416
|
+
if exists(mask) and mask.shape[1] == x.shape[1]:
|
417
|
+
mask = mask[:, :-1]
|
418
|
+
kwargs['mask'] = mask
|
419
|
+
|
420
|
+
out = self.net(inp, **kwargs)
|
389
421
|
|
390
|
-
|
422
|
+
loss = self.loss_fn(out, target)
|
391
423
|
|
392
|
-
if
|
424
|
+
if exists(mask):
|
393
425
|
assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
|
394
426
|
|
395
427
|
if self.equal_loss_weight_batch:
|
396
|
-
loss = masked_mean(loss,
|
428
|
+
loss = masked_mean(loss, mask)
|
397
429
|
else:
|
398
|
-
loss = loss[
|
430
|
+
loss = loss[mask]
|
399
431
|
|
400
432
|
return loss.mean()
|
@@ -1,8 +1,8 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
2
2
|
x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=QidhbSgBhYDpAp7FaryxrNJxWXp0-pJkeygguTpVp4k,12308
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.17.dist-info/METADATA,sha256=u_DVXsX7KsVhnfMhpV-3KV6KGNHddmzW_SNG9om557s,89897
|
15
|
+
x_transformers-2.3.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.17.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|