x-transformers 1.42.4__py3-none-any.whl → 1.42.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +21 -4
- {x_transformers-1.42.4.dist-info → x_transformers-1.42.5.dist-info}/METADATA +1 -1
- {x_transformers-1.42.4.dist-info → x_transformers-1.42.5.dist-info}/RECORD +6 -6
- {x_transformers-1.42.4.dist-info → x_transformers-1.42.5.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.4.dist-info → x_transformers-1.42.5.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.4.dist-info → x_transformers-1.42.5.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -512,12 +512,15 @@ class DataDependentAlibi(Module):
|
|
512
512
|
self,
|
513
513
|
dim,
|
514
514
|
heads,
|
515
|
+
causal = True,
|
515
516
|
bias_init = 5.,
|
516
|
-
post_log_scale = 1
|
517
|
+
post_log_scale = 1.,
|
517
518
|
):
|
518
519
|
super().__init__()
|
519
520
|
|
520
|
-
|
521
|
+
self.causal = causal
|
522
|
+
|
523
|
+
linear = nn.Linear(dim, heads * (1 if causal else 2))
|
521
524
|
|
522
525
|
self.to_forget_gates = nn.Sequential(
|
523
526
|
linear,
|
@@ -529,9 +532,21 @@ class DataDependentAlibi(Module):
|
|
529
532
|
self.post_log_scale = post_log_scale
|
530
533
|
|
531
534
|
def forward(self, x):
|
535
|
+
bidirectional = not self.causal
|
536
|
+
|
532
537
|
forget_gates = self.to_forget_gates(x) * self.post_log_scale
|
538
|
+
|
533
539
|
forget_gates = forget_gates.cumsum(dim = -1)
|
540
|
+
|
541
|
+
if bidirectional:
|
542
|
+
forget_gates, forget_gates_reversed = forget_gates.chunk(2, dim = 1)
|
543
|
+
|
534
544
|
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
|
545
|
+
|
546
|
+
if bidirectional:
|
547
|
+
forget_gates_reversed = einx.subtract('b h j, b h i -> b h i j', forget_gates_reversed, forget_gates_reversed)
|
548
|
+
forget_gates = forget_gates.tril() + forget_gates_reversed.triu()
|
549
|
+
|
535
550
|
return forget_gates
|
536
551
|
|
537
552
|
class PerRowDataDependentAlibi(Module):
|
@@ -541,10 +556,13 @@ class PerRowDataDependentAlibi(Module):
|
|
541
556
|
self,
|
542
557
|
dim,
|
543
558
|
heads,
|
559
|
+
causal = True,
|
544
560
|
dim_head = 8,
|
545
561
|
post_log_scale = 1.
|
546
562
|
):
|
547
563
|
super().__init__()
|
564
|
+
assert causal, 'bidirectional not supported yet'
|
565
|
+
|
548
566
|
self.scale = dim_head ** -0.5
|
549
567
|
|
550
568
|
linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
|
@@ -1138,10 +1156,9 @@ class Attention(Module):
|
|
1138
1156
|
self.data_dependent_alibi = None
|
1139
1157
|
|
1140
1158
|
if data_dependent_alibi:
|
1141
|
-
assert causal, 'data dependent alibi only works for autoregressive for now until further research'
|
1142
1159
|
|
1143
1160
|
dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
|
1144
|
-
dda_kwargs = dict(dim = dim, heads = heads)
|
1161
|
+
dda_kwargs = dict(dim = dim, heads = heads, causal = causal)
|
1145
1162
|
|
1146
1163
|
if data_dependent_alibi_per_row:
|
1147
1164
|
dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=KIR7efx59xl0BVshU1e6RO0YKgz7zeYBXITDNYWJ4mQ,92506
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.5.dist-info/METADATA,sha256=cLnay5nt6F6GKdghsaqHiZQsVmJ9dS5l-IDozZIs3ec,689
|
14
|
+
x_transformers-1.42.5.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
15
|
+
x_transformers-1.42.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|