x-transformers 1.41.1__py3-none-any.whl → 1.41.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +48 -6
- {x_transformers-1.41.1.dist-info → x_transformers-1.41.2.dist-info}/METADATA +1 -1
- {x_transformers-1.41.1.dist-info → x_transformers-1.41.2.dist-info}/RECORD +6 -6
- {x_transformers-1.41.1.dist-info → x_transformers-1.41.2.dist-info}/LICENSE +0 -0
- {x_transformers-1.41.1.dist-info → x_transformers-1.41.2.dist-info}/WHEEL +0 -0
- {x_transformers-1.41.1.dist-info → x_transformers-1.41.2.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -514,11 +514,48 @@ class DataDependentAlibi(Module):
|
|
514
514
|
nn.init.constant_(linear.bias, 5.)
|
515
515
|
|
516
516
|
def forward(self, x):
|
517
|
-
seq = x.shape[-2]
|
518
|
-
|
519
517
|
forget_gates = self.to_forget_gates(x)
|
520
518
|
forget_gates = forget_gates.cumsum(dim = -1)
|
521
519
|
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
|
520
|
+
return forget_gates
|
521
|
+
|
522
|
+
class PerRowDataDependentAlibi(Module):
|
523
|
+
""" same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queris and keys with a small head dimension """
|
524
|
+
|
525
|
+
def __init__(
|
526
|
+
self,
|
527
|
+
dim,
|
528
|
+
heads,
|
529
|
+
dim_head = 8
|
530
|
+
):
|
531
|
+
super().__init__()
|
532
|
+
self.scale = dim_head ** -0.5
|
533
|
+
|
534
|
+
linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
|
535
|
+
|
536
|
+
self.to_forget_gates = nn.Sequential(
|
537
|
+
linear,
|
538
|
+
Rearrange('b n (kv h d) -> kv b h n d', kv = 2, d = dim_head)
|
539
|
+
)
|
540
|
+
|
541
|
+
def forward(self, x):
|
542
|
+
q, k = self.to_forget_gates(x)
|
543
|
+
forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
|
544
|
+
|
545
|
+
forget_gates = F.logsigmoid(forget_gates)
|
546
|
+
|
547
|
+
# mask out upper triangle + diagonal
|
548
|
+
|
549
|
+
n = x.shape[-2]
|
550
|
+
causal_mask = torch.ones((n, n), dtype = torch.bool, device = x.device).triu()
|
551
|
+
|
552
|
+
forget_gates = forget_gates.masked_fill(causal_mask, 0.)
|
553
|
+
|
554
|
+
# reverse cumsum
|
555
|
+
|
556
|
+
forget_gates = forget_gates.flip(dims = (-1,))
|
557
|
+
forget_gates = forget_gates.cumsum(dim = -1)
|
558
|
+
forget_gates = forget_gates.flip(dims = (-1,))
|
522
559
|
|
523
560
|
return forget_gates
|
524
561
|
|
@@ -971,6 +1008,8 @@ class Attention(Module):
|
|
971
1008
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
972
1009
|
rotary_embed_values = False,
|
973
1010
|
data_dependent_alibi = False,
|
1011
|
+
data_dependent_alibi_per_row = False,
|
1012
|
+
data_dependent_alibi_per_row_dim_head = 8,
|
974
1013
|
use_cope = False,
|
975
1014
|
cope_max_pos = 16,
|
976
1015
|
cope_soft_onehot_pos = False,
|
@@ -1082,10 +1121,13 @@ class Attention(Module):
|
|
1082
1121
|
if data_dependent_alibi:
|
1083
1122
|
assert causal, 'data dependent alibi only works for autoregressive for now until further research'
|
1084
1123
|
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1124
|
+
dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
|
1125
|
+
dda_kwargs = dict(dim = dim, heads = heads)
|
1126
|
+
|
1127
|
+
if data_dependent_alibi_per_row:
|
1128
|
+
dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
|
1129
|
+
|
1130
|
+
self.data_dependent_alibi = dda_klass(**dda_kwargs)
|
1089
1131
|
|
1090
1132
|
# attend class - includes core attention algorithm + talking heads
|
1091
1133
|
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=5yoN96mTpw3qygZg2plVed5lTOs9p0REtt50lOlllHU,91334
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.41.
|
12
|
-
x_transformers-1.41.
|
13
|
-
x_transformers-1.41.
|
14
|
-
x_transformers-1.41.
|
15
|
-
x_transformers-1.41.
|
11
|
+
x_transformers-1.41.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.41.2.dist-info/METADATA,sha256=84WHjWdVHIxkHdHxMEKK5OEssLU80cdj4aGTJQ7SpO4,689
|
13
|
+
x_transformers-1.41.2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
14
|
+
x_transformers-1.41.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.41.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|