x-transformers 1.30.4__py3-none-any.whl → 1.30.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +10 -0
- x_transformers/x_transformers.py +54 -0
- {x_transformers-1.30.4.dist-info → x_transformers-1.30.7.dist-info}/METADATA +1 -1
- {x_transformers-1.30.4.dist-info → x_transformers-1.30.7.dist-info}/RECORD +7 -7
- {x_transformers-1.30.4.dist-info → x_transformers-1.30.7.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.4.dist-info → x_transformers-1.30.7.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.4.dist-info → x_transformers-1.30.7.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -4,6 +4,7 @@ from functools import partial
|
|
4
4
|
from typing import Tuple
|
5
5
|
|
6
6
|
import torch
|
7
|
+
from torch.nn import Module
|
7
8
|
from torch import nn, einsum, Tensor
|
8
9
|
import torch.nn.functional as F
|
9
10
|
|
@@ -22,6 +23,7 @@ class Intermediates:
|
|
22
23
|
pre_softmax_attn: Tensor | None = None
|
23
24
|
post_softmax_attn: Tensor | None = None
|
24
25
|
cached_kv: Tuple[Tensor, Tensor] | None = None
|
26
|
+
layer_type: str | None = None
|
25
27
|
|
26
28
|
def to_tuple(self):
|
27
29
|
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
|
@@ -81,6 +83,7 @@ class Attend(nn.Module):
|
|
81
83
|
flash = False,
|
82
84
|
logit_softclamp_value = None,
|
83
85
|
add_zero_kv = False,
|
86
|
+
cope = None,
|
84
87
|
onnxable = False,
|
85
88
|
sdp_kwargs: dict = dict(
|
86
89
|
enable_flash = True,
|
@@ -126,6 +129,10 @@ class Attend(nn.Module):
|
|
126
129
|
|
127
130
|
self.logit_softclamp_value = logit_softclamp_value
|
128
131
|
|
132
|
+
# contextual positional encoding
|
133
|
+
|
134
|
+
self.cope = cope
|
135
|
+
|
129
136
|
# flash attention
|
130
137
|
|
131
138
|
self.flash = flash
|
@@ -317,6 +324,9 @@ class Attend(nn.Module):
|
|
317
324
|
causal_mask = self.create_causal_mask(i, j, device = device)
|
318
325
|
sim = sim.masked_fill(causal_mask, mask_value)
|
319
326
|
|
327
|
+
if exists(self.cope):
|
328
|
+
sim = sim + self.cope(q, sim)
|
329
|
+
|
320
330
|
pre_softmax_attn = sim.clone()
|
321
331
|
|
322
332
|
if exists(self.logit_softclamp_value):
|
x_transformers/x_transformers.py
CHANGED
@@ -304,6 +304,33 @@ class RelativePositionBias(Module):
|
|
304
304
|
bias = rearrange(values, 'i j h -> h i j')
|
305
305
|
return bias * self.scale
|
306
306
|
|
307
|
+
class CoPE(Module):
|
308
|
+
"""
|
309
|
+
Appendix B of https://arxiv.org/abs/2405.18719
|
310
|
+
"""
|
311
|
+
def __init__ (self, dim, max_pos):
|
312
|
+
super () . __init__ ()
|
313
|
+
self.max_pos = max_pos
|
314
|
+
self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
|
315
|
+
|
316
|
+
def forward(self, query, attn_logits):
|
317
|
+
# compute positions
|
318
|
+
|
319
|
+
gates = attn_logits.sigmoid()
|
320
|
+
pos = gates.flip(-1).cumsum(dim = -1).flip(-1)
|
321
|
+
pos = pos.clamp(max = self.max_pos - 1)
|
322
|
+
|
323
|
+
# interpolate from integer positions
|
324
|
+
|
325
|
+
pos_ceil = pos.ceil().long()
|
326
|
+
pos_floor = pos.floor().long()
|
327
|
+
logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
|
328
|
+
logits_ceil = logits_int.gather(-1, pos_ceil)
|
329
|
+
logits_floor = logits_int.gather(-1, pos_floor)
|
330
|
+
|
331
|
+
w = pos - pos_floor
|
332
|
+
return logits_ceil * w + logits_floor * (1 - w)
|
333
|
+
|
307
334
|
class DynamicPositionBias(Module):
|
308
335
|
def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
|
309
336
|
super().__init__()
|
@@ -722,6 +749,8 @@ class Attention(Module):
|
|
722
749
|
tensor_product = False, # https://arxiv.org/abs/2208.06061
|
723
750
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
724
751
|
rotary_embed_values = False,
|
752
|
+
use_cope = False,
|
753
|
+
cope_max_pos = 16,
|
725
754
|
logit_softclamp_value = None,
|
726
755
|
onnxable = False
|
727
756
|
):
|
@@ -753,13 +782,16 @@ class Attention(Module):
|
|
753
782
|
self.to_k = nn.Linear(dim_kv, k_dim, bias = False)
|
754
783
|
|
755
784
|
# shared key / values, for further memory savings during inference
|
785
|
+
|
756
786
|
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
|
757
787
|
self.to_v = nn.Linear(dim_kv, v_dim, bias = False) if not shared_kv else None
|
758
788
|
|
759
789
|
# relations projection from tp-attention
|
790
|
+
|
760
791
|
self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
|
761
792
|
|
762
793
|
# add GLU gating for aggregated values, from alphafold2
|
794
|
+
|
763
795
|
self.to_v_gate = None
|
764
796
|
if gate_values:
|
765
797
|
self.to_v_gate = nn.Linear(dim, out_dim)
|
@@ -768,6 +800,7 @@ class Attention(Module):
|
|
768
800
|
nn.init.constant_(self.to_v_gate.bias, 10)
|
769
801
|
|
770
802
|
# add per head gating of the output values, from 'Attend to nothing' paper
|
803
|
+
|
771
804
|
self.to_v_head_gate = None
|
772
805
|
if gate_value_heads:
|
773
806
|
self.to_v_head_gate = nn.Linear(dim, heads)
|
@@ -775,11 +808,13 @@ class Attention(Module):
|
|
775
808
|
nn.init.constant_(self.to_v_head_gate.bias, 10)
|
776
809
|
|
777
810
|
# cosine sim attention
|
811
|
+
|
778
812
|
self.qk_norm = qk_norm
|
779
813
|
self.qk_norm_groups = qk_norm_groups
|
780
814
|
self.qk_norm_scale = qk_norm_scale
|
781
815
|
|
782
816
|
# whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
|
817
|
+
|
783
818
|
self.qk_norm_dim_scale = qk_norm_dim_scale
|
784
819
|
|
785
820
|
self.qk_norm_q_scale = self.qk_norm_k_scale = 1
|
@@ -790,6 +825,17 @@ class Attention(Module):
|
|
790
825
|
assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
|
791
826
|
assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
|
792
827
|
|
828
|
+
# contextual positional encoding
|
829
|
+
# https://arxiv.org/html/2405.18719v2
|
830
|
+
|
831
|
+
cope = None
|
832
|
+
|
833
|
+
if use_cope:
|
834
|
+
assert causal, 'CoPE was designed for causal attention'
|
835
|
+
assert not flash, 'CoPE is not flash attention compatible'
|
836
|
+
|
837
|
+
cope = CoPE(dim_head, cope_max_pos)
|
838
|
+
|
793
839
|
# attend class - includes core attention algorithm + talking heads
|
794
840
|
|
795
841
|
self.attend = Attend(
|
@@ -803,31 +849,38 @@ class Attention(Module):
|
|
803
849
|
add_zero_kv = add_zero_kv,
|
804
850
|
flash = flash,
|
805
851
|
logit_softclamp_value = logit_softclamp_value,
|
852
|
+
cope = cope,
|
806
853
|
onnxable = onnxable
|
807
854
|
)
|
808
855
|
|
809
856
|
# head scaling
|
857
|
+
|
810
858
|
self.head_scale = head_scale
|
811
859
|
if head_scale:
|
812
860
|
self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
|
813
861
|
|
814
862
|
# explicit topk sparse attention
|
863
|
+
|
815
864
|
self.sparse_topk = sparse_topk
|
816
865
|
|
817
866
|
# add memory key / values
|
867
|
+
|
818
868
|
self.num_mem_kv = num_mem_kv
|
819
869
|
if num_mem_kv > 0:
|
820
870
|
self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
|
821
871
|
self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
|
822
872
|
|
823
873
|
# attention on attention
|
874
|
+
|
824
875
|
self.attn_on_attn = on_attn
|
825
876
|
self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
|
826
877
|
|
827
878
|
# whether to rotate positions into values, for absolute positions in addition to relative
|
879
|
+
|
828
880
|
self.rotary_embed_values = rotary_embed_values
|
829
881
|
|
830
882
|
# init output projection 0
|
883
|
+
|
831
884
|
if zero_init_output:
|
832
885
|
init_zero_(self.to_out)
|
833
886
|
|
@@ -1410,6 +1463,7 @@ class AttentionLayers(Module):
|
|
1410
1463
|
x = residual_fn(out, inner_residual)
|
1411
1464
|
|
1412
1465
|
if layer_type in ('a', 'c') and return_hiddens:
|
1466
|
+
inter.layer_type = layer_type
|
1413
1467
|
intermediates.append(inter)
|
1414
1468
|
|
1415
1469
|
if layer_type == 'a' and self.residual_attn:
|
@@ -1,14 +1,14 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=ap2QkD-bRadFE9ZFQP84Lo1P2DpLOXPam24Jq9ybpPY,10903
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=r9F_LLp5bQyAlue3bBTRwoRx02noTCh4ICF8oWCw1wE,67657
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.7.dist-info/METADATA,sha256=_TniCg2s6tlimpfzpWeMCsCMOjsoYwUObBiXFdY-JhA,661
|
12
|
+
x_transformers-1.30.7.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|