x-transformers 1.41.2__py3-none-any.whl → 1.41.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +1 -0
- x_transformers/x_transformers.py +39 -20
- {x_transformers-1.41.2.dist-info → x_transformers-1.41.4.dist-info}/METADATA +1 -1
- {x_transformers-1.41.2.dist-info → x_transformers-1.41.4.dist-info}/RECORD +7 -7
- {x_transformers-1.41.2.dist-info → x_transformers-1.41.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.41.2.dist-info → x_transformers-1.41.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.41.2.dist-info → x_transformers-1.41.4.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
x_transformers/x_transformers.py
CHANGED
@@ -455,11 +455,9 @@ class AlibiPositionalBias(Module):
|
|
455
455
|
self.register_buffer('slopes', slopes, persistent = False)
|
456
456
|
self.register_buffer('bias', None, persistent = False)
|
457
457
|
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
bias = -torch.abs(einx.subtract('j, i -> 1 i j', context_arange, seq_arange))
|
462
|
-
return bias
|
458
|
+
@property
|
459
|
+
def device(self):
|
460
|
+
return next(self.buffers()).device
|
463
461
|
|
464
462
|
@staticmethod
|
465
463
|
def _get_slopes(heads):
|
@@ -474,9 +472,21 @@ class AlibiPositionalBias(Module):
|
|
474
472
|
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
|
475
473
|
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
|
476
474
|
|
477
|
-
|
478
|
-
|
479
|
-
|
475
|
+
def forward_custom_pos(
|
476
|
+
self,
|
477
|
+
pos_i: Tensor,
|
478
|
+
pos_j: Tensor | None = None
|
479
|
+
):
|
480
|
+
h, device = self.total_heads, self.device
|
481
|
+
|
482
|
+
pos_j = default(pos_j, pos_i)
|
483
|
+
bias = -einx.subtract('... j, ... i -> ... 1 i j', pos_j, pos_i).abs()
|
484
|
+
|
485
|
+
bias = bias * self.slopes
|
486
|
+
num_heads_unalibied = h - bias.shape[-3]
|
487
|
+
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
|
488
|
+
|
489
|
+
return bias
|
480
490
|
|
481
491
|
def forward(self, i, j):
|
482
492
|
h, device = self.total_heads, self.device
|
@@ -484,13 +494,15 @@ class AlibiPositionalBias(Module):
|
|
484
494
|
if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
|
485
495
|
return self.bias[..., -i:, -j:]
|
486
496
|
|
487
|
-
|
497
|
+
seq_arange = torch.arange(j - i, j, device = device)
|
498
|
+
context_arange = torch.arange(j, device = device)
|
499
|
+
bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()
|
500
|
+
|
488
501
|
bias = bias * self.slopes
|
502
|
+
num_heads_unalibied = h - bias.shape[-3]
|
503
|
+
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = -3)
|
489
504
|
|
490
|
-
num_heads_unalibied = h - bias.shape[0]
|
491
|
-
bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
|
492
505
|
self.register_buffer('bias', bias, persistent = False)
|
493
|
-
|
494
506
|
return self.bias
|
495
507
|
|
496
508
|
class DataDependentAlibi(Module):
|
@@ -499,7 +511,9 @@ class DataDependentAlibi(Module):
|
|
499
511
|
def __init__(
|
500
512
|
self,
|
501
513
|
dim,
|
502
|
-
heads
|
514
|
+
heads,
|
515
|
+
bias_init = 5.,
|
516
|
+
post_log_scale = 1.
|
503
517
|
):
|
504
518
|
super().__init__()
|
505
519
|
|
@@ -511,22 +525,24 @@ class DataDependentAlibi(Module):
|
|
511
525
|
nn.LogSigmoid()
|
512
526
|
)
|
513
527
|
|
514
|
-
nn.init.constant_(linear.bias,
|
528
|
+
nn.init.constant_(linear.bias, bias_init)
|
529
|
+
self.post_log_scale = post_log_scale
|
515
530
|
|
516
531
|
def forward(self, x):
|
517
|
-
forget_gates = self.to_forget_gates(x)
|
532
|
+
forget_gates = self.to_forget_gates(x) * self.post_log_scale
|
518
533
|
forget_gates = forget_gates.cumsum(dim = -1)
|
519
534
|
forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
|
520
535
|
return forget_gates
|
521
536
|
|
522
537
|
class PerRowDataDependentAlibi(Module):
|
523
|
-
""" same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a
|
538
|
+
""" same as data dependent alibi from forgetting transformer, but the forgetting gates are also derived by a queries and keys with a small head dimension """
|
524
539
|
|
525
540
|
def __init__(
|
526
541
|
self,
|
527
542
|
dim,
|
528
543
|
heads,
|
529
|
-
dim_head = 8
|
544
|
+
dim_head = 8,
|
545
|
+
post_log_scale = 1.
|
530
546
|
):
|
531
547
|
super().__init__()
|
532
548
|
self.scale = dim_head ** -0.5
|
@@ -535,14 +551,16 @@ class PerRowDataDependentAlibi(Module):
|
|
535
551
|
|
536
552
|
self.to_forget_gates = nn.Sequential(
|
537
553
|
linear,
|
538
|
-
Rearrange('b n (
|
554
|
+
Rearrange('b n (qk h d) -> qk b h n d', qk = 2, d = dim_head)
|
539
555
|
)
|
540
556
|
|
557
|
+
self.post_log_scale = post_log_scale
|
558
|
+
|
541
559
|
def forward(self, x):
|
542
560
|
q, k = self.to_forget_gates(x)
|
543
561
|
forget_gates = einsum('... i d, ... j d -> ... i j', q, k) * self.scale
|
544
562
|
|
545
|
-
forget_gates = F.logsigmoid(forget_gates)
|
563
|
+
forget_gates = F.logsigmoid(forget_gates) * self.post_log_scale
|
546
564
|
|
547
565
|
# mask out upper triangle + diagonal
|
548
566
|
|
@@ -1010,6 +1028,7 @@ class Attention(Module):
|
|
1010
1028
|
data_dependent_alibi = False,
|
1011
1029
|
data_dependent_alibi_per_row = False,
|
1012
1030
|
data_dependent_alibi_per_row_dim_head = 8,
|
1031
|
+
data_dependent_alibi_kwargs: dict = dict(),
|
1013
1032
|
use_cope = False,
|
1014
1033
|
cope_max_pos = 16,
|
1015
1034
|
cope_soft_onehot_pos = False,
|
@@ -1127,7 +1146,7 @@ class Attention(Module):
|
|
1127
1146
|
if data_dependent_alibi_per_row:
|
1128
1147
|
dda_kwargs.update(dim_head = data_dependent_alibi_per_row_dim_head)
|
1129
1148
|
|
1130
|
-
self.data_dependent_alibi = dda_klass(**dda_kwargs)
|
1149
|
+
self.data_dependent_alibi = dda_klass(**dda_kwargs, **data_dependent_alibi_kwargs)
|
1131
1150
|
|
1132
1151
|
# attend class - includes core attention algorithm + talking heads
|
1133
1152
|
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=rByHtOfuCO0br69rOB7oFsHoHrsAefZErE2FKM86q7k,17319
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=UhIbFPXjdQsbFBDHVGmV81LGHTD5qwbusDc5kl3F2A4,91987
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.41.
|
12
|
-
x_transformers-1.41.
|
13
|
-
x_transformers-1.41.
|
14
|
-
x_transformers-1.41.
|
15
|
-
x_transformers-1.41.
|
11
|
+
x_transformers-1.41.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.41.4.dist-info/METADATA,sha256=BWpwY2VILr3LAvjqQjDBm28jpS3RZWvmLrghAa7HeuU,689
|
13
|
+
x_transformers-1.41.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
14
|
+
x_transformers-1.41.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.41.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|