x-transformers 1.40.11__py3-none-any.whl → 1.41.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/dpo.py +7 -4
- x_transformers/x_transformers.py +69 -16
- {x_transformers-1.40.11.dist-info → x_transformers-1.41.1.dist-info}/METADATA +2 -1
- {x_transformers-1.40.11.dist-info → x_transformers-1.41.1.dist-info}/RECORD +7 -7
- {x_transformers-1.40.11.dist-info → x_transformers-1.41.1.dist-info}/WHEEL +1 -1
- {x_transformers-1.40.11.dist-info → x_transformers-1.41.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.40.11.dist-info → x_transformers-1.41.1.dist-info}/top_level.txt +0 -0
x_transformers/dpo.py
CHANGED
@@ -5,6 +5,7 @@ from torch.nn import Module
|
|
5
5
|
import torch.nn.functional as F
|
6
6
|
from x_transformers.x_transformers import TransformerWrapper
|
7
7
|
|
8
|
+
import einx
|
8
9
|
from einops import rearrange
|
9
10
|
|
10
11
|
# helper functions
|
@@ -17,16 +18,18 @@ def freeze_all_layers_(module):
|
|
17
18
|
param.requires_grad = False
|
18
19
|
|
19
20
|
def log_prob_from_model_and_seq(model, seq):
|
20
|
-
|
21
|
+
src_seq, tgt_seq = seq[:, :-1], seq[:, 1:]
|
22
|
+
logits = model(src_seq)
|
21
23
|
log_prob = logits.log_softmax(dim = -1)
|
22
|
-
|
23
|
-
log_probs = log_prob.gather(-1, indices)
|
24
|
-
return rearrange(log_probs, '... 1 -> ...')
|
24
|
+
return einx.get_at('b n [l], b n -> b n', log_prob, tgt_seq)
|
25
25
|
|
26
26
|
def masked_mean(log_probs, mask = None):
|
27
27
|
if not exists(mask):
|
28
28
|
return log_probs.mean(dim = -1)
|
29
29
|
|
30
|
+
if mask.shape[-1] == (log_probs.shape[-1] + 1):
|
31
|
+
mask = mask[:, :-1]
|
32
|
+
|
30
33
|
log_probs = log_probs.masked_fill(~mask, 0.)
|
31
34
|
num = log_probs.sum(dim = -1)
|
32
35
|
den = mask.sum(dim = -1)
|
x_transformers/x_transformers.py
CHANGED
@@ -16,8 +16,9 @@ from collections import namedtuple
|
|
16
16
|
from contextlib import nullcontext
|
17
17
|
from dataclasses import dataclass
|
18
18
|
|
19
|
-
|
19
|
+
import einx
|
20
20
|
from einops.layers.torch import Rearrange
|
21
|
+
from einops import rearrange, repeat, reduce, pack, unpack
|
21
22
|
|
22
23
|
from x_transformers.attend import Attend, Intermediates
|
23
24
|
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
@@ -324,7 +325,7 @@ class RelativePositionBias(Module):
|
|
324
325
|
device = self.device
|
325
326
|
q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
|
326
327
|
k_pos = torch.arange(j, dtype = torch.long, device = device)
|
327
|
-
rel_pos =
|
328
|
+
rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
|
328
329
|
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
329
330
|
values = self.relative_attention_bias(rp_bucket)
|
330
331
|
bias = rearrange(values, 'i j h -> h i j')
|
@@ -351,8 +352,10 @@ class CoPE(Module):
|
|
351
352
|
self.soft_onehot = soft_onehot
|
352
353
|
self.soft_onehot_temp = soft_onehot_temp
|
353
354
|
|
354
|
-
if soft_onehot:
|
355
|
-
|
355
|
+
if not soft_onehot:
|
356
|
+
return
|
357
|
+
|
358
|
+
self.register_buffer('positions', torch.arange(max_pos))
|
356
359
|
|
357
360
|
def forward(self, query, attn_logits):
|
358
361
|
|
@@ -374,7 +377,7 @@ class CoPE(Module):
|
|
374
377
|
logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
|
375
378
|
|
376
379
|
if self.soft_onehot:
|
377
|
-
diff_pos = (
|
380
|
+
diff_pos = einx.subtract('i, j -> i j', pos, self.positions).abs()
|
378
381
|
soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
|
379
382
|
cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
|
380
383
|
else:
|
@@ -423,7 +426,7 @@ class DynamicPositionBias(Module):
|
|
423
426
|
# get the (n x n) matrix of distances
|
424
427
|
seq_arange = torch.arange(n, device = device)
|
425
428
|
context_arange = torch.arange(n, device = device)
|
426
|
-
indices =
|
429
|
+
indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
|
427
430
|
indices += (n - 1)
|
428
431
|
|
429
432
|
# input to continuous positions MLP
|
@@ -453,9 +456,9 @@ class AlibiPositionalBias(Module):
|
|
453
456
|
self.register_buffer('bias', None, persistent = False)
|
454
457
|
|
455
458
|
def get_bias(self, i, j, device):
|
456
|
-
|
457
|
-
|
458
|
-
bias = -torch.abs(
|
459
|
+
seq_arange = torch.arange(j - i, j, device = device)
|
460
|
+
context_arange = torch.arange(j, device = device)
|
461
|
+
bias = -torch.abs(einx.subtract('j, i -> 1 i j', context_arange, seq_arange))
|
459
462
|
return bias
|
460
463
|
|
461
464
|
@staticmethod
|
@@ -490,6 +493,35 @@ class AlibiPositionalBias(Module):
|
|
490
493
|
|
491
494
|
return self.bias
|
492
495
|
|
496
|
+
class DataDependentAlibi(Module):
|
497
|
+
""" https://openreview.net/forum?id=q2Lnyegkr8 """
|
498
|
+
|
499
|
+
def __init__(
|
500
|
+
self,
|
501
|
+
dim,
|
502
|
+
heads
|
503
|
+
):
|
504
|
+
super().__init__()
|
505
|
+
|
506
|
+
linear = nn.Linear(dim, heads)
|
507
|
+
|
508
|
+
self.to_forget_gates = nn.Sequential(
|
509
|
+
linear,
|
510
|
+
Rearrange('b n h -> b h n'),
|
511
|
+
nn.LogSigmoid()
|
512
|
+
)
|
513
|
+
|
514
|
+
nn.init.constant_(linear.bias, 5.)
|
515
|
+
|
516
|
+
def forward(self, x):
|
517
|
+
seq = x.shape[-2]
|
518
|
+
|
519
|
+
forget_gates = self.to_forget_gates(x)
|
520
|
+
forget_gates = forget_gates.cumsum(dim = -1)
|
521
|
+
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
|
522
|
+
|
523
|
+
return forget_gates
|
524
|
+
|
493
525
|
class RotaryEmbedding(Module):
|
494
526
|
def __init__(
|
495
527
|
self,
|
@@ -938,6 +970,7 @@ class Attention(Module):
|
|
938
970
|
tensor_product = False, # https://arxiv.org/abs/2208.06061
|
939
971
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
940
972
|
rotary_embed_values = False,
|
973
|
+
data_dependent_alibi = False,
|
941
974
|
use_cope = False,
|
942
975
|
cope_max_pos = 16,
|
943
976
|
cope_soft_onehot_pos = False,
|
@@ -1041,6 +1074,19 @@ class Attention(Module):
|
|
1041
1074
|
soft_onehot = cope_soft_onehot_pos
|
1042
1075
|
)
|
1043
1076
|
|
1077
|
+
# data dependent alibi
|
1078
|
+
# https://openreview.net/forum?id=q2Lnyegkr8
|
1079
|
+
|
1080
|
+
self.data_dependent_alibi = None
|
1081
|
+
|
1082
|
+
if data_dependent_alibi:
|
1083
|
+
assert causal, 'data dependent alibi only works for autoregressive for now until further research'
|
1084
|
+
|
1085
|
+
self.data_dependent_alibi = DataDependentAlibi(
|
1086
|
+
dim,
|
1087
|
+
heads = heads
|
1088
|
+
)
|
1089
|
+
|
1044
1090
|
# attend class - includes core attention algorithm + talking heads
|
1045
1091
|
|
1046
1092
|
self.attend = Attend(
|
@@ -1236,7 +1282,7 @@ class Attention(Module):
|
|
1236
1282
|
if exists(self.max_attend_past):
|
1237
1283
|
range_q = torch.arange(j - i, j, device = device)
|
1238
1284
|
range_k = torch.arange(j, device = device)
|
1239
|
-
dist =
|
1285
|
+
dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
|
1240
1286
|
max_attend_past_mask = dist > self.max_attend_past
|
1241
1287
|
max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
|
1242
1288
|
masks.append(max_attend_past_mask)
|
@@ -1251,6 +1297,11 @@ class Attention(Module):
|
|
1251
1297
|
attn_bias = rel_pos(i, j)
|
1252
1298
|
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
|
1253
1299
|
|
1300
|
+
# prepare data dependent alibi from forgetting transformers paper, if needed
|
1301
|
+
|
1302
|
+
if exists(self.data_dependent_alibi):
|
1303
|
+
attn_bias = self.data_dependent_alibi(x)
|
1304
|
+
|
1254
1305
|
# if previous values passed in for residual, either invoke resformer or neutreno
|
1255
1306
|
|
1256
1307
|
if exists(value_residual):
|
@@ -1291,7 +1342,7 @@ class Attention(Module):
|
|
1291
1342
|
|
1292
1343
|
if exists(self.to_v_head_gate):
|
1293
1344
|
head_gate = self.to_v_head_gate(x)
|
1294
|
-
out =
|
1345
|
+
out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
|
1295
1346
|
|
1296
1347
|
# merge heads
|
1297
1348
|
|
@@ -1308,8 +1359,7 @@ class Attention(Module):
|
|
1308
1359
|
out = self.to_out(out)
|
1309
1360
|
|
1310
1361
|
if exists(mask):
|
1311
|
-
|
1312
|
-
out = out.masked_fill(~mask, 0.)
|
1362
|
+
out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
|
1313
1363
|
|
1314
1364
|
if not return_intermediates:
|
1315
1365
|
return out
|
@@ -1389,10 +1439,13 @@ class AttentionLayers(Module):
|
|
1389
1439
|
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
|
1390
1440
|
cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)
|
1391
1441
|
|
1442
|
+
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
1443
|
+
data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
|
1444
|
+
neutreno_value_residual = attn_kwargs.get('neutreno_value_residual', False)
|
1445
|
+
|
1392
1446
|
assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
|
1393
1447
|
|
1394
|
-
|
1395
|
-
add_value_residual |= attn_kwargs.get('neutreno_value_residual', False)
|
1448
|
+
add_value_residual |= neutreno_value_residual
|
1396
1449
|
|
1397
1450
|
self.dim = dim
|
1398
1451
|
self.causal = causal
|
@@ -1405,7 +1458,7 @@ class AttentionLayers(Module):
|
|
1405
1458
|
assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
|
1406
1459
|
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
|
1407
1460
|
|
1408
|
-
assert
|
1461
|
+
assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), or T5 relative positional bias'
|
1409
1462
|
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
1410
1463
|
|
1411
1464
|
# relative positional bias
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.41.1
|
4
4
|
Summary: X-Transformers - Pytorch
|
5
5
|
Home-page: https://github.com/lucidrains/x-transformers
|
6
6
|
Author: Phil Wang
|
@@ -15,5 +15,6 @@ Classifier: Programming Language :: Python :: 3.6
|
|
15
15
|
Description-Content-Type: text/markdown
|
16
16
|
License-File: LICENSE
|
17
17
|
Requires-Dist: torch >=2.0
|
18
|
+
Requires-Dist: einx >=0.3.0
|
18
19
|
Requires-Dist: einops >=0.8.0
|
19
20
|
|
@@ -2,14 +2,14 @@ x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,79
|
|
2
2
|
x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
|
-
x_transformers/dpo.py,sha256=
|
5
|
+
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=n8W19Pnhbz-JxbC7QATApWrhI_yC4oqTHGQ1NLuindY,89814
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.41.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.41.1.dist-info/METADATA,sha256=UIPYbEVBLrWDGuezlnyh2tFKPlM_Mdj-pYTGxse_NMI,689
|
13
|
+
x_transformers-1.41.1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
14
|
+
x_transformers-1.41.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.41.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|