x-transformers 1.41.1__py3-none-any.whl → 1.41.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -499,7 +499,9 @@ class DataDependentAlibi(Module):
499
499
  def __init__(
500
500
  self,
501
501
  dim,
502
- heads
502
+ heads,
503
+ bias_init = 5.,
504
+ post_log_scale = 1.
503
505
  ):
504
506
  super().__init__()
505
507
 
@@ -511,14 +513,55 @@ class DataDependentAlibi(Module):
511
513
  nn.LogSigmoid()
512
514
  )
513
515
 
514
- nn.init.constant_(linear.bias, 5.)
516
+ nn.init.constant_(linear.bias, bias_init)
517
+ self.post_log_scale = post_log_scale
515
518
 
516
519
  def forward(self, x):
517
- seq = x.shape[-2]
518
-
519
- forget_gates = self.to_forget_gates(x)
520
+ forget_gates = self.to_forget_gates(x) * self.post_log_scale
520
521
  forget_gates = forget_gates.cumsum(dim = -1)
521
522
  forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
523
+ return forget_gates
524
+
525
+ class PerRowDataDependentAlibi(Module):
526
+ """ same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queries and keys with a small head dimension """
527
+
528
+ def __init__(
529
+ self,
530
+ dim,
531
+ heads,
532
+ dim_head = 8,
533
+ post_log_scale = 1.
534
+ ):
535
+ super().__init__()
536
+ self.scale = dim_head ** -0.5
537
+
538
+ linear = nn.Linear(dim, heads * dim_head * 2, bias = False)
539
+
540
+ self.to_forget_gates = nn.Sequential(
541
+ linear,
542
+ Rearrange('b n (qk h d) -> qk b h n d', qk = 2, d = dim_head)
543
+ )
544
+
545
+ self.post_log_scale = post_log_scale
546
+
547
+ def forward(self, x):
548
+ q, k = self.to_forget_gates(x)
549
+ forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
550
+
551
+ forget_gates = F.logsigmoid(forget_gates) * self.post_log_scale
552
+
553
+ # mask out upper triangle + diagonal
554
+
555
+ n = x.shape[-2]
556
+ causal_mask = torch.ones((n, n), dtype = torch.bool, device = x.device).triu()
557
+
558
+ forget_gates = forget_gates.masked_fill(causal_mask, 0.)
559
+
560
+ # reverse cumsum
561
+
562
+ forget_gates = forget_gates.flip(dims = (-1,))
563
+ forget_gates = forget_gates.cumsum(dim = -1)
564
+ forget_gates = forget_gates.flip(dims = (-1,))
522
565
 
523
566
  return forget_gates
524
567
 
@@ -971,6 +1014,9 @@ class Attention(Module):
971
1014
  add_zero_kv = False, # same as add_zero_attn in pytorch
972
1015
  rotary_embed_values = False,
973
1016
  data_dependent_alibi = False,
1017
+ data_dependent_alibi_per_row = False,
1018
+ data_dependent_alibi_per_row_dim_head = 8,
1019
+ data_dependent_alibi_kwargs: dict = dict(),
974
1020
  use_cope = False,
975
1021
  cope_max_pos = 16,
976
1022
  cope_soft_onehot_pos = False,
@@ -1082,10 +1128,13 @@ class Attention(Module):
1082
1128
  if data_dependent_alibi:
1083
1129
  assert causal, 'data dependent alibi only works for autoregressive for now until further research'
1084
1130
 
1085
- self.data_dependent_alibi = DataDependentAlibi(
1086
- dim,
1087
- heads = heads
1088
- )
1131
+ dda_klass = DataDependentAlibi if not data_dependent_alibi_per_row else PerRowDataDependentAlibi
1132
+ dda_kwargs = dict(dim = dim, heads = heads)
1133
+
1134
+ if data_dependent_alibi_per_row:
1135
+ dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
1136
+
1137
+ self.data_dependent_alibi = dda_klass(**dda_kwargs, **data_dependent_alibi_kwargs)
1089
1138
 
1090
1139
  # attend class - includes core attention algorithm + talking heads
1091
1140
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.41.1
3
+ Version: 1.41.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=n8W19Pnhbz-JxbC7QATApWrhI_yC4oqTHGQ1NLuindY,89814
8
+ x_transformers/x_transformers.py,sha256=aaxMw4iJkLAh8D4W0g8EwWDfIJVPYdhpmj5T9ZFa0js,91642
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.41.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.41.1.dist-info/METADATA,sha256=UIPYbEVBLrWDGuezlnyh2tFKPlM_Mdj-pYTGxse_NMI,689
13
- x_transformers-1.41.1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
- x_transformers-1.41.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.41.1.dist-info/RECORD,,
11
+ x_transformers-1.41.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.41.3.dist-info/METADATA,sha256=GyMdKP9ErEnGHbBqY3MwpSGSE41lK_KQbeE5n_O2Py4,689
13
+ x_transformers-1.41.3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
+ x_transformers-1.41.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.41.3.dist-info/RECORD,,