x-transformers 1.41.1__py3-none-any.whl → 1.41.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -514,11 +514,48 @@ class DataDependentAlibi(Module):
514
514
  nn.init.constant_(linear.bias, 5.)
515
515
 
516
516
  def forward(self, x):
517
- seq = x.shape[-2]
518
-
519
517
  forget_gates = self.to_forget_gates(x)
520
518
  forget_gates = forget_gates.cumsum(dim = -1)
521
519
  forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
520
+ return forget_gates
521
+
522
+ class PerRowDataDependentAlibi(Module):
523
+ """ same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queris and keys with a small head dimension """
524
+
525
+ def __init__(
526
+ self,
527
+ dim,
528
+ heads,
529
+ dim_head = 8
530
+ ):
531
+ super().__init__()
532
+ self.scale = dim_head ** -0.5
533
+
534
+ linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
535
+
536
+ self.to_forget_gates = nn.Sequential(
537
+ linear,
538
+ Rearrange('b n (kv h d) -> kv b h n d', kv = 2, d = dim_head)
539
+ )
540
+
541
+ def forward(self, x):
542
+ q, k = self.to_forget_gates(x)
543
+ forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
544
+
545
+ forget_gates = F.logsigmoid(forget_gates)
546
+
547
+ # mask out upper triangle + diagonal
548
+
549
+ n = x.shape[-2]
550
+ causal_mask = torch.ones((n, n), dtype = torch.bool, device = x.device).triu()
551
+
552
+ forget_gates = forget_gates.masked_fill(causal_mask, 0.)
553
+
554
+ # reverse cumsum
555
+
556
+ forget_gates = forget_gates.flip(dims = (-1,))
557
+ forget_gates = forget_gates.cumsum(dim = -1)
558
+ forget_gates = forget_gates.flip(dims = (-1,))
522
559
 
523
560
  return forget_gates
524
561
 
@@ -971,6 +1008,8 @@ class Attention(Module):
971
1008
  add_zero_kv = False, # same as add_zero_attn in pytorch
972
1009
  rotary_embed_values = False,
973
1010
  data_dependent_alibi = False,
1011
+ data_dependent_alibi_per_row = False,
1012
+ data_dependent_alibi_per_row_dim_head = 8,
974
1013
  use_cope = False,
975
1014
  cope_max_pos = 16,
976
1015
  cope_soft_onehot_pos = False,
@@ -1082,10 +1121,13 @@ class Attention(Module):
1082
1121
  if data_dependent_alibi:
1083
1122
  assert causal, 'data dependent alibi only works for autoregressive for now until further research'
1084
1123
 
1085
- self.data_dependent_alibi = DataDependentAlibi(
1086
- dim,
1087
- heads = heads
1088
- )
1124
+ dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
1125
+ dda_kwargs = dict(dim = dim, heads = heads)
1126
+
1127
+ if data_dependent_alibi_per_row:
1128
+ dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
1129
+
1130
+ self.data_dependent_alibi = dda_klass(**dda_kwargs)
1089
1131
 
1090
1132
  # attend class - includes core attention algorithm + talking heads
1091
1133
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.41.1
3
+ Version: 1.41.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=n8W19Pnhbz-JxbC7QATApWrhI_yC4oqTHGQ1NLuindY,89814
8
+ x_transformers/x_transformers.py,sha256=5yoN96mTpw3qygZg2plVed5lTOs9p0REtt50lOlllHU,91334
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.41.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.41.1.dist-info/METADATA,sha256=UIPYbEVBLrWDGuezlnyh2tFKPlM_Mdj-pYTGxse_NMI,689
13
- x_transformers-1.41.1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
- x_transformers-1.41.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.41.1.dist-info/RECORD,,
11
+ x_transformers-1.41.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.41.2.dist-info/METADATA,sha256=84WHjWdVHIxkHdHxMEKK5OEssLU80cdj4aGTJQ7SpO4,689
13
+ x_transformers-1.41.2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
+ x_transformers-1.41.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.41.2.dist-info/RECORD,,