x-transformers 1.30.6__py3-none-any.whl → 1.30.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +9 -0
- x_transformers/x_transformers.py +77 -0
- {x_transformers-1.30.6.dist-info → x_transformers-1.30.8.dist-info}/METADATA +1 -1
- {x_transformers-1.30.6.dist-info → x_transformers-1.30.8.dist-info}/RECORD +7 -7
- {x_transformers-1.30.6.dist-info → x_transformers-1.30.8.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.6.dist-info → x_transformers-1.30.8.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.6.dist-info → x_transformers-1.30.8.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -4,6 +4,7 @@ from functools import partial
|
|
4
4
|
from typing import Tuple
|
5
5
|
|
6
6
|
import torch
|
7
|
+
from torch.nn import Module
|
7
8
|
from torch import nn, einsum, Tensor
|
8
9
|
import torch.nn.functional as F
|
9
10
|
|
@@ -82,6 +83,7 @@ class Attend(nn.Module):
|
|
82
83
|
flash = False,
|
83
84
|
logit_softclamp_value = None,
|
84
85
|
add_zero_kv = False,
|
86
|
+
cope = None,
|
85
87
|
onnxable = False,
|
86
88
|
sdp_kwargs: dict = dict(
|
87
89
|
enable_flash = True,
|
@@ -127,6 +129,10 @@ class Attend(nn.Module):
|
|
127
129
|
|
128
130
|
self.logit_softclamp_value = logit_softclamp_value
|
129
131
|
|
132
|
+
# contextual positional encoding
|
133
|
+
|
134
|
+
self.cope = cope
|
135
|
+
|
130
136
|
# flash attention
|
131
137
|
|
132
138
|
self.flash = flash
|
@@ -318,6 +324,9 @@ class Attend(nn.Module):
|
|
318
324
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
319
325
|
sim = sim.masked_fill(causal_mask, mask_value)
|
320
326
|
|
327
|
+
if exists(self.cope):
|
328
|
+
sim = sim + self.cope(q, sim)
|
329
|
+
|
321
330
|
pre_softmax_attn = sim.clone()
|
322
331
|
|
323
332
|
if exists(self.logit_softclamp_value):
|
x_transformers/x_transformers.py
CHANGED
@@ -304,6 +304,57 @@ class RelativePositionBias(Module):
|
|
304
304
|
bias = rearrange(values, 'i j h -> h i j')
|
305
305
|
return bias * self.scale
|
306
306
|
|
307
|
+
class CoPE(Module):
|
308
|
+
"""
|
309
|
+
Appendix B of https://arxiv.org/abs/2405.18719
|
310
|
+
"""
|
311
|
+
def __init__ (
|
312
|
+
self,
|
313
|
+
dim,
|
314
|
+
max_pos,
|
315
|
+
soft_onehot = True,
|
316
|
+
reverse = True
|
317
|
+
):
|
318
|
+
super () . __init__ ()
|
319
|
+
self.reverse = reverse
|
320
|
+
self.max_pos = max_pos
|
321
|
+
self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
|
322
|
+
|
323
|
+
self.soft_onehot = soft_onehot
|
324
|
+
|
325
|
+
if soft_onehot:
|
326
|
+
self.register_buffer('positions', torch.arange(max_pos))
|
327
|
+
|
328
|
+
def forward(self, query, attn_logits, temp = 5e-2):
|
329
|
+
# compute positions
|
330
|
+
|
331
|
+
gates = attn_logits.sigmoid()
|
332
|
+
|
333
|
+
if self.reverse:
|
334
|
+
pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
|
335
|
+
else:
|
336
|
+
pos = gates.cumsum(dim = -1)
|
337
|
+
|
338
|
+
pos = pos.clamp(max = self.max_pos - 1)
|
339
|
+
|
340
|
+
logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
|
341
|
+
|
342
|
+
if self.soft_onehot:
|
343
|
+
diff_pos = (pos[..., None] - self.positions).abs()
|
344
|
+
soft_onehot_pos = F.softmax(-diff_pos / temp, dim = -1)
|
345
|
+
cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
|
346
|
+
else:
|
347
|
+
# interpolate from integer positions
|
348
|
+
pos_ceil = pos.ceil().long()
|
349
|
+
pos_floor = pos.floor().long()
|
350
|
+
logits_ceil = logits_int.gather(-1, pos_ceil)
|
351
|
+
logits_floor = logits_int.gather(-1, pos_floor)
|
352
|
+
|
353
|
+
w = pos - pos_floor
|
354
|
+
cope_pos_emb = logits_ceil * w + logits_floor * (1 - w)
|
355
|
+
|
356
|
+
return cope_pos_emb
|
357
|
+
|
307
358
|
class DynamicPositionBias(Module):
|
308
359
|
def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
|
309
360
|
super().__init__()
|
@@ -722,6 +773,8 @@ class Attention(Module):
|
|
722
773
|
tensor_product = False, # https://arxiv.org/abs/2208.06061
|
723
774
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
724
775
|
rotary_embed_values = False,
|
776
|
+
use_cope = False,
|
777
|
+
cope_max_pos = 16,
|
725
778
|
logit_softclamp_value = None,
|
726
779
|
onnxable = False
|
727
780
|
):
|
@@ -753,13 +806,16 @@ class Attention(Module):
|
|
753
806
|
self.to_k = nn.Linear(dim_kv, k_dim, bias = False)
|
754
807
|
|
755
808
|
# shared key / values, for further memory savings during inference
|
809
|
+
|
756
810
|
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
|
757
811
|
self.to_v = nn.Linear(dim_kv, v_dim, bias = False) if not shared_kv else None
|
758
812
|
|
759
813
|
# relations projection from tp-attention
|
814
|
+
|
760
815
|
self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
|
761
816
|
|
762
817
|
# add GLU gating for aggregated values, from alphafold2
|
818
|
+
|
763
819
|
self.to_v_gate = None
|
764
820
|
if gate_values:
|
765
821
|
self.to_v_gate = nn.Linear(dim, out_dim)
|
@@ -768,6 +824,7 @@ class Attention(Module):
|
|
768
824
|
nn.init.constant_(self.to_v_gate.bias, 10)
|
769
825
|
|
770
826
|
# add per head gating of the output values, from 'Attend to nothing' paper
|
827
|
+
|
771
828
|
self.to_v_head_gate = None
|
772
829
|
if gate_value_heads:
|
773
830
|
self.to_v_head_gate = nn.Linear(dim, heads)
|
@@ -775,11 +832,13 @@ class Attention(Module):
|
|
775
832
|
nn.init.constant_(self.to_v_head_gate.bias, 10)
|
776
833
|
|
777
834
|
# cosine sim attention
|
835
|
+
|
778
836
|
self.qk_norm = qk_norm
|
779
837
|
self.qk_norm_groups = qk_norm_groups
|
780
838
|
self.qk_norm_scale = qk_norm_scale
|
781
839
|
|
782
840
|
# whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
|
841
|
+
|
783
842
|
self.qk_norm_dim_scale = qk_norm_dim_scale
|
784
843
|
|
785
844
|
self.qk_norm_q_scale = self.qk_norm_k_scale = 1
|
@@ -790,6 +849,17 @@ class Attention(Module):
|
|
790
849
|
assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
|
791
850
|
assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
|
792
851
|
|
852
|
+
# contextual positional encoding
|
853
|
+
# https://arxiv.org/html/2405.18719v2
|
854
|
+
|
855
|
+
cope = None
|
856
|
+
|
857
|
+
if use_cope:
|
858
|
+
assert causal, 'CoPE was designed for causal attention'
|
859
|
+
assert not flash, 'CoPE is not flash attention compatible'
|
860
|
+
|
861
|
+
cope = CoPE(dim_head, cope_max_pos)
|
862
|
+
|
793
863
|
# attend class - includes core attention algorithm + talking heads
|
794
864
|
|
795
865
|
self.attend = Attend(
|
@@ -803,31 +873,38 @@ class Attention(Module):
|
|
803
873
|
add_zero_kv = add_zero_kv,
|
804
874
|
flash = flash,
|
805
875
|
logit_softclamp_value = logit_softclamp_value,
|
876
|
+
cope = cope,
|
806
877
|
onnxable = onnxable
|
807
878
|
)
|
808
879
|
|
809
880
|
# head scaling
|
881
|
+
|
810
882
|
self.head_scale = head_scale
|
811
883
|
if head_scale:
|
812
884
|
self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
|
813
885
|
|
814
886
|
# explicit topk sparse attention
|
887
|
+
|
815
888
|
self.sparse_topk = sparse_topk
|
816
889
|
|
817
890
|
# add memory key / values
|
891
|
+
|
818
892
|
self.num_mem_kv = num_mem_kv
|
819
893
|
if num_mem_kv > 0:
|
820
894
|
self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
|
821
895
|
self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
|
822
896
|
|
823
897
|
# attention on attention
|
898
|
+
|
824
899
|
self.attn_on_attn = on_attn
|
825
900
|
self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
|
826
901
|
|
827
902
|
# whether to rotate positions into values, for absolute positions in addition to relative
|
903
|
+
|
828
904
|
self.rotary_embed_values = rotary_embed_values
|
829
905
|
|
830
906
|
# init output projection 0
|
907
|
+
|
831
908
|
if zero_init_output:
|
832
909
|
init_zero_(self.to_out)
|
833
910
|
|
@@ -1,14 +1,14 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=ap2QkD-bRadFE9ZFQP84Lo1P2DpLOXPam24Jq9ybpPY,10903
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=5cof7yvOAfFviLh-luafmhtTJDemCPoy9rHHYjWxLu4,68338
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.8.dist-info/METADATA,sha256=2L0SfGhrbLMjpRKLwTp1_YH1Amu3g_j1nEuWIuGNqrQ,661
|
12
|
+
x_transformers-1.30.8.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|