x-transformers 1.42.4__py3-none-any.whl → 1.42.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -512,12 +512,15 @@ class DataDependentAlibi(Module):
512
512
  self,
513
513
  dim,
514
514
  heads,
515
+ causal = True,
515
516
  bias_init = 5.,
516
- post_log_scale = 1.
517
+ post_log_scale = 1.,
517
518
  ):
518
519
  super().__init__()
519
520
 
520
- linear = nn.Linear(dim, heads)
521
+ self.causal = causal
522
+
523
+ linear = nn.Linear(dim, heads * (1 if causal else 2))
521
524
 
522
525
  self.to_forget_gates = nn.Sequential(
523
526
  linear,
@@ -529,9 +532,21 @@ class DataDependentAlibi(Module):
529
532
  self.post_log_scale = post_log_scale
530
533
 
531
534
  def forward(self, x):
535
+ bidirectional = not self.causal
536
+
532
537
  forget_gates = self.to_forget_gates(x) * self.post_log_scale
538
+
533
539
  forget_gates = forget_gates.cumsum(dim = -1)
540
+
541
+ if bidirectional:
542
+ forget_gates, forget_gates_reversed = forget_gates.chunk(2, dim = 1)
543
+
534
544
  forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
545
+
546
+ if bidirectional:
547
+ forget_gates_reversed = einx.subtract('b h j, b h i -> b h i j', forget_gates_reversed, forget_gates_reversed)
548
+ forget_gates = forget_gates.tril() + forget_gates_reversed.triu()
549
+
535
550
  return forget_gates
536
551
 
537
552
  class PerRowDataDependentAlibi(Module):
@@ -541,10 +556,13 @@ class PerRowDataDependentAlibi(Module):
541
556
  self,
542
557
  dim,
543
558
  heads,
559
+ causal = True,
544
560
  dim_head = 8,
545
561
  post_log_scale = 1.
546
562
  ):
547
563
  super().__init__()
564
+ assert causal, 'bidirectional not supported yet'
565
+
548
566
  self.scale = dim_head ** -0.5
549
567
 
550
568
  linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
@@ -1138,10 +1156,9 @@ class Attention(Module):
1138
1156
  self.data_dependent_alibi = None
1139
1157
 
1140
1158
  if data_dependent_alibi:
1141
- assert causal, 'data dependent alibi only works for autoregressive for now until further research'
1142
1159
 
1143
1160
  dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
1144
- dda_kwargs = dict(dim = dim, heads = heads)
1161
+ dda_kwargs = dict(dim = dim, heads = heads, causal = causal)
1145
1162
 
1146
1163
  if data_dependent_alibi_per_row:
1147
1164
  dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.4
3
+ Version: 1.42.5
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=o_Rm-v1XJyIYU_zDcXWxbHN6whFcK8VKRHvlqTNaQTc,92062
9
+ x_transformers/x_transformers.py,sha256=KIR7efx59xl0BVshU1e6RO0YKgz7zeYBXITDNYWJ4mQ,92506
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.4.dist-info/METADATA,sha256=9KikrLFEmmDn92O9ne5Qd6pEuiztY21vkXg5KWiChhw,689
14
- x_transformers-1.42.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
- x_transformers-1.42.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.4.dist-info/RECORD,,
12
+ x_transformers-1.42.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.5.dist-info/METADATA,sha256=cLnay5nt6F6GKdghsaqHiZQsVmJ9dS5l-IDozZIs3ec,689
14
+ x_transformers-1.42.5.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
+ x_transformers-1.42.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.5.dist-info/RECORD,,