x-transformers 1.38.2__py3-none-any.whl → 1.38.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +3 -2
- x_transformers/x_transformers.py +9 -1
- {x_transformers-1.38.2.dist-info → x_transformers-1.38.3.dist-info}/METADATA +1 -1
- {x_transformers-1.38.2.dist-info → x_transformers-1.38.3.dist-info}/RECORD +7 -7
- {x_transformers-1.38.2.dist-info → x_transformers-1.38.3.dist-info}/LICENSE +0 -0
- {x_transformers-1.38.2.dist-info → x_transformers-1.38.3.dist-info}/WHEEL +0 -0
- {x_transformers-1.38.2.dist-info → x_transformers-1.38.3.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -67,7 +67,7 @@ print_once = once(print)
|
|
67
67
|
|
68
68
|
# selective attention
|
69
69
|
# https://arxiv.org/abs/2410.02703 - section 3.3
|
70
|
-
# it is a
|
70
|
+
# it is a technique to allow each token to prevent itself from being attended to by future tokens
|
71
71
|
# if sim_head_gate not supplied, will use the first head of the attention logits (sim in this framework)
|
72
72
|
|
73
73
|
def selective_attn(
|
@@ -81,9 +81,10 @@ def selective_attn(
|
|
81
81
|
gate = F.relu(sim_head_gate) # only positive
|
82
82
|
|
83
83
|
if no_mask_sos:
|
84
|
+
gate = gate.clone()
|
84
85
|
gate[..., -i] = 0.
|
85
86
|
|
86
|
-
eye = torch.eye(
|
87
|
+
eye = torch.eye(i, device = device)
|
87
88
|
|
88
89
|
if j > i:
|
89
90
|
eye = F.pad(eye, (j - i, 0), value = 1.)
|
x_transformers/x_transformers.py
CHANGED
@@ -1084,6 +1084,10 @@ class Attention(Module):
|
|
1084
1084
|
|
1085
1085
|
self.rotary_embed_values = rotary_embed_values
|
1086
1086
|
|
1087
|
+
# whether parent can kv cache
|
1088
|
+
|
1089
|
+
self.can_cache_kv = not selective
|
1090
|
+
|
1087
1091
|
# init output projection 0
|
1088
1092
|
|
1089
1093
|
if zero_init_output:
|
@@ -1634,6 +1638,10 @@ class AttentionLayers(Module):
|
|
1634
1638
|
residual
|
1635
1639
|
]))
|
1636
1640
|
|
1641
|
+
# determine whether can cache kv
|
1642
|
+
|
1643
|
+
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ])
|
1644
|
+
|
1637
1645
|
def forward(
|
1638
1646
|
self,
|
1639
1647
|
x,
|
@@ -2135,7 +2143,7 @@ class TransformerWrapper(Module):
|
|
2135
2143
|
|
2136
2144
|
# whether can do cached kv decoding
|
2137
2145
|
|
2138
|
-
self.can_cache_kv = self.num_memory_tokens == 0 and not recycling
|
2146
|
+
self.can_cache_kv = self.num_memory_tokens == 0 and not recycling and self.attn_layers.can_cache_kv
|
2139
2147
|
self.can_cache_kv_outside_max_seq_len = no_abs_pos_emb
|
2140
2148
|
|
2141
2149
|
def init_(self):
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=Z9cnbY3f9zCl1yUOMWVQT5_ee4keC0lo_NX4cd0rbKk,15393
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=P3o1R2DY2ic71FlJJ4ie4w_z-g3jIIrkBXcbllRoXHA,84240
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.38.
|
12
|
-
x_transformers-1.38.
|
13
|
-
x_transformers-1.38.
|
14
|
-
x_transformers-1.38.
|
15
|
-
x_transformers-1.38.
|
11
|
+
x_transformers-1.38.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.38.3.dist-info/METADATA,sha256=XbmJ91NWqmOjO0xvyXh9uh82TIwVbU54L3eFmaFVvYs,661
|
13
|
+
x_transformers-1.38.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
14
|
+
x_transformers-1.38.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.38.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|