x-transformers 1.42.4__py3-none-any.whl → 1.42.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +29 -6
- {x_transformers-1.42.4.dist-info → x_transformers-1.42.6.dist-info}/METADATA +1 -1
- {x_transformers-1.42.4.dist-info → x_transformers-1.42.6.dist-info}/RECORD +6 -6
- {x_transformers-1.42.4.dist-info → x_transformers-1.42.6.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.4.dist-info → x_transformers-1.42.6.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.4.dist-info → x_transformers-1.42.6.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -512,12 +512,15 @@ class DataDependentAlibi(Module):
|
|
512
512
|
self,
|
513
513
|
dim,
|
514
514
|
heads,
|
515
|
+
causal = True,
|
515
516
|
bias_init = 5.,
|
516
|
-
post_log_scale = 1
|
517
|
+
post_log_scale = 1.,
|
517
518
|
):
|
518
519
|
super().__init__()
|
519
520
|
|
520
|
-
|
521
|
+
self.causal = causal
|
522
|
+
|
523
|
+
linear = nn.Linear(dim, heads * (1 if causal else 2))
|
521
524
|
|
522
525
|
self.to_forget_gates = nn.Sequential(
|
523
526
|
linear,
|
@@ -529,9 +532,21 @@ class DataDependentAlibi(Module):
|
|
529
532
|
self.post_log_scale = post_log_scale
|
530
533
|
|
531
534
|
def forward(self, x):
|
535
|
+
bidirectional = not self.causal
|
536
|
+
|
532
537
|
forget_gates = self.to_forget_gates(x) * self.post_log_scale
|
538
|
+
|
533
539
|
forget_gates = forget_gates.cumsum(dim = -1)
|
540
|
+
|
541
|
+
if bidirectional:
|
542
|
+
forget_gates, forget_gates_reversed = forget_gates.chunk(2, dim = 1)
|
543
|
+
|
534
544
|
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
|
545
|
+
|
546
|
+
if bidirectional:
|
547
|
+
forget_gates_reversed = einx.subtract('b h j, b h i -> b h i j', forget_gates_reversed, forget_gates_reversed)
|
548
|
+
forget_gates = forget_gates.tril() + forget_gates_reversed.triu()
|
549
|
+
|
535
550
|
return forget_gates
|
536
551
|
|
537
552
|
class PerRowDataDependentAlibi(Module):
|
@@ -541,10 +556,13 @@ class PerRowDataDependentAlibi(Module):
|
|
541
556
|
self,
|
542
557
|
dim,
|
543
558
|
heads,
|
559
|
+
causal = True,
|
544
560
|
dim_head = 8,
|
545
561
|
post_log_scale = 1.
|
546
562
|
):
|
547
563
|
super().__init__()
|
564
|
+
assert causal, 'bidirectional not supported yet'
|
565
|
+
|
548
566
|
self.scale = dim_head ** -0.5
|
549
567
|
|
550
568
|
linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
|
@@ -1037,7 +1055,12 @@ class Attention(Module):
|
|
1037
1055
|
logit_softclamp_value = 50.,
|
1038
1056
|
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
|
1039
1057
|
neutreno_alpha = 0.4,
|
1040
|
-
onnxable = False
|
1058
|
+
onnxable = False,
|
1059
|
+
attend_sdp_kwargs: dict = dict(
|
1060
|
+
enable_flash = True,
|
1061
|
+
enable_math = True,
|
1062
|
+
enable_mem_efficient = True
|
1063
|
+
)
|
1041
1064
|
):
|
1042
1065
|
super().__init__()
|
1043
1066
|
dim_kv = default(dim_context, dim)
|
@@ -1138,10 +1161,9 @@ class Attention(Module):
|
|
1138
1161
|
self.data_dependent_alibi = None
|
1139
1162
|
|
1140
1163
|
if data_dependent_alibi:
|
1141
|
-
assert causal, 'data dependent alibi only works for autoregressive for now until further research'
|
1142
1164
|
|
1143
1165
|
dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
|
1144
|
-
dda_kwargs = dict(dim = dim, heads = heads)
|
1166
|
+
dda_kwargs = dict(dim = dim, heads = heads, causal = causal)
|
1145
1167
|
|
1146
1168
|
if data_dependent_alibi_per_row:
|
1147
1169
|
dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
|
@@ -1171,7 +1193,8 @@ class Attention(Module):
|
|
1171
1193
|
softclamp_logits = softclamp_logits,
|
1172
1194
|
logit_softclamp_value = logit_softclamp_value,
|
1173
1195
|
cope = cope,
|
1174
|
-
onnxable = onnxable
|
1196
|
+
onnxable = onnxable,
|
1197
|
+
sdp_kwargs = attend_sdp_kwargs
|
1175
1198
|
)
|
1176
1199
|
|
1177
1200
|
# head scaling
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=cPsSl1s14_c9fMdn9cZwe6Eg3aDbcRyCTsoXUJusWUg,92706
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.6.dist-info/METADATA,sha256=OANeMK9I504gC7iErAdYMTGBUEl6FOcEwm97o4OyC1k,689
|
14
|
+
x_transformers-1.42.6.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
15
|
+
x_transformers-1.42.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|