x-transformers 1.41.2__py3-none-any.whl → 1.41.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +15 -8
- {x_transformers-1.41.2.dist-info → x_transformers-1.41.3.dist-info}/METADATA +1 -1
- {x_transformers-1.41.2.dist-info → x_transformers-1.41.3.dist-info}/RECORD +6 -6
- {x_transformers-1.41.2.dist-info → x_transformers-1.41.3.dist-info}/LICENSE +0 -0
- {x_transformers-1.41.2.dist-info → x_transformers-1.41.3.dist-info}/WHEEL +0 -0
- {x_transformers-1.41.2.dist-info → x_transformers-1.41.3.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -499,7 +499,9 @@ class DataDependentAlibi(Module):
|
|
499
499
|
def __init__(
|
500
500
|
self,
|
501
501
|
dim,
|
502
|
-
heads
|
502
|
+
heads,
|
503
|
+
bias_init = 5.,
|
504
|
+
post_log_scale = 1.
|
503
505
|
):
|
504
506
|
super().__init__()
|
505
507
|
|
@@ -511,22 +513,24 @@ class DataDependentAlibi(Module):
|
|
511
513
|
nn.LogSigmoid()
|
512
514
|
)
|
513
515
|
|
514
|
-
nn.init.constant_(linear.bias,
|
516
|
+
nn.init.constant_(linear.bias, bias_init)
|
517
|
+
self.post_log_scale = post_log_scale
|
515
518
|
|
516
519
|
def forward(self, x):
|
517
|
-
forget_gates = self.to_forget_gates(x)
|
520
|
+
forget_gates = self.to_forget_gates(x) * self.post_log_scale
|
518
521
|
forget_gates = forget_gates.cumsum(dim = -1)
|
519
522
|
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
|
520
523
|
return forget_gates
|
521
524
|
|
522
525
|
class PerRowDataDependentAlibi(Module):
|
523
|
-
""" same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a
|
526
|
+
""" same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queries and keys with a small head dimension """
|
524
527
|
|
525
528
|
def __init__(
|
526
529
|
self,
|
527
530
|
dim,
|
528
531
|
heads,
|
529
|
-
dim_head = 8
|
532
|
+
dim_head = 8,
|
533
|
+
post_log_scale = 1.
|
530
534
|
):
|
531
535
|
super().__init__()
|
532
536
|
self.scale = dim_head ** -0.5
|
@@ -535,14 +539,16 @@ class PerRowDataDependentAlibi(Module):
|
|
535
539
|
|
536
540
|
self.to_forget_gates = nn.Sequential(
|
537
541
|
linear,
|
538
|
-
Rearrange('b n (
|
542
|
+
Rearrange('b n (qk h d) -> qk b h n d', qk = 2, d = dim_head)
|
539
543
|
)
|
540
544
|
|
545
|
+
self.post_log_scale = post_log_scale
|
546
|
+
|
541
547
|
def forward(self, x):
|
542
548
|
q, k = self.to_forget_gates(x)
|
543
549
|
forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
|
544
550
|
|
545
|
-
forget_gates = F.logsigmoid(forget_gates)
|
551
|
+
forget_gates = F.logsigmoid(forget_gates) * self.post_log_scale
|
546
552
|
|
547
553
|
# mask out upper triangle + diagonal
|
548
554
|
|
@@ -1010,6 +1016,7 @@ class Attention(Module):
|
|
1010
1016
|
data_dependent_alibi = False,
|
1011
1017
|
data_dependent_alibi_per_row = False,
|
1012
1018
|
data_dependent_alibi_per_row_dim_head = 8,
|
1019
|
+
data_dependent_alibi_kwargs: dict = dict(),
|
1013
1020
|
use_cope = False,
|
1014
1021
|
cope_max_pos = 16,
|
1015
1022
|
cope_soft_onehot_pos = False,
|
@@ -1127,7 +1134,7 @@ class Attention(Module):
|
|
1127
1134
|
if data_dependent_alibi_per_row:
|
1128
1135
|
dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
|
1129
1136
|
|
1130
|
-
self.data_dependent_alibi = dda_klass(**dda_kwargs)
|
1137
|
+
self.data_dependent_alibi = dda_klass(**dda_kwargs, **data_dependent_alibi_kwargs)
|
1131
1138
|
|
1132
1139
|
# attend class - includes core attention algorithm + talking heads
|
1133
1140
|
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=aaxMw4iJkLAh8D4W0g8EwWDfIJVPYdhpmj5T9ZFa0js,91642
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.41.
|
12
|
-
x_transformers-1.41.
|
13
|
-
x_transformers-1.41.
|
14
|
-
x_transformers-1.41.
|
15
|
-
x_transformers-1.41.
|
11
|
+
x_transformers-1.41.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.41.3.dist-info/METADATA,sha256=GyMdKP9ErEnGHbBqY3MwpSGSE41lK_KQbeE5n_O2Py4,689
|
13
|
+
x_transformers-1.41.3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
14
|
+
x_transformers-1.41.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.41.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|