x-transformers 1.41.3__py3-none-any.whl → 1.41.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -480,6 +480,7 @@ class Attend(Module):
480
480
  sim = sim + self.pre_softmax_talking_heads(sim)
481
481
 
482
482
  if exists(attn_bias):
483
+ print(attn_bias.shape)
483
484
  sim = sim + attn_bias
484
485
 
485
486
  if self.softclamp_logits:
@@ -455,11 +455,9 @@ class AlibiPositionalBias(Module):
455
455
  self.register_buffer('slopes', slopes, persistent = False)
456
456
  self.register_buffer('bias', None, persistent = False)
457
457
 
458
- def get_bias(self, i, j, device):
459
- seq_arange = torch.arange(j - i, j, device = device)
460
- context_arange = torch.arange(j, device = device)
461
- bias = -torch.abs(einx.subtract('j, i -> 1 i j', context_arange, seq_arange))
462
- return bias
458
+ @property
459
+ def device(self):
460
+ return next(self.buffers()).device
463
461
 
464
462
  @staticmethod
465
463
  def _get_slopes(heads):
@@ -474,9 +472,21 @@ class AlibiPositionalBias(Module):
474
472
  closest_power_of_2 = 2 ** math.floor(math.log2(heads))
475
473
  return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
476
474
 
477
- @property
478
- def device(self):
479
- return next(self.buffers()).device
475
+ def forward_custom_pos(
476
+ self,
477
+ pos_i: Tensor,
478
+ pos_j: Tensor | None = None
479
+ ):
480
+ h, device = self.total_heads, self.device
481
+
482
+ pos_j = default(pos_j, pos_i)
483
+ bias = -einx.subtract('... j, ... i -> ... 1 i j', pos_j, pos_i).abs()
484
+
485
+ bias = bias * self.slopes
486
+ num_heads_unalibied = h - bias.shape[-3]
487
+ bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
488
+
489
+ return bias
480
490
 
481
491
  def forward(self, i, j):
482
492
  h, device = self.total_heads, self.device
@@ -484,13 +494,15 @@ class AlibiPositionalBias(Module):
484
494
  if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
485
495
  return self.bias[..., -i:, -j:]
486
496
 
487
- bias = self.get_bias(i, j, device)
497
+ seq_arange = torch.arange(j - i, j, device = device)
498
+ context_arange = torch.arange(j, device = device)
499
+ bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()
500
+
488
501
  bias = bias * self.slopes
502
+ num_heads_unalibied = h - bias.shape[-3]
503
+ bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
489
504
 
490
- num_heads_unalibied = h - bias.shape[0]
491
- bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
492
505
  self.register_buffer('bias', bias, persistent = False)
493
-
494
506
  return self.bias
495
507
 
496
508
  class DataDependentAlibi(Module):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.41.3
3
+ Version: 1.41.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
- x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
2
+ x_transformers/attend.py,sha256=rByHtOfuCO0br69rOB7oFsHoHrsAefZErE2FKM86q7k,17319
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=aaxMw4iJkLAh8D4W0g8EwWDfIJVPYdhpmj5T9ZFa0js,91642
8
+ x_transformers/x_transformers.py,sha256=UhIbFPXjdQsbFBDHVGmV81LGHTD5qwbusDc5kl3F2A4,91987
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.41.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.41.3.dist-info/METADATA,sha256=GyMdKP9ErEnGHbBqY3MwpSGSE41lK_KQbeE5n_O2Py4,689
13
- x_transformers-1.41.3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
- x_transformers-1.41.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.41.3.dist-info/RECORD,,
11
+ x_transformers-1.41.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.41.4.dist-info/METADATA,sha256=BWpwY2VILr3LAvjqQjDBm28jpS3RZWvmLrghAa7HeuU,689
13
+ x_transformers-1.41.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
+ x_transformers-1.41.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.41.4.dist-info/RECORD,,