x-transformers 1.40.11__py3-none-any.whl → 1.41.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/dpo.py +7 -4
- x_transformers/x_transformers.py +84 -16
- {x_transformers-1.40.11.dist-info → x_transformers-1.41.0.dist-info}/METADATA +2 -1
- {x_transformers-1.40.11.dist-info → x_transformers-1.41.0.dist-info}/RECORD +7 -7
- {x_transformers-1.40.11.dist-info → x_transformers-1.41.0.dist-info}/WHEEL +1 -1
- {x_transformers-1.40.11.dist-info → x_transformers-1.41.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.40.11.dist-info → x_transformers-1.41.0.dist-info}/top_level.txt +0 -0
x_transformers/dpo.py
CHANGED
@@ -5,6 +5,7 @@ from torch.nn import Module
|
|
5
5
|
import torch.nn.functional as F
|
6
6
|
from x_transformers.x_transformers import TransformerWrapper
|
7
7
|
|
8
|
+
import einx
|
8
9
|
from einops import rearrange
|
9
10
|
|
10
11
|
# helper functions
|
@@ -17,16 +18,18 @@ def freeze_all_layers_(module):
|
|
17
18
|
param.requires_grad = False
|
18
19
|
|
19
20
|
def log_prob_from_model_and_seq(model, seq):
|
20
|
-
|
21
|
+
src_seq, tgt_seq = seq[:, :-1], seq[:, 1:]
|
22
|
+
logits = model(src_seq)
|
21
23
|
log_prob = logits.log_softmax(dim = -1)
|
22
|
-
|
23
|
-
log_probs = log_prob.gather(-1, indices)
|
24
|
-
return rearrange(log_probs, '... 1 -> ...')
|
24
|
+
return einx.get_at('b n [l], b n -> b n', log_prob, tgt_seq)
|
25
25
|
|
26
26
|
def masked_mean(log_probs, mask = None):
|
27
27
|
if not exists(mask):
|
28
28
|
return log_probs.mean(dim = -1)
|
29
29
|
|
30
|
+
if mask.shape[-1] == (log_probs.shape[-1] + 1):
|
31
|
+
mask = mask[:, :-1]
|
32
|
+
|
30
33
|
log_probs = log_probs.masked_fill(~mask, 0.)
|
31
34
|
num = log_probs.sum(dim = -1)
|
32
35
|
den = mask.sum(dim = -1)
|
x_transformers/x_transformers.py
CHANGED
@@ -16,8 +16,9 @@ from collections import namedtuple
|
|
16
16
|
from contextlib import nullcontext
|
17
17
|
from dataclasses import dataclass
|
18
18
|
|
19
|
-
|
19
|
+
import einx
|
20
20
|
from einops.layers.torch import Rearrange
|
21
|
+
from einops import rearrange, repeat, reduce, pack, unpack
|
21
22
|
|
22
23
|
from x_transformers.attend import Attend, Intermediates
|
23
24
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
@@ -100,6 +101,12 @@ def log(t, eps = 1e-20):
|
|
100
101
|
def max_neg_value(tensor):
|
101
102
|
return -torch.finfo(tensor.dtype).max
|
102
103
|
|
104
|
+
def reverse_cumsum(t, dim = -1):
|
105
|
+
t = t.flip(dims = (dim,))
|
106
|
+
t = t.cumsum(dim = dim)
|
107
|
+
t = t.flip(dims = (dim,))
|
108
|
+
return t
|
109
|
+
|
103
110
|
def l2norm(t, groups = 1):
|
104
111
|
t = rearrange(t, '... (g d) -> ... g d', g = groups)
|
105
112
|
t = F.normalize(t, p = 2, dim = -1)
|
@@ -324,7 +331,7 @@ class RelativePositionBias(Module):
|
|
324
331
|
device = self.device
|
325
332
|
q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
|
326
333
|
k_pos = torch.arange(j, dtype = torch.long, device = device)
|
327
|
-
rel_pos =
|
334
|
+
rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
|
328
335
|
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
329
336
|
values = self.relative_attention_bias(rp_bucket)
|
330
337
|
bias = rearrange(values, 'i j h -> h i j')
|
@@ -351,8 +358,10 @@ class CoPE(Module):
|
|
351
358
|
self.soft_onehot = soft_onehot
|
352
359
|
self.soft_onehot_temp = soft_onehot_temp
|
353
360
|
|
354
|
-
if soft_onehot:
|
355
|
-
|
361
|
+
if not soft_onehot:
|
362
|
+
return
|
363
|
+
|
364
|
+
self.register_buffer('positions', torch.arange(max_pos))
|
356
365
|
|
357
366
|
def forward(self, query, attn_logits):
|
358
367
|
|
@@ -374,7 +383,7 @@ class CoPE(Module):
|
|
374
383
|
logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
|
375
384
|
|
376
385
|
if self.soft_onehot:
|
377
|
-
diff_pos = (
|
386
|
+
diff_pos = einx.subtract('i, j -> i j', pos, self.positions).abs()
|
378
387
|
soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
|
379
388
|
cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
|
380
389
|
else:
|
@@ -423,7 +432,7 @@ class DynamicPositionBias(Module):
|
|
423
432
|
# get the (n x n) matrix of distances
|
424
433
|
seq_arange = torch.arange(n, device = device)
|
425
434
|
context_arange = torch.arange(n, device = device)
|
426
|
-
indices =
|
435
|
+
indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
|
427
436
|
indices += (n - 1)
|
428
437
|
|
429
438
|
# input to continuous positions MLP
|
@@ -453,9 +462,9 @@ class AlibiPositionalBias(Module):
|
|
453
462
|
self.register_buffer('bias', None, persistent = False)
|
454
463
|
|
455
464
|
def get_bias(self, i, j, device):
|
456
|
-
|
457
|
-
|
458
|
-
bias = -torch.abs(
|
465
|
+
seq_arange = torch.arange(j - i, j, device = device)
|
466
|
+
context_arange = torch.arange(j, device = device)
|
467
|
+
bias = -torch.abs(einx.subtract('j, i -> 1 i j', context_arange, seq_arange))
|
459
468
|
return bias
|
460
469
|
|
461
470
|
@staticmethod
|
@@ -490,6 +499,44 @@ class AlibiPositionalBias(Module):
|
|
490
499
|
|
491
500
|
return self.bias
|
492
501
|
|
502
|
+
class DataDependentAlibi(Module):
|
503
|
+
""" https://openreview.net/forum?id=q2Lnyegkr8 """
|
504
|
+
|
505
|
+
def __init__(
|
506
|
+
self,
|
507
|
+
dim,
|
508
|
+
heads
|
509
|
+
):
|
510
|
+
super().__init__()
|
511
|
+
|
512
|
+
linear = nn.Linear(dim, heads)
|
513
|
+
|
514
|
+
self.to_forget_gates = nn.Sequential(
|
515
|
+
linear,
|
516
|
+
Rearrange('b n h -> b h n'),
|
517
|
+
nn.Sigmoid()
|
518
|
+
)
|
519
|
+
|
520
|
+
nn.init.constant_(linear.bias, 5.)
|
521
|
+
|
522
|
+
def forward(self, x):
|
523
|
+
seq = x.shape[-2]
|
524
|
+
|
525
|
+
forget_gates = self.to_forget_gates(x).log()
|
526
|
+
forget_gates = repeat(forget_gates, 'b h j -> b h i j', i = seq)
|
527
|
+
|
528
|
+
# causal mask out, including diagonal (so token to itself attention is never masked out)
|
529
|
+
|
530
|
+
causal_mask = torch.ones((seq, seq), dtype = torch.bool, device = x.device).triu()
|
531
|
+
|
532
|
+
forget_gates = forget_gates.masked_fill(causal_mask, 0.)
|
533
|
+
|
534
|
+
# reverse cumulative sum in log space (equivalent to cumprod)
|
535
|
+
|
536
|
+
forget_gates = reverse_cumsum(forget_gates)
|
537
|
+
|
538
|
+
return forget_gates
|
539
|
+
|
493
540
|
class RotaryEmbedding(Module):
|
494
541
|
def __init__(
|
495
542
|
self,
|
@@ -938,6 +985,7 @@ class Attention(Module):
|
|
938
985
|
tensor_product = False, # https://arxiv.org/abs/2208.06061
|
939
986
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
940
987
|
rotary_embed_values = False,
|
988
|
+
data_dependent_alibi = False,
|
941
989
|
use_cope = False,
|
942
990
|
cope_max_pos = 16,
|
943
991
|
cope_soft_onehot_pos = False,
|
@@ -1041,6 +1089,19 @@ class Attention(Module):
|
|
1041
1089
|
soft_onehot = cope_soft_onehot_pos
|
1042
1090
|
)
|
1043
1091
|
|
1092
|
+
# data dependent alibi
|
1093
|
+
# https://openreview.net/forum?id=q2Lnyegkr8
|
1094
|
+
|
1095
|
+
self.data_dependent_alibi = None
|
1096
|
+
|
1097
|
+
if data_dependent_alibi:
|
1098
|
+
assert causal, 'data dependent alibi only works for autoregressive for now until further research'
|
1099
|
+
|
1100
|
+
self.data_dependent_alibi = DataDependentAlibi(
|
1101
|
+
dim,
|
1102
|
+
heads = heads
|
1103
|
+
)
|
1104
|
+
|
1044
1105
|
# attend class - includes core attention algorithm + talking heads
|
1045
1106
|
|
1046
1107
|
self.attend = Attend(
|
@@ -1236,7 +1297,7 @@ class Attention(Module):
|
|
1236
1297
|
if exists(self.max_attend_past):
|
1237
1298
|
range_q = torch.arange(j - i, j, device = device)
|
1238
1299
|
range_k = torch.arange(j, device = device)
|
1239
|
-
dist =
|
1300
|
+
dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
|
1240
1301
|
max_attend_past_mask = dist > self.max_attend_past
|
1241
1302
|
max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
|
1242
1303
|
masks.append(max_attend_past_mask)
|
@@ -1251,6 +1312,11 @@ class Attention(Module):
|
|
1251
1312
|
attn_bias = rel_pos(i, j)
|
1252
1313
|
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
|
1253
1314
|
|
1315
|
+
# prepare data dependent alibi from forgetting transformers paper, if needed
|
1316
|
+
|
1317
|
+
if exists(self.data_dependent_alibi):
|
1318
|
+
attn_bias = self.data_dependent_alibi(x)
|
1319
|
+
|
1254
1320
|
# if previous values passed in for residual, either invoke resformer or neutreno
|
1255
1321
|
|
1256
1322
|
if exists(value_residual):
|
@@ -1291,7 +1357,7 @@ class Attention(Module):
|
|
1291
1357
|
|
1292
1358
|
if exists(self.to_v_head_gate):
|
1293
1359
|
head_gate = self.to_v_head_gate(x)
|
1294
|
-
out =
|
1360
|
+
out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
|
1295
1361
|
|
1296
1362
|
# merge heads
|
1297
1363
|
|
@@ -1308,8 +1374,7 @@ class Attention(Module):
|
|
1308
1374
|
out = self.to_out(out)
|
1309
1375
|
|
1310
1376
|
if exists(mask):
|
1311
|
-
|
1312
|
-
out = out.masked_fill(~mask, 0.)
|
1377
|
+
out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
|
1313
1378
|
|
1314
1379
|
if not return_intermediates:
|
1315
1380
|
return out
|
@@ -1389,10 +1454,13 @@ class AttentionLayers(Module):
|
|
1389
1454
|
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
|
1390
1455
|
cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)
|
1391
1456
|
|
1457
|
+
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
1458
|
+
data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
|
1459
|
+
neutreno_value_residual = attn_kwargs.get('neutreno_value_residual', False)
|
1460
|
+
|
1392
1461
|
assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
|
1393
1462
|
|
1394
|
-
|
1395
|
-
add_value_residual |= attn_kwargs.get('neutreno_value_residual', False)
|
1463
|
+
add_value_residual |= neutreno_value_residual
|
1396
1464
|
|
1397
1465
|
self.dim = dim
|
1398
1466
|
self.causal = causal
|
@@ -1405,7 +1473,7 @@ class AttentionLayers(Module):
|
|
1405
1473
|
assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
|
1406
1474
|
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
|
1407
1475
|
|
1408
|
-
assert
|
1476
|
+
assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), or T5 relative positional bias'
|
1409
1477
|
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
1410
1478
|
|
1411
1479
|
# relative positional bias
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.41.0
|
4
4
|
Summary: X-Transformers - Pytorch
|
5
5
|
Home-page: https://github.com/lucidrains/x-transformers
|
6
6
|
Author: Phil Wang
|
@@ -15,5 +15,6 @@ Classifier: Programming Language :: Python :: 3.6
|
|
15
15
|
Description-Content-Type: text/markdown
|
16
16
|
License-File: LICENSE
|
17
17
|
Requires-Dist: torch >=2.0
|
18
|
+
Requires-Dist: einx >=0.3.0
|
18
19
|
Requires-Dist: einops >=0.8.0
|
19
20
|
|
@@ -2,14 +2,14 @@ x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,79
|
|
2
2
|
x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
|
-
x_transformers/dpo.py,sha256=
|
5
|
+
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=fz71zW2IQ3NQU_csHbzCwFzGHNwrdIF9rZTLhUjmM_Q,90260
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.41.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.41.0.dist-info/METADATA,sha256=832V1ChJ77viLmQO7d9RU7R9SV_bspLgVl9vtdRdq5Q,689
|
13
|
+
x_transformers-1.41.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
14
|
+
x_transformers-1.41.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.41.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|