x-transformers 2.3.16__py3-none-any.whl → 2.3.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +11 -3
- x_transformers/continuous.py +97 -60
- {x_transformers-2.3.16.dist-info → x_transformers-2.3.18.dist-info}/METADATA +1 -1
- {x_transformers-2.3.16.dist-info → x_transformers-2.3.18.dist-info}/RECORD +6 -6
- {x_transformers-2.3.16.dist-info → x_transformers-2.3.18.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.16.dist-info → x_transformers-2.3.18.dist-info}/licenses/LICENSE +0 -0
@@ -36,8 +36,16 @@ def eval_decorator(fn):
|
|
36
36
|
|
37
37
|
# for variable lengthed prefixes
|
38
38
|
|
39
|
+
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
40
|
+
if pad == (0, 0):
|
41
|
+
return t
|
42
|
+
|
43
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
44
|
+
zeros = ((0, 0) * dims_from_right)
|
45
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
46
|
+
|
39
47
|
def align_right(t, lens, pad_id = 0):
|
40
|
-
batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
|
48
|
+
batch, seq_len, device, dtype = *t.shape[:2], t.device, t.dtype
|
41
49
|
|
42
50
|
assert lens.ndim == 1 and lens.shape[0] == batch
|
43
51
|
assert lens.amax() <= seq_len
|
@@ -48,10 +56,10 @@ def align_right(t, lens, pad_id = 0):
|
|
48
56
|
batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
|
49
57
|
prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
|
50
58
|
|
51
|
-
t =
|
59
|
+
t = pad_at_dim(t, (max_pad_len, 0), value = pad_id, dim = 1)
|
52
60
|
offset = max_pad_len - pad_lens
|
53
61
|
|
54
|
-
aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
|
62
|
+
aligned = t[batch_arange, prompt_len_arange + offset[..., None], ...]
|
55
63
|
return aligned
|
56
64
|
|
57
65
|
# nucleus
|
x_transformers/continuous.py
CHANGED
@@ -9,6 +9,8 @@ from torch.distributions import Normal
|
|
9
9
|
import einx
|
10
10
|
from einops import rearrange, reduce, pack, repeat, unpack
|
11
11
|
|
12
|
+
from x_transformers.autoregressive_wrapper import align_right
|
13
|
+
|
12
14
|
from x_transformers.x_transformers import (
|
13
15
|
Attention,
|
14
16
|
AttentionLayers,
|
@@ -222,7 +224,6 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
222
224
|
net: ContinuousTransformerWrapper,
|
223
225
|
loss_fn: Module | None = None,
|
224
226
|
equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
|
225
|
-
rollout_steps = 1 # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
|
226
227
|
):
|
227
228
|
super().__init__()
|
228
229
|
self.net = net
|
@@ -236,14 +237,6 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
236
237
|
self.loss_fn = loss_fn
|
237
238
|
self.equal_loss_weight_batch = equal_loss_weight_batch
|
238
239
|
|
239
|
-
# num rollout steps - if greater than one, recurrently feedback the output and enforce loss rollout steps - 1 ahead
|
240
|
-
# applied successfully in vjepa2 world model, with rollout steps of 2
|
241
|
-
# rollout steps of 1 would be the same as single step autoregressive
|
242
|
-
|
243
|
-
assert not (rollout_steps > 1 and probabilistic), f'rollout steps greater than 1 only supported for non-probabilistic'
|
244
|
-
assert 1 <= rollout_steps
|
245
|
-
self.rollout_steps = rollout_steps
|
246
|
-
|
247
240
|
@torch.no_grad()
|
248
241
|
def generate(
|
249
242
|
self,
|
@@ -298,40 +291,17 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
298
291
|
self.net.train(was_training)
|
299
292
|
return out
|
300
293
|
|
301
|
-
def
|
294
|
+
def forward_rollout(
|
302
295
|
self,
|
303
296
|
x,
|
297
|
+
rollout_steps = 2,
|
304
298
|
**kwargs
|
305
299
|
):
|
306
|
-
|
307
|
-
one_step_autoregress = steps == 1
|
308
|
-
|
309
|
-
# get the input
|
310
|
-
|
311
|
-
inp = x[:, :-steps]
|
312
|
-
|
313
|
-
# variables
|
314
|
-
|
315
|
-
batch, seq_len, device = *inp.shape[:2], inp.device
|
316
|
-
|
317
|
-
# get target
|
318
|
-
|
319
|
-
seq_start_pos = None
|
320
|
-
|
321
|
-
if one_step_autoregress:
|
322
|
-
target = x[:, None, 1:]
|
323
|
-
else:
|
324
|
-
|
325
|
-
batch_arange = arange(batch, device = device)
|
326
|
-
batch_arange = rearrange(batch_arange, 'b -> b 1 1')
|
327
|
-
seq_arange = arange(seq_len, device = device)
|
328
|
-
steps_arange = arange(steps, device = device) + 1
|
300
|
+
assert rollout_steps > 1
|
329
301
|
|
330
|
-
|
302
|
+
steps = rollout_steps
|
331
303
|
|
332
|
-
|
333
|
-
|
334
|
-
seq_start_pos = torch.zeros(batch, device = device, dtype = torch.long)
|
304
|
+
device = x.device
|
335
305
|
|
336
306
|
# assert inputs
|
337
307
|
|
@@ -348,53 +318,120 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
348
318
|
mask = einx.less('j, i -> i j', seq_arange, lens)
|
349
319
|
kwargs['mask'] = mask
|
350
320
|
|
321
|
+
if not exists(lens):
|
322
|
+
batch, seq_len = x.shape[:2]
|
323
|
+
lens = torch.full((batch,), seq_len, device = device)
|
324
|
+
|
351
325
|
# handle mask manually
|
352
326
|
|
353
327
|
mask = kwargs.pop('mask', None)
|
354
328
|
|
355
|
-
|
329
|
+
# pick a random range for each batch sample and aligh the sequence to the right for rollout loss
|
356
330
|
|
357
|
-
|
331
|
+
valid_tokens_for_rollout = (lens - steps).clamp(min = 0)
|
332
|
+
valid_sample = valid_tokens_for_rollout > 0
|
333
|
+
|
334
|
+
x = x[valid_sample] # remove invalid sequence (lens less than rollout steps)
|
335
|
+
|
336
|
+
if exists(mask):
|
337
|
+
mask = mask[valid_sample]
|
338
|
+
|
339
|
+
batch = x.shape[0]
|
340
|
+
seq_start_pos = (torch.rand((batch,), device = device) * valid_tokens_for_rollout).floor().long()
|
341
|
+
|
342
|
+
batch_arange = torch.arange(batch, device = device)
|
343
|
+
batch_arange = rearrange(batch_arange, 'b -> b 1')
|
358
344
|
|
359
|
-
|
360
|
-
masks = []
|
345
|
+
# crop out sequence to use
|
361
346
|
|
362
|
-
|
347
|
+
seq_end_pos = seq_start_pos + steps
|
348
|
+
max_end_pos = seq_end_pos.amax().item()
|
349
|
+
x = x[:, :max_end_pos]
|
363
350
|
|
364
|
-
|
365
|
-
if has_mask:
|
366
|
-
step_mask = mask[:, step_index:(step_index + seq_len)]
|
367
|
-
masks.append(step_mask)
|
351
|
+
x = align_right(x, seq_end_pos)
|
368
352
|
|
369
|
-
|
353
|
+
# get the input
|
354
|
+
|
355
|
+
inp, targets = x[:, :-steps], x[:, -steps:]
|
370
356
|
|
371
|
-
|
357
|
+
# maybe rollout
|
372
358
|
|
373
|
-
|
359
|
+
cache = None
|
360
|
+
preds = []
|
374
361
|
|
375
|
-
|
362
|
+
for _ in range(steps):
|
376
363
|
|
377
|
-
|
378
|
-
|
364
|
+
out, cache = self.net(
|
365
|
+
inp,
|
366
|
+
seq_start_pos = seq_start_pos,
|
367
|
+
return_intermediates = True,
|
368
|
+
**kwargs
|
369
|
+
)
|
379
370
|
|
380
|
-
|
371
|
+
last_pred = out[..., -1:, :]
|
372
|
+
|
373
|
+
if self.probabilistic:
|
374
|
+
mean, var = last_pred
|
375
|
+
std = var.clamp(min = 1e-5).sqrt()
|
376
|
+
inp = torch.normal(mean, std)
|
377
|
+
else:
|
378
|
+
inp = last_pred
|
381
379
|
|
382
|
-
|
380
|
+
preds.append(last_pred)
|
383
381
|
|
384
|
-
|
382
|
+
# stack for predictions
|
383
|
+
|
384
|
+
preds = cat(preds, dim = 1)
|
385
385
|
|
386
386
|
# loss
|
387
387
|
|
388
|
-
loss = self.loss_fn(
|
388
|
+
loss = self.loss_fn(preds, targets)
|
389
|
+
|
390
|
+
return loss.mean()
|
391
|
+
|
392
|
+
def forward(
|
393
|
+
self,
|
394
|
+
x,
|
395
|
+
rollout_steps = 1, # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
|
396
|
+
**kwargs
|
397
|
+
):
|
398
|
+
if rollout_steps > 1:
|
399
|
+
return self.forward_rollout(x, rollout_steps = rollout_steps, **kwargs)
|
400
|
+
|
401
|
+
inp, target = x[:, :-1], x[:, 1:]
|
402
|
+
|
403
|
+
assert 'prepend_embeds' not in kwargs
|
404
|
+
|
405
|
+
# lens
|
406
|
+
|
407
|
+
lens = kwargs.pop('lens', None)
|
408
|
+
|
409
|
+
if exists(lens):
|
410
|
+
assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
|
411
|
+
seq_len, device = inp.shape[1], inp.device
|
412
|
+
seq_arange = torch.arange(seq_len, device = device)
|
413
|
+
mask = einx.less('j, i -> i j', seq_arange, lens)
|
414
|
+
|
415
|
+
kwargs['mask'] = mask
|
416
|
+
|
417
|
+
# mask
|
418
|
+
|
419
|
+
mask = kwargs.get('mask', None)
|
420
|
+
|
421
|
+
if exists(mask) and mask.shape[1] == x.shape[1]:
|
422
|
+
mask = mask[:, :-1]
|
423
|
+
kwargs['mask'] = mask
|
424
|
+
|
425
|
+
out = self.net(inp, **kwargs)
|
389
426
|
|
390
|
-
|
427
|
+
loss = self.loss_fn(out, target)
|
391
428
|
|
392
|
-
if
|
429
|
+
if exists(mask):
|
393
430
|
assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
|
394
431
|
|
395
432
|
if self.equal_loss_weight_batch:
|
396
|
-
loss = masked_mean(loss,
|
433
|
+
loss = masked_mean(loss, mask)
|
397
434
|
else:
|
398
|
-
loss = loss[
|
435
|
+
loss = loss[mask]
|
399
436
|
|
400
437
|
return loss.mean()
|
@@ -1,8 +1,8 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
2
2
|
x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=uV2hLQOckeRsybqJy-0F8RhAyMPJlkVHmA7QqUJHG4g,12433
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.18.dist-info/METADATA,sha256=RKXNlO50fifu1Nas38iZRn6IJVDkv4Cen94XYVJlWg0,89897
|
15
|
+
x_transformers-2.3.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.18.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|