x-transformers 1.41.2__py3-none-any.whl → 1.41.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -499,7 +499,9 @@ class DataDependentAlibi(Module):
499
499
  def __init__(
500
500
  self,
501
501
  dim,
502
- heads
502
+ heads,
503
+ bias_init = 5.,
504
+ post_log_scale = 1.
503
505
  ):
504
506
  super().__init__()
505
507
 
@@ -511,22 +513,24 @@ class DataDependentAlibi(Module):
511
513
  nn.LogSigmoid()
512
514
  )
513
515
 
514
- nn.init.constant_(linear.bias, 5.)
516
+ nn.init.constant_(linear.bias, bias_init)
517
+ self.post_log_scale = post_log_scale
515
518
 
516
519
  def forward(self, x):
517
- forget_gates = self.to_forget_gates(x)
520
+ forget_gates = self.to_forget_gates(x) * self.post_log_scale
518
521
  forget_gates = forget_gates.cumsum(dim = -1)
519
522
  forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
520
523
  return forget_gates
521
524
 
522
525
  class PerRowDataDependentAlibi(Module):
523
- """ same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queris and keys with a small head dimension """
526
+ """ same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queries and keys with a small head dimension """
524
527
 
525
528
  def __init__(
526
529
  self,
527
530
  dim,
528
531
  heads,
529
- dim_head = 8
532
+ dim_head = 8,
533
+ post_log_scale = 1.
530
534
  ):
531
535
  super().__init__()
532
536
  self.scale = dim_head ** -0.5
@@ -535,14 +539,16 @@ class PerRowDataDependentAlibi(Module):
535
539
 
536
540
  self.to_forget_gates = nn.Sequential(
537
541
  linear,
538
- Rearrange('b n (kv h d) -> kv b h n d', kv = 2, d = dim_head)
542
+ Rearrange('b n (qk h d) -> qk b h n d', qk = 2, d = dim_head)
539
543
  )
540
544
 
545
+ self.post_log_scale = post_log_scale
546
+
541
547
  def forward(self, x):
542
548
  q, k = self.to_forget_gates(x)
543
549
  forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
544
550
 
545
- forget_gates = F.logsigmoid(forget_gates)
551
+ forget_gates = F.logsigmoid(forget_gates) * self.post_log_scale
546
552
 
547
553
  # mask out upper triangle + diagonal
548
554
 
@@ -1010,6 +1016,7 @@ class Attention(Module):
1010
1016
  data_dependent_alibi = False,
1011
1017
  data_dependent_alibi_per_row = False,
1012
1018
  data_dependent_alibi_per_row_dim_head = 8,
1019
+ data_dependent_alibi_kwargs: dict = dict(),
1013
1020
  use_cope = False,
1014
1021
  cope_max_pos = 16,
1015
1022
  cope_soft_onehot_pos = False,
@@ -1127,7 +1134,7 @@ class Attention(Module):
1127
1134
  if data_dependent_alibi_per_row:
1128
1135
  dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
1129
1136
 
1130
- self.data_dependent_alibi = dda_klass(**dda_kwargs)
1137
+ self.data_dependent_alibi = dda_klass(**dda_kwargs, **data_dependent_alibi_kwargs)
1131
1138
 
1132
1139
  # attend class - includes core attention algorithm + talking heads
1133
1140
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.41.2
3
+ Version: 1.41.3
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=5yoN96mTpw3qygZg2plVed5lTOs9p0REtt50lOlllHU,91334
8
+ x_transformers/x_transformers.py,sha256=aaxMw4iJkLAh8D4W0g8EwWDfIJVPYdhpmj5T9ZFa0js,91642
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.41.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.41.2.dist-info/METADATA,sha256=84WHjWdVHIxkHdHxMEKK5OEssLU80cdj4aGTJQ7SpO4,689
13
- x_transformers-1.41.2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
- x_transformers-1.41.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.41.2.dist-info/RECORD,,
11
+ x_transformers-1.41.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.41.3.dist-info/METADATA,sha256=GyMdKP9ErEnGHbBqY3MwpSGSE41lK_KQbeE5n_O2Py4,689
13
+ x_transformers-1.41.3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
+ x_transformers-1.41.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.41.3.dist-info/RECORD,,