x-transformers 1.41.1__py3-none-any.whl → 1.41.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +58 -9
- {x_transformers-1.41.1.dist-info → x_transformers-1.41.3.dist-info}/METADATA +1 -1
- {x_transformers-1.41.1.dist-info → x_transformers-1.41.3.dist-info}/RECORD +6 -6
- {x_transformers-1.41.1.dist-info → x_transformers-1.41.3.dist-info}/LICENSE +0 -0
- {x_transformers-1.41.1.dist-info → x_transformers-1.41.3.dist-info}/WHEEL +0 -0
- {x_transformers-1.41.1.dist-info → x_transformers-1.41.3.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -499,7 +499,9 @@ class DataDependentAlibi(Module):
|
|
499
499
|
def __init__(
|
500
500
|
self,
|
501
501
|
dim,
|
502
|
-
heads
|
502
|
+
heads,
|
503
|
+
bias_init = 5.,
|
504
|
+
post_log_scale = 1.
|
503
505
|
):
|
504
506
|
super().__init__()
|
505
507
|
|
@@ -511,14 +513,55 @@ class DataDependentAlibi(Module):
|
|
511
513
|
nn.LogSigmoid()
|
512
514
|
)
|
513
515
|
|
514
|
-
nn.init.constant_(linear.bias,
|
516
|
+
nn.init.constant_(linear.bias, bias_init)
|
517
|
+
self.post_log_scale = post_log_scale
|
515
518
|
|
516
519
|
def forward(self, x):
|
517
|
-
|
518
|
-
|
519
|
-
forget_gates = self.to_forget_gates(x)
|
520
|
+
forget_gates = self.to_forget_gates(x) * self.post_log_scale
|
520
521
|
forget_gates = forget_gates.cumsum(dim = -1)
|
521
522
|
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
|
523
|
+
return forget_gates
|
524
|
+
|
525
|
+
class PerRowDataDependentAlibi(Module):
|
526
|
+
""" same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queries and keys with a small head dimension """
|
527
|
+
|
528
|
+
def __init__(
|
529
|
+
self,
|
530
|
+
dim,
|
531
|
+
heads,
|
532
|
+
dim_head = 8,
|
533
|
+
post_log_scale = 1.
|
534
|
+
):
|
535
|
+
super().__init__()
|
536
|
+
self.scale = dim_head ** -0.5
|
537
|
+
|
538
|
+
linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
|
539
|
+
|
540
|
+
self.to_forget_gates = nn.Sequential(
|
541
|
+
linear,
|
542
|
+
Rearrange('b n (qk h d) -> qk b h n d', qk = 2, d = dim_head)
|
543
|
+
)
|
544
|
+
|
545
|
+
self.post_log_scale = post_log_scale
|
546
|
+
|
547
|
+
def forward(self, x):
|
548
|
+
q, k = self.to_forget_gates(x)
|
549
|
+
forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
|
550
|
+
|
551
|
+
forget_gates = F.logsigmoid(forget_gates) * self.post_log_scale
|
552
|
+
|
553
|
+
# mask out upper triangle + diagonal
|
554
|
+
|
555
|
+
n = x.shape[-2]
|
556
|
+
causal_mask = torch.ones((n, n), dtype = torch.bool, device = x.device).triu()
|
557
|
+
|
558
|
+
forget_gates = forget_gates.masked_fill(causal_mask, 0.)
|
559
|
+
|
560
|
+
# reverse cumsum
|
561
|
+
|
562
|
+
forget_gates = forget_gates.flip(dims = (-1,))
|
563
|
+
forget_gates = forget_gates.cumsum(dim = -1)
|
564
|
+
forget_gates = forget_gates.flip(dims = (-1,))
|
522
565
|
|
523
566
|
return forget_gates
|
524
567
|
|
@@ -971,6 +1014,9 @@ class Attention(Module):
|
|
971
1014
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
972
1015
|
rotary_embed_values = False,
|
973
1016
|
data_dependent_alibi = False,
|
1017
|
+
data_dependent_alibi_per_row = False,
|
1018
|
+
data_dependent_alibi_per_row_dim_head = 8,
|
1019
|
+
data_dependent_alibi_kwargs: dict = dict(),
|
974
1020
|
use_cope = False,
|
975
1021
|
cope_max_pos = 16,
|
976
1022
|
cope_soft_onehot_pos = False,
|
@@ -1082,10 +1128,13 @@ class Attention(Module):
|
|
1082
1128
|
if data_dependent_alibi:
|
1083
1129
|
assert causal, 'data dependent alibi only works for autoregressive for now until further research'
|
1084
1130
|
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1131
|
+
dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
|
1132
|
+
dda_kwargs = dict(dim = dim, heads = heads)
|
1133
|
+
|
1134
|
+
if data_dependent_alibi_per_row:
|
1135
|
+
dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
|
1136
|
+
|
1137
|
+
self.data_dependent_alibi = dda_klass(**dda_kwargs, **data_dependent_alibi_kwargs)
|
1089
1138
|
|
1090
1139
|
# attend class - includes core attention algorithm + talking heads
|
1091
1140
|
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=aaxMw4iJkLAh8D4W0g8EwWDfIJVPYdhpmj5T9ZFa0js,91642
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.41.
|
12
|
-
x_transformers-1.41.
|
13
|
-
x_transformers-1.41.
|
14
|
-
x_transformers-1.41.
|
15
|
-
x_transformers-1.41.
|
11
|
+
x_transformers-1.41.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.41.3.dist-info/METADATA,sha256=GyMdKP9ErEnGHbBqY3MwpSGSE41lK_KQbeE5n_O2Py4,689
|
13
|
+
x_transformers-1.41.3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
14
|
+
x_transformers-1.41.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.41.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|