x-transformers 2.7.4__py3-none-any.whl → 2.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +24 -0
- {x_transformers-2.7.4.dist-info → x_transformers-2.7.6.dist-info}/METADATA +1 -1
- {x_transformers-2.7.4.dist-info → x_transformers-2.7.6.dist-info}/RECORD +5 -5
- {x_transformers-2.7.4.dist-info → x_transformers-2.7.6.dist-info}/WHEEL +0 -0
- {x_transformers-2.7.4.dist-info → x_transformers-2.7.6.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -2462,6 +2462,23 @@ class AttentionLayers(Module):
|
|
2462
2462
|
|
2463
2463
|
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
2464
2464
|
|
2465
|
+
def attn_qk_clip_(
|
2466
|
+
self,
|
2467
|
+
intermediates: LayerIntermediates,
|
2468
|
+
tau = 100.
|
2469
|
+
):
|
2470
|
+
# pairs up the attention intermediates with each attention module and does qk clip proposed by kimi team
|
2471
|
+
|
2472
|
+
layer_and_layer_types = (self.layers, self.layer_types)
|
2473
|
+
|
2474
|
+
attn_layers = [layer for (_, layer, _), layer_type in zip(self.layers, self.layer_types) if layer_type in ('a', 'c')]
|
2475
|
+
attn_intermeds = intermediates.attn_intermediates
|
2476
|
+
|
2477
|
+
assert len(attn_layers) == len(attn_intermeds)
|
2478
|
+
|
2479
|
+
for attn_layer, attn_inter in zip(attn_layers, attn_intermeds):
|
2480
|
+
attn_layer.qk_clip_(attn_inter, tau = tau)
|
2481
|
+
|
2465
2482
|
def forward(
|
2466
2483
|
self,
|
2467
2484
|
x,
|
@@ -3192,6 +3209,13 @@ class TransformerWrapper(Module):
|
|
3192
3209
|
if not isinstance(self.pos_emb, always):
|
3193
3210
|
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
|
3194
3211
|
|
3212
|
+
def attn_qk_clip_(
|
3213
|
+
self,
|
3214
|
+
intermediates: LayerIntermediates,
|
3215
|
+
tau = 100.
|
3216
|
+
):
|
3217
|
+
self.attn_layers.attn_qk_clip_(intermediates, tau = tau)
|
3218
|
+
|
3195
3219
|
def forward(
|
3196
3220
|
self,
|
3197
3221
|
x,
|
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
9
9
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
10
|
x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
|
11
11
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
12
|
-
x_transformers/x_transformers.py,sha256=
|
12
|
+
x_transformers/x_transformers.py,sha256=odnCZAKZKrQLXmpaWhiPVB5elGjt8kerDbO3-yeC-60,124764
|
13
13
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
14
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.7.
|
16
|
-
x_transformers-2.7.
|
17
|
-
x_transformers-2.7.
|
18
|
-
x_transformers-2.7.
|
15
|
+
x_transformers-2.7.6.dist-info/METADATA,sha256=n-AKJXX2Ko3XlehMOv5MojPrFaHdRi4lRkvcGAFOXR4,93739
|
16
|
+
x_transformers-2.7.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
17
|
+
x_transformers-2.7.6.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
18
|
+
x_transformers-2.7.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|