x-transformers 1.43.5__py3-none-any.whl → 1.44.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +23 -9
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.0.dist-info}/METADATA +1 -1
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.0.dist-info}/RECORD +6 -6
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.0.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
3
3
|
|
4
4
|
import math
|
5
|
+
from copy import deepcopy
|
5
6
|
from random import random, randrange
|
6
7
|
from packaging import version
|
7
8
|
|
@@ -1136,6 +1137,7 @@ class Attention(Module):
|
|
1136
1137
|
sigmoid = False,
|
1137
1138
|
selective = False,
|
1138
1139
|
custom_attn_fn: Callable | None = None,
|
1140
|
+
hybrid_module: Module | None = None,
|
1139
1141
|
one_kv_head = False,
|
1140
1142
|
kv_heads = None,
|
1141
1143
|
shared_kv = False,
|
@@ -1335,6 +1337,10 @@ class Attention(Module):
|
|
1335
1337
|
|
1336
1338
|
self.attn_on_attn = on_attn
|
1337
1339
|
|
1340
|
+
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1341
|
+
|
1342
|
+
self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
|
1343
|
+
|
1338
1344
|
# output dimension by default same as input, but can be overridden
|
1339
1345
|
|
1340
1346
|
dim_out = default(dim_out, dim)
|
@@ -1407,6 +1413,16 @@ class Attention(Module):
|
|
1407
1413
|
value_residual_mix = self.to_value_residual_mix(q_input)
|
1408
1414
|
v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
|
1409
1415
|
|
1416
|
+
# qk normalization
|
1417
|
+
|
1418
|
+
if self.qk_norm:
|
1419
|
+
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1420
|
+
q, k = map(qk_l2norm, (q, k))
|
1421
|
+
scale = self.qk_norm_scale
|
1422
|
+
|
1423
|
+
q = q * self.qk_norm_q_scale
|
1424
|
+
k = k * self.qk_norm_k_scale
|
1425
|
+
|
1410
1426
|
# take care of caching
|
1411
1427
|
|
1412
1428
|
if exists(cache):
|
@@ -1427,14 +1443,6 @@ class Attention(Module):
|
|
1427
1443
|
mem_len = mem.shape[-2] if exists(mem) else 0
|
1428
1444
|
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
|
1429
1445
|
|
1430
|
-
if self.qk_norm:
|
1431
|
-
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1432
|
-
q, k = map(qk_l2norm, (q, k))
|
1433
|
-
scale = self.qk_norm_scale
|
1434
|
-
|
1435
|
-
q = q * self.qk_norm_q_scale
|
1436
|
-
k = k * self.qk_norm_k_scale
|
1437
|
-
|
1438
1446
|
if exists(rotary_pos_emb):
|
1439
1447
|
freqs, xpos_scale = rotary_pos_emb
|
1440
1448
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
@@ -1581,6 +1589,12 @@ class Attention(Module):
|
|
1581
1589
|
|
1582
1590
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
1583
1591
|
|
1592
|
+
# hybrid module
|
1593
|
+
|
1594
|
+
if exists(self.hybrid_module):
|
1595
|
+
hybrid_out, _ = self.hybrid_module(x)
|
1596
|
+
out = 0.5 * (out + hybrid_out)
|
1597
|
+
|
1584
1598
|
# alphafold2 styled gating of the values
|
1585
1599
|
|
1586
1600
|
if exists(self.to_v_gate):
|
@@ -2003,7 +2017,7 @@ class AttentionLayers(Module):
|
|
2003
2017
|
|
2004
2018
|
# determine whether can cache kv
|
2005
2019
|
|
2006
|
-
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)
|
2020
|
+
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
2007
2021
|
|
2008
2022
|
def forward(
|
2009
2023
|
self,
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=BI3RU3XFvwSNDZgoQBrFBSJ4SavJr38rOCCVgHZBTx0,101241
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
16
|
-
x_transformers-1.
|
12
|
+
x_transformers-1.44.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.0.dist-info/METADATA,sha256=MNVwW_pDeKEIHRVEA1XOUNfzFmL706X7Npoh7xc3wIk,738
|
14
|
+
x_transformers-1.44.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.44.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|