x-transformers 1.38.2__py3-none-any.whl → 1.39.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +14 -11
- x_transformers/x_transformers.py +13 -3
- {x_transformers-1.38.2.dist-info → x_transformers-1.39.0.dist-info}/METADATA +1 -1
- {x_transformers-1.38.2.dist-info → x_transformers-1.39.0.dist-info}/RECORD +7 -7
- {x_transformers-1.38.2.dist-info → x_transformers-1.39.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.38.2.dist-info → x_transformers-1.39.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.38.2.dist-info → x_transformers-1.39.0.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -67,7 +67,7 @@ print_once = once(print)
|
|
67
67
|
|
68
68
|
# selective attention
|
69
69
|
# https://arxiv.org/abs/2410.02703 - section 3.3
|
70
|
-
# it is a
|
70
|
+
# it is a technique to allow each token to prevent itself from being attended to by future tokens
|
71
71
|
# if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
|
72
72
|
|
73
73
|
def selective_attn(
|
@@ -81,9 +81,10 @@ def selective_attn(
|
|
81
81
|
gate = F.relu(sim_head_gate) # only positive
|
82
82
|
|
83
83
|
if no_mask_sos:
|
84
|
+
gate = gate.clone()
|
84
85
|
gate[..., -i] = 0.
|
85
86
|
|
86
|
-
eye = torch.eye(
|
87
|
+
eye = torch.eye(i, device = device)
|
87
88
|
|
88
89
|
if j > i:
|
89
90
|
eye = F.pad(eye, (j - i, 0), value = 1.)
|
@@ -127,7 +128,8 @@ class Attend(Module):
|
|
127
128
|
dropout = 0.,
|
128
129
|
causal = False,
|
129
130
|
heads = None,
|
130
|
-
|
131
|
+
pre_talking_heads = True,
|
132
|
+
post_talking_heads = True,
|
131
133
|
sparse_topk = None,
|
132
134
|
scale = None,
|
133
135
|
qk_norm = False,
|
@@ -178,14 +180,15 @@ class Attend(Module):
|
|
178
180
|
|
179
181
|
# talking heads
|
180
182
|
|
181
|
-
assert not (flash and
|
183
|
+
assert not (flash and (pre_talking_heads or post_talking_heads)), 'talking heads not compatible with flash attention'
|
182
184
|
|
183
|
-
self.
|
184
|
-
if
|
185
|
-
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
186
|
-
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
|
185
|
+
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if pre_talking_heads else None
|
186
|
+
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if post_talking_heads else None
|
187
187
|
|
188
|
+
if exists(self.pre_softmax_talking_heads):
|
188
189
|
nn.init.dirac_(self.pre_softmax_talking_heads.weight)
|
190
|
+
|
191
|
+
if exists(self.post_softmax_talking_heads):
|
189
192
|
nn.init.dirac_(self.post_softmax_talking_heads.weight)
|
190
193
|
|
191
194
|
# selective attention
|
@@ -433,8 +436,8 @@ class Attend(Module):
|
|
433
436
|
|
434
437
|
qk_similarities = sim.clone()
|
435
438
|
|
436
|
-
if self.
|
437
|
-
sim = self.pre_softmax_talking_heads(sim)
|
439
|
+
if exists(self.pre_softmax_talking_heads):
|
440
|
+
sim = sim + self.pre_softmax_talking_heads(sim)
|
438
441
|
|
439
442
|
if exists(attn_bias):
|
440
443
|
sim = sim + attn_bias
|
@@ -481,7 +484,7 @@ class Attend(Module):
|
|
481
484
|
|
482
485
|
attn = self.attn_dropout(attn)
|
483
486
|
|
484
|
-
if self.
|
487
|
+
if exists(self.post_softmax_talking_heads):
|
485
488
|
attn = self.post_softmax_talking_heads(attn)
|
486
489
|
|
487
490
|
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
|
x_transformers/x_transformers.py
CHANGED
@@ -907,7 +907,8 @@ class Attention(Module):
|
|
907
907
|
heads = 8,
|
908
908
|
causal = False,
|
909
909
|
flash = False,
|
910
|
-
|
910
|
+
pre_talking_heads = False,
|
911
|
+
post_talking_heads = False,
|
911
912
|
head_scale = False,
|
912
913
|
sparse_topk = None,
|
913
914
|
num_mem_kv = 0,
|
@@ -1036,7 +1037,8 @@ class Attention(Module):
|
|
1036
1037
|
self.attend = Attend(
|
1037
1038
|
heads = heads,
|
1038
1039
|
causal = causal,
|
1039
|
-
|
1040
|
+
pre_talking_heads = pre_talking_heads,
|
1041
|
+
post_talking_heads = post_talking_heads,
|
1040
1042
|
dropout = dropout,
|
1041
1043
|
sparse_topk = sparse_topk,
|
1042
1044
|
qk_norm = qk_norm,
|
@@ -1084,6 +1086,10 @@ class Attention(Module):
|
|
1084
1086
|
|
1085
1087
|
self.rotary_embed_values = rotary_embed_values
|
1086
1088
|
|
1089
|
+
# whether parent can kv cache
|
1090
|
+
|
1091
|
+
self.can_cache_kv = not selective
|
1092
|
+
|
1087
1093
|
# init output projection 0
|
1088
1094
|
|
1089
1095
|
if zero_init_output:
|
@@ -1634,6 +1640,10 @@ class AttentionLayers(Module):
|
|
1634
1640
|
residual
|
1635
1641
|
]))
|
1636
1642
|
|
1643
|
+
# determine whether can cache kv
|
1644
|
+
|
1645
|
+
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ])
|
1646
|
+
|
1637
1647
|
def forward(
|
1638
1648
|
self,
|
1639
1649
|
x,
|
@@ -2135,7 +2145,7 @@ class TransformerWrapper(Module):
|
|
2135
2145
|
|
2136
2146
|
# whether can do cached kv decoding
|
2137
2147
|
|
2138
|
-
self.can_cache_kv = self.num_memory_tokens == 0 and not recycling
|
2148
|
+
self.can_cache_kv = self.num_memory_tokens == 0 and not recycling and self.attn_layers.can_cache_kv
|
2139
2149
|
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
|
2140
2150
|
|
2141
2151
|
def init_(self):
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=sTeX7DmUt6I5FhHtgcTDOIvmD1CvJ1PmVjZ_-lYO-QA,15596
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=nfq-EOLx5HWf8tXlmVuDbkhqNfFnfqRMCEkALK3SFkA,84341
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
11
|
+
x_transformers-1.39.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.39.0.dist-info/METADATA,sha256=-ppCMMH6ZTsmwaMJB9q4b4Yvd-nU8v95l5SdzTY17OU,661
|
13
|
+
x_transformers-1.39.0.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.39.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.39.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|