x-transformers 1.41.2__py3-none-any.whl → 1.41.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -480,6 +480,7 @@ class Attend(Module):
480
480
  sim = sim + self.pre_softmax_talking_heads(sim)
481
481
 
482
482
  if exists(attn_bias):
483
+ print(attn_bias.shape)
483
484
  sim = sim + attn_bias
484
485
 
485
486
  if self.softclamp_logits:
@@ -455,11 +455,9 @@ class AlibiPositionalBias(Module):
455
455
  self.register_buffer('slopes', slopes, persistent = False)
456
456
  self.register_buffer('bias', None, persistent = False)
457
457
 
458
- def get_bias(self, i, j, device):
459
- seq_arange = torch.arange(j - i, j, device = device)
460
- context_arange = torch.arange(j, device = device)
461
- bias = -torch.abs(einx.subtract('j, i -> 1 i j', context_arange, seq_arange))
462
- return bias
458
+ @property
459
+ def device(self):
460
+ return next(self.buffers()).device
463
461
 
464
462
  @staticmethod
465
463
  def _get_slopes(heads):
@@ -474,9 +472,21 @@ class AlibiPositionalBias(Module):
474
472
  closest_power_of_2 = 2 ** math.floor(math.log2(heads))
475
473
  return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
476
474
 
477
- @property
478
- def device(self):
479
- return next(self.buffers()).device
475
+ def forward_custom_pos(
476
+ self,
477
+ pos_i: Tensor,
478
+ pos_j: Tensor | None = None
479
+ ):
480
+ h, device = self.total_heads, self.device
481
+
482
+ pos_j = default(pos_j, pos_i)
483
+ bias = -einx.subtract('... j, ... i -> ... 1 i j', pos_j, pos_i).abs()
484
+
485
+ bias = bias * self.slopes
486
+ num_heads_unalibied = h - bias.shape[-3]
487
+ bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
488
+
489
+ return bias
480
490
 
481
491
  def forward(self, i, j):
482
492
  h, device = self.total_heads, self.device
@@ -484,13 +494,15 @@ class AlibiPositionalBias(Module):
484
494
  if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
485
495
  return self.bias[..., -i:, -j:]
486
496
 
487
- bias = self.get_bias(i, j, device)
497
+ seq_arange = torch.arange(j - i, j, device = device)
498
+ context_arange = torch.arange(j, device = device)
499
+ bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()
500
+
488
501
  bias = bias * self.slopes
502
+ num_heads_unalibied = h - bias.shape[-3]
503
+ bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
489
504
 
490
- num_heads_unalibied = h - bias.shape[0]
491
- bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
492
505
  self.register_buffer('bias', bias, persistent = False)
493
-
494
506
  return self.bias
495
507
 
496
508
  class DataDependentAlibi(Module):
@@ -499,7 +511,9 @@ class DataDependentAlibi(Module):
499
511
  def __init__(
500
512
  self,
501
513
  dim,
502
- heads
514
+ heads,
515
+ bias_init = 5.,
516
+ post_log_scale = 1.
503
517
  ):
504
518
  super().__init__()
505
519
 
@@ -511,22 +525,24 @@ class DataDependentAlibi(Module):
511
525
  nn.LogSigmoid()
512
526
  )
513
527
 
514
- nn.init.constant_(linear.bias, 5.)
528
+ nn.init.constant_(linear.bias, bias_init)
529
+ self.post_log_scale = post_log_scale
515
530
 
516
531
  def forward(self, x):
517
- forget_gates = self.to_forget_gates(x)
532
+ forget_gates = self.to_forget_gates(x) * self.post_log_scale
518
533
  forget_gates = forget_gates.cumsum(dim = -1)
519
534
  forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
520
535
  return forget_gates
521
536
 
522
537
  class PerRowDataDependentAlibi(Module):
523
- """ same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queris and keys with a small head dimension """
538
+ """ same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queries and keys with a small head dimension """
524
539
 
525
540
  def __init__(
526
541
  self,
527
542
  dim,
528
543
  heads,
529
- dim_head = 8
544
+ dim_head = 8,
545
+ post_log_scale = 1.
530
546
  ):
531
547
  super().__init__()
532
548
  self.scale = dim_head ** -0.5
@@ -535,14 +551,16 @@ class PerRowDataDependentAlibi(Module):
535
551
 
536
552
  self.to_forget_gates = nn.Sequential(
537
553
  linear,
538
- Rearrange('b n (kv h d) -> kv b h n d', kv = 2, d = dim_head)
554
+ Rearrange('b n (qk h d) -> qk b h n d', qk = 2, d = dim_head)
539
555
  )
540
556
 
557
+ self.post_log_scale = post_log_scale
558
+
541
559
  def forward(self, x):
542
560
  q, k = self.to_forget_gates(x)
543
561
  forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
544
562
 
545
- forget_gates = F.logsigmoid(forget_gates)
563
+ forget_gates = F.logsigmoid(forget_gates) * self.post_log_scale
546
564
 
547
565
  # mask out upper triangle + diagonal
548
566
 
@@ -1010,6 +1028,7 @@ class Attention(Module):
1010
1028
  data_dependent_alibi = False,
1011
1029
  data_dependent_alibi_per_row = False,
1012
1030
  data_dependent_alibi_per_row_dim_head = 8,
1031
+ data_dependent_alibi_kwargs: dict = dict(),
1013
1032
  use_cope = False,
1014
1033
  cope_max_pos = 16,
1015
1034
  cope_soft_onehot_pos = False,
@@ -1127,7 +1146,7 @@ class Attention(Module):
1127
1146
  if data_dependent_alibi_per_row:
1128
1147
  dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
1129
1148
 
1130
- self.data_dependent_alibi = dda_klass(**dda_kwargs)
1149
+ self.data_dependent_alibi = dda_klass(**dda_kwargs, **data_dependent_alibi_kwargs)
1131
1150
 
1132
1151
  # attend class - includes core attention algorithm + talking heads
1133
1152
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.41.2
3
+ Version: 1.41.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
2
+ x_transformers/attend.py,sha256=rByHtOfuCO0br69rOB7oFsHoHrsAefZErE2FKM86q7k,17319
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=5yoN96mTpw3qygZg2plVed5lTOs9p0REtt50lOlllHU,91334
8
+ x_transformers/x_transformers.py,sha256=UhIbFPXjdQsbFBDHVGmV81LGHTD5qwbusDc5kl3F2A4,91987
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.41.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.41.2.dist-info/METADATA,sha256=84WHjWdVHIxkHdHxMEKK5OEssLU80cdj4aGTJQ7SpO4,689
13
- x_transformers-1.41.2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
- x_transformers-1.41.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.41.2.dist-info/RECORD,,
11
+ x_transformers-1.41.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.41.4.dist-info/METADATA,sha256=BWpwY2VILr3LAvjqQjDBm28jpS3RZWvmLrghAa7HeuU,689
13
+ x_transformers-1.41.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
+ x_transformers-1.41.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.41.4.dist-info/RECORD,,