x-transformers 1.41.0__py3-none-any.whl → 1.41.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +45 -18
- {x_transformers-1.41.0.dist-info → x_transformers-1.41.2.dist-info}/METADATA +1 -1
- {x_transformers-1.41.0.dist-info → x_transformers-1.41.2.dist-info}/RECORD +6 -6
- {x_transformers-1.41.0.dist-info → x_transformers-1.41.2.dist-info}/LICENSE +0 -0
- {x_transformers-1.41.0.dist-info → x_transformers-1.41.2.dist-info}/WHEEL +0 -0
- {x_transformers-1.41.0.dist-info → x_transformers-1.41.2.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -101,12 +101,6 @@ def log(t, eps = 1e-20):
|
|
101
101
|
def max_neg_value(tensor):
|
102
102
|
return -torch.finfo(tensor.dtype).max
|
103
103
|
|
104
|
-
def reverse_cumsum(t, dim = -1):
|
105
|
-
t = t.flip(dims = (dim,))
|
106
|
-
t = t.cumsum(dim = dim)
|
107
|
-
t = t.flip(dims = (dim,))
|
108
|
-
return t
|
109
|
-
|
110
104
|
def l2norm(t, groups = 1):
|
111
105
|
t = rearrange(t, '... (g d) -> ... g d', g = groups)
|
112
106
|
t = F.normalize(t, p = 2, dim = -1)
|
@@ -514,26 +508,54 @@ class DataDependentAlibi(Module):
|
|
514
508
|
self.to_forget_gates = nn.Sequential(
|
515
509
|
linear,
|
516
510
|
Rearrange('b n h -> b h n'),
|
517
|
-
nn.
|
511
|
+
nn.LogSigmoid()
|
518
512
|
)
|
519
513
|
|
520
514
|
nn.init.constant_(linear.bias, 5.)
|
521
515
|
|
522
516
|
def forward(self, x):
|
523
|
-
|
517
|
+
forget_gates = self.to_forget_gates(x)
|
518
|
+
forget_gates = forget_gates.cumsum(dim = -1)
|
519
|
+
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
|
520
|
+
return forget_gates
|
524
521
|
|
525
|
-
|
526
|
-
|
522
|
+
class PerRowDataDependentAlibi(Module):
|
523
|
+
""" same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queris and keys with a small head dimension """
|
527
524
|
|
528
|
-
|
525
|
+
def __init__(
|
526
|
+
self,
|
527
|
+
dim,
|
528
|
+
heads,
|
529
|
+
dim_head = 8
|
530
|
+
):
|
531
|
+
super().__init__()
|
532
|
+
self.scale = dim_head ** -0.5
|
529
533
|
|
530
|
-
|
534
|
+
linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
|
535
|
+
|
536
|
+
self.to_forget_gates = nn.Sequential(
|
537
|
+
linear,
|
538
|
+
Rearrange('b n (kv h d) -> kv b h n d', kv = 2, d = dim_head)
|
539
|
+
)
|
540
|
+
|
541
|
+
def forward(self, x):
|
542
|
+
q, k = self.to_forget_gates(x)
|
543
|
+
forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
|
544
|
+
|
545
|
+
forget_gates = F.logsigmoid(forget_gates)
|
546
|
+
|
547
|
+
# mask out upper triangle + diagonal
|
548
|
+
|
549
|
+
n = x.shape[-2]
|
550
|
+
causal_mask = torch.ones((n, n), dtype = torch.bool, device = x.device).triu()
|
531
551
|
|
532
552
|
forget_gates = forget_gates.masked_fill(causal_mask, 0.)
|
533
553
|
|
534
|
-
# reverse
|
554
|
+
# reverse cumsum
|
535
555
|
|
536
|
-
forget_gates =
|
556
|
+
forget_gates = forget_gates.flip(dims = (-1,))
|
557
|
+
forget_gates = forget_gates.cumsum(dim = -1)
|
558
|
+
forget_gates = forget_gates.flip(dims = (-1,))
|
537
559
|
|
538
560
|
return forget_gates
|
539
561
|
|
@@ -986,6 +1008,8 @@ class Attention(Module):
|
|
986
1008
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
987
1009
|
rotary_embed_values = False,
|
988
1010
|
data_dependent_alibi = False,
|
1011
|
+
data_dependent_alibi_per_row = False,
|
1012
|
+
data_dependent_alibi_per_row_dim_head = 8,
|
989
1013
|
use_cope = False,
|
990
1014
|
cope_max_pos = 16,
|
991
1015
|
cope_soft_onehot_pos = False,
|
@@ -1097,10 +1121,13 @@ class Attention(Module):
|
|
1097
1121
|
if data_dependent_alibi:
|
1098
1122
|
assert causal, 'data dependent alibi only works for autoregressive for now until further research'
|
1099
1123
|
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1124
|
+
dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
|
1125
|
+
dda_kwargs = dict(dim = dim, heads = heads)
|
1126
|
+
|
1127
|
+
if data_dependent_alibi_per_row:
|
1128
|
+
dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
|
1129
|
+
|
1130
|
+
self.data_dependent_alibi = dda_klass(**dda_kwargs)
|
1104
1131
|
|
1105
1132
|
# attend class - includes core attention algorithm + talking heads
|
1106
1133
|
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=5yoN96mTpw3qygZg2plVed5lTOs9p0REtt50lOlllHU,91334
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.41.
|
12
|
-
x_transformers-1.41.
|
13
|
-
x_transformers-1.41.
|
14
|
-
x_transformers-1.41.
|
15
|
-
x_transformers-1.41.
|
11
|
+
x_transformers-1.41.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.41.2.dist-info/METADATA,sha256=84WHjWdVHIxkHdHxMEKK5OEssLU80cdj4aGTJQ7SpO4,689
|
13
|
+
x_transformers-1.41.2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
14
|
+
x_transformers-1.41.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.41.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|