x-transformers 2.3.15__py3-none-any.whl → 2.3.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +11 -3
- x_transformers/continuous.py +95 -55
- {x_transformers-2.3.15.dist-info → x_transformers-2.3.17.dist-info}/METADATA +1 -1
- {x_transformers-2.3.15.dist-info → x_transformers-2.3.17.dist-info}/RECORD +6 -6
- {x_transformers-2.3.15.dist-info → x_transformers-2.3.17.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.15.dist-info → x_transformers-2.3.17.dist-info}/licenses/LICENSE +0 -0
@@ -36,8 +36,16 @@ def eval_decorator(fn):
|
|
36
36
|
|
37
37
|
# for variable lengthed prefixes
|
38
38
|
|
39
|
+
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
40
|
+
if pad == (0, 0):
|
41
|
+
return t
|
42
|
+
|
43
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
44
|
+
zeros = ((0, 0) * dims_from_right)
|
45
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
46
|
+
|
39
47
|
def align_right(t, lens, pad_id = 0):
|
40
|
-
batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
|
48
|
+
batch, seq_len, device, dtype = *t.shape[:2], t.device, t.dtype
|
41
49
|
|
42
50
|
assert lens.ndim == 1 and lens.shape[0] == batch
|
43
51
|
assert lens.amax() <= seq_len
|
@@ -48,10 +56,10 @@ def align_right(t, lens, pad_id = 0):
|
|
48
56
|
batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
|
49
57
|
prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
|
50
58
|
|
51
|
-
t =
|
59
|
+
t = pad_at_dim(t, (max_pad_len, 0), value = pad_id, dim = 1)
|
52
60
|
offset = max_pad_len - pad_lens
|
53
61
|
|
54
|
-
aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
|
62
|
+
aligned = t[batch_arange, prompt_len_arange + offset[..., None], ...]
|
55
63
|
return aligned
|
56
64
|
|
57
65
|
# nucleus
|
x_transformers/continuous.py
CHANGED
@@ -9,6 +9,8 @@ from torch.distributions import Normal
|
|
9
9
|
import einx
|
10
10
|
from einops import rearrange, reduce, pack, repeat, unpack
|
11
11
|
|
12
|
+
from x_transformers.autoregressive_wrapper import align_right
|
13
|
+
|
12
14
|
from x_transformers.x_transformers import (
|
13
15
|
Attention,
|
14
16
|
AttentionLayers,
|
@@ -130,6 +132,7 @@ class ContinuousTransformerWrapper(Module):
|
|
130
132
|
sum_embeds = None,
|
131
133
|
prepend_embeds = None,
|
132
134
|
prepend_mask = None,
|
135
|
+
seq_start_pos = None,
|
133
136
|
**kwargs
|
134
137
|
):
|
135
138
|
batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
|
@@ -145,7 +148,7 @@ class ContinuousTransformerWrapper(Module):
|
|
145
148
|
# project in + positional embedding
|
146
149
|
|
147
150
|
x = self.project_in(x)
|
148
|
-
x = x + self.pos_emb(x, pos = pos)
|
151
|
+
x = x + self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos)
|
149
152
|
|
150
153
|
if exists(sum_embeds):
|
151
154
|
x = x + sum_embeds
|
@@ -221,7 +224,6 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
221
224
|
net: ContinuousTransformerWrapper,
|
222
225
|
loss_fn: Module | None = None,
|
223
226
|
equal_loss_weight_batch = False, # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)
|
224
|
-
rollout_steps = 1 # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
|
225
227
|
):
|
226
228
|
super().__init__()
|
227
229
|
self.net = net
|
@@ -235,14 +237,6 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
235
237
|
self.loss_fn = loss_fn
|
236
238
|
self.equal_loss_weight_batch = equal_loss_weight_batch
|
237
239
|
|
238
|
-
# num rollout steps - if greater than one, recurrently feedback the output and enforce loss rollout steps - 1 ahead
|
239
|
-
# applied successfully in vjepa2 world model, with rollout steps of 2
|
240
|
-
# rollout steps of 1 would be the same as single step autoregressive
|
241
|
-
|
242
|
-
assert not (rollout_steps > 1 and probabilistic), f'rollout steps greater than 1 only supported for non-probabilistic'
|
243
|
-
assert 1 <= rollout_steps
|
244
|
-
self.rollout_steps = rollout_steps
|
245
|
-
|
246
240
|
@torch.no_grad()
|
247
241
|
def generate(
|
248
242
|
self,
|
@@ -297,36 +291,18 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
297
291
|
self.net.train(was_training)
|
298
292
|
return out
|
299
293
|
|
300
|
-
def
|
294
|
+
def forward_rollout(
|
301
295
|
self,
|
302
296
|
x,
|
297
|
+
rollout_steps = 2,
|
303
298
|
**kwargs
|
304
299
|
):
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
# get the input
|
309
|
-
|
310
|
-
inp = x[:, :-steps]
|
311
|
-
|
312
|
-
# variables
|
313
|
-
|
314
|
-
batch, seq_len, device = *inp.shape[:2], inp.device
|
315
|
-
|
316
|
-
# get target
|
317
|
-
|
318
|
-
if one_step_autoregress:
|
319
|
-
target = x[:, None, 1:]
|
320
|
-
else:
|
321
|
-
|
322
|
-
batch_arange = arange(batch, device = device)
|
323
|
-
batch_arange = rearrange(batch_arange, 'b -> b 1 1')
|
324
|
-
seq_arange = arange(seq_len, device = device)
|
325
|
-
steps_arange = arange(steps, device = device) + 1
|
300
|
+
assert rollout_steps > 1
|
301
|
+
assert not self.probabilistic, 'probabilistic not supported yet'
|
326
302
|
|
327
|
-
|
303
|
+
steps = rollout_steps
|
328
304
|
|
329
|
-
|
305
|
+
device = x.device
|
330
306
|
|
331
307
|
# assert inputs
|
332
308
|
|
@@ -343,50 +319,114 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
343
319
|
mask = einx.less('j, i -> i j', seq_arange, lens)
|
344
320
|
kwargs['mask'] = mask
|
345
321
|
|
322
|
+
if not exists(lens):
|
323
|
+
batch, seq_len = x.shape[:2]
|
324
|
+
lens = torch.full((batch,), seq_len, device = device)
|
325
|
+
|
346
326
|
# handle mask manually
|
347
327
|
|
348
328
|
mask = kwargs.pop('mask', None)
|
349
329
|
|
350
|
-
|
330
|
+
# pick a random range for each batch sample and aligh the sequence to the right for rollout loss
|
351
331
|
|
352
|
-
|
332
|
+
valid_tokens_for_rollout = (lens - steps).clamp(min = 0)
|
333
|
+
valid_sample = valid_tokens_for_rollout > 0
|
353
334
|
|
354
|
-
|
355
|
-
masks = []
|
335
|
+
x = x[valid_sample] # remove invalid sequence (lens less than rollout steps)
|
356
336
|
|
357
|
-
|
337
|
+
if exists(mask):
|
338
|
+
mask = mask[valid_sample]
|
358
339
|
|
359
|
-
|
360
|
-
|
361
|
-
step_mask = mask[:, step_index:(step_index + seq_len)]
|
362
|
-
masks.append(step_mask)
|
340
|
+
batch = x.shape[0]
|
341
|
+
seq_start_pos = (torch.rand((batch,), device = device) * valid_tokens_for_rollout).floor().long()
|
363
342
|
|
364
|
-
|
343
|
+
batch_arange = torch.arange(batch, device = device)
|
344
|
+
batch_arange = rearrange(batch_arange, 'b -> b 1')
|
365
345
|
|
366
|
-
|
346
|
+
# crop out sequence to use
|
367
347
|
|
368
|
-
|
348
|
+
seq_end_pos = seq_start_pos + steps
|
349
|
+
max_end_pos = seq_end_pos.amax().item()
|
350
|
+
x = x[:, :max_end_pos]
|
369
351
|
|
370
|
-
|
352
|
+
x = align_right(x, seq_end_pos)
|
353
|
+
|
354
|
+
# get the input
|
371
355
|
|
372
|
-
|
356
|
+
inp, targets = x[:, :-steps], x[:, -steps:]
|
373
357
|
|
374
|
-
|
358
|
+
# maybe rollout
|
359
|
+
|
360
|
+
cache = None
|
361
|
+
preds = []
|
375
362
|
|
376
|
-
|
363
|
+
for _ in range(steps):
|
364
|
+
|
365
|
+
out, cache = self.net(
|
366
|
+
inp,
|
367
|
+
seq_start_pos = seq_start_pos,
|
368
|
+
return_intermediates = True,
|
369
|
+
**kwargs
|
370
|
+
)
|
371
|
+
|
372
|
+
last_pred = out[:, -1:]
|
373
|
+
inp = last_pred
|
374
|
+
|
375
|
+
preds.append(last_pred)
|
376
|
+
|
377
|
+
# stack for predictions
|
378
|
+
|
379
|
+
preds = cat(preds, dim = 1)
|
377
380
|
|
378
381
|
# loss
|
379
382
|
|
380
|
-
loss = self.loss_fn(
|
383
|
+
loss = self.loss_fn(preds, targets)
|
384
|
+
|
385
|
+
return loss.mean()
|
386
|
+
|
387
|
+
def forward(
|
388
|
+
self,
|
389
|
+
x,
|
390
|
+
rollout_steps = 1, # they used 2 rollout steps in a successful world model paper https://ai.meta.com/vjepa/
|
391
|
+
**kwargs
|
392
|
+
):
|
393
|
+
if rollout_steps > 1:
|
394
|
+
return self.forward_rollout(x, rollout_steps = rollout_steps, **kwargs)
|
395
|
+
|
396
|
+
inp, target = x[:, :-1], x[:, 1:]
|
397
|
+
|
398
|
+
assert 'prepend_embeds' not in kwargs
|
399
|
+
|
400
|
+
# lens
|
401
|
+
|
402
|
+
lens = kwargs.pop('lens', None)
|
403
|
+
|
404
|
+
if exists(lens):
|
405
|
+
assert 'mask' not in kwargs, 'either `mask` or `lens` passed in, but not both'
|
406
|
+
seq_len, device = inp.shape[1], inp.device
|
407
|
+
seq_arange = torch.arange(seq_len, device = device)
|
408
|
+
mask = einx.less('j, i -> i j', seq_arange, lens)
|
409
|
+
|
410
|
+
kwargs['mask'] = mask
|
411
|
+
|
412
|
+
# mask
|
413
|
+
|
414
|
+
mask = kwargs.get('mask', None)
|
415
|
+
|
416
|
+
if exists(mask) and mask.shape[1] == x.shape[1]:
|
417
|
+
mask = mask[:, :-1]
|
418
|
+
kwargs['mask'] = mask
|
419
|
+
|
420
|
+
out = self.net(inp, **kwargs)
|
381
421
|
|
382
|
-
|
422
|
+
loss = self.loss_fn(out, target)
|
383
423
|
|
384
|
-
if
|
424
|
+
if exists(mask):
|
385
425
|
assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
|
386
426
|
|
387
427
|
if self.equal_loss_weight_batch:
|
388
|
-
loss = masked_mean(loss,
|
428
|
+
loss = masked_mean(loss, mask)
|
389
429
|
else:
|
390
|
-
loss = loss[
|
430
|
+
loss = loss[mask]
|
391
431
|
|
392
432
|
return loss.mean()
|
@@ -1,8 +1,8 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
|
2
2
|
x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=QidhbSgBhYDpAp7FaryxrNJxWXp0-pJkeygguTpVp4k,12308
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.17.dist-info/METADATA,sha256=u_DVXsX7KsVhnfMhpV-3KV6KGNHddmzW_SNG9om557s,89897
|
15
|
+
x_transformers-2.3.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.17.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|