x-transformers 1.41.3__py3-none-any.whl → 1.41.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +1 -0
- x_transformers/x_transformers.py +24 -12
- {x_transformers-1.41.3.dist-info → x_transformers-1.41.4.dist-info}/METADATA +1 -1
- {x_transformers-1.41.3.dist-info → x_transformers-1.41.4.dist-info}/RECORD +7 -7
- {x_transformers-1.41.3.dist-info → x_transformers-1.41.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.41.3.dist-info → x_transformers-1.41.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.41.3.dist-info → x_transformers-1.41.4.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
x_transformers/x_transformers.py
CHANGED
@@ -455,11 +455,9 @@ class AlibiPositionalBias(Module):
|
|
455
455
|
self.register_buffer('slopes', slopes, persistent = False)
|
456
456
|
self.register_buffer('bias', None, persistent = False)
|
457
457
|
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
bias = -torch.abs(einx.subtract('j, i -> 1 i j', context_arange, seq_arange))
|
462
|
-
return bias
|
458
|
+
@property
|
459
|
+
def device(self):
|
460
|
+
return next(self.buffers()).device
|
463
461
|
|
464
462
|
@staticmethod
|
465
463
|
def _get_slopes(heads):
|
@@ -474,9 +472,21 @@ class AlibiPositionalBias(Module):
|
|
474
472
|
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
|
475
473
|
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
|
476
474
|
|
477
|
-
|
478
|
-
|
479
|
-
|
475
|
+
def forward_custom_pos(
|
476
|
+
self,
|
477
|
+
pos_i: Tensor,
|
478
|
+
pos_j: Tensor | None = None
|
479
|
+
):
|
480
|
+
h, device = self.total_heads, self.device
|
481
|
+
|
482
|
+
pos_j = default(pos_j, pos_i)
|
483
|
+
bias = -einx.subtract('... j, ... i -> ... 1 i j', pos_j, pos_i).abs()
|
484
|
+
|
485
|
+
bias = bias * self.slopes
|
486
|
+
num_heads_unalibied = h - bias.shape[-3]
|
487
|
+
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
|
488
|
+
|
489
|
+
return bias
|
480
490
|
|
481
491
|
def forward(self, i, j):
|
482
492
|
h, device = self.total_heads, self.device
|
@@ -484,13 +494,15 @@ class AlibiPositionalBias(Module):
|
|
484
494
|
if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
|
485
495
|
return self.bias[..., -i:, -j:]
|
486
496
|
|
487
|
-
|
497
|
+
seq_arange = torch.arange(j - i, j, device = device)
|
498
|
+
context_arange = torch.arange(j, device = device)
|
499
|
+
bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()
|
500
|
+
|
488
501
|
bias = bias * self.slopes
|
502
|
+
num_heads_unalibied = h - bias.shape[-3]
|
503
|
+
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
|
489
504
|
|
490
|
-
num_heads_unalibied = h - bias.shape[0]
|
491
|
-
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
|
492
505
|
self.register_buffer('bias', bias, persistent = False)
|
493
|
-
|
494
506
|
return self.bias
|
495
507
|
|
496
508
|
class DataDependentAlibi(Module):
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=rByHtOfuCO0br69rOB7oFsHoHrsAefZErE2FKM86q7k,17319
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=UhIbFPXjdQsbFBDHVGmV81LGHTD5qwbusDc5kl3F2A4,91987
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.41.
|
12
|
-
x_transformers-1.41.
|
13
|
-
x_transformers-1.41.
|
14
|
-
x_transformers-1.41.
|
15
|
-
x_transformers-1.41.
|
11
|
+
x_transformers-1.41.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.41.4.dist-info/METADATA,sha256=BWpwY2VILr3LAvjqQjDBm28jpS3RZWvmLrghAa7HeuU,689
|
13
|
+
x_transformers-1.41.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
14
|
+
x_transformers-1.41.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.41.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|