x-transformers 2.7.4__py3-none-any.whl → 2.7.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +21 -0
- {x_transformers-2.7.4.dist-info → x_transformers-2.7.5.dist-info}/METADATA +1 -1
- {x_transformers-2.7.4.dist-info → x_transformers-2.7.5.dist-info}/RECORD +5 -5
- {x_transformers-2.7.4.dist-info → x_transformers-2.7.5.dist-info}/WHEEL +0 -0
- {x_transformers-2.7.4.dist-info → x_transformers-2.7.5.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -2462,6 +2462,20 @@ class AttentionLayers(Module):
|
|
2462
2462
|
|
2463
2463
|
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
2464
2464
|
|
2465
|
+
def attn_qk_clip_(
|
2466
|
+
self,
|
2467
|
+
intermediates: LayerIntermediates,
|
2468
|
+
tau = 100.
|
2469
|
+
):
|
2470
|
+
# pairs up the attention intermediates with each attention module and does qk clip proposed by kimi team
|
2471
|
+
|
2472
|
+
for (_, layer, _), layer_type, attn_inter in zip(self.layers, self.layer_types, intermediates.attn_intermediates):
|
2473
|
+
|
2474
|
+
if layer_type not in ('a', 'c'):
|
2475
|
+
continue
|
2476
|
+
|
2477
|
+
layer.qk_clip_(attn_inter, tau = tau)
|
2478
|
+
|
2465
2479
|
def forward(
|
2466
2480
|
self,
|
2467
2481
|
x,
|
@@ -3192,6 +3206,13 @@ class TransformerWrapper(Module):
|
|
3192
3206
|
if not isinstance(self.pos_emb, always):
|
3193
3207
|
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
|
3194
3208
|
|
3209
|
+
def attn_qk_clip_(
|
3210
|
+
self,
|
3211
|
+
intermediates: LayerIntermediates,
|
3212
|
+
tau = 100.
|
3213
|
+
):
|
3214
|
+
self.attn_layers.attn_qk_clip_(intermediates, tau = tau)
|
3215
|
+
|
3195
3216
|
def forward(
|
3196
3217
|
self,
|
3197
3218
|
x,
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
12
|
+
x_transformers/x_transformers.py,sha256=xaGBkYCy6CqL0q9icWmL_WzCeU6ZztEYEkMtN71L2z4,124576
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.7.
|
16
|
-
x_transformers-2.7.
|
17
|
-
x_transformers-2.7.
|
18
|
-
x_transformers-2.7.
|
15
|
+
x_transformers-2.7.5.dist-info/METADATA,sha256=m6f4PIgJFKKWlsGAydi_Bg5-7-0IRlor0pRY_zBh5s8,93739
|
16
|
+
x_transformers-2.7.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.7.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.7.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|