x-transformers 1.43.5__py3-none-any.whl → 1.44.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +42 -10
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.1.dist-info}/METADATA +1 -1
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.1.dist-info}/RECORD +6 -6
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.1.dist-info}/WHEEL +1 -1
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.1.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -2,14 +2,16 @@ from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
3
3
|
|
4
4
|
import math
|
5
|
+
from copy import deepcopy
|
5
6
|
from random import random, randrange
|
6
7
|
from packaging import version
|
7
8
|
|
8
9
|
import torch
|
10
|
+
from torch.amp import autocast
|
9
11
|
import torch.nn.functional as F
|
10
12
|
from torch import nn, einsum, Tensor
|
13
|
+
from torch.utils._pytree import tree_flatten
|
11
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
12
|
-
from torch.amp import autocast
|
13
15
|
|
14
16
|
from functools import partial, wraps
|
15
17
|
from collections import namedtuple
|
@@ -1136,6 +1138,8 @@ class Attention(Module):
|
|
1136
1138
|
sigmoid = False,
|
1137
1139
|
selective = False,
|
1138
1140
|
custom_attn_fn: Callable | None = None,
|
1141
|
+
hybrid_module: Module | None = None,
|
1142
|
+
hybrid_mask_kwarg: str | None = None,
|
1139
1143
|
one_kv_head = False,
|
1140
1144
|
kv_heads = None,
|
1141
1145
|
shared_kv = False,
|
@@ -1335,6 +1339,12 @@ class Attention(Module):
|
|
1335
1339
|
|
1336
1340
|
self.attn_on_attn = on_attn
|
1337
1341
|
|
1342
|
+
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1343
|
+
|
1344
|
+
self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
|
1345
|
+
|
1346
|
+
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1347
|
+
|
1338
1348
|
# output dimension by default same as input, but can be overridden
|
1339
1349
|
|
1340
1350
|
dim_out = default(dim_out, dim)
|
@@ -1407,6 +1417,16 @@ class Attention(Module):
|
|
1407
1417
|
value_residual_mix = self.to_value_residual_mix(q_input)
|
1408
1418
|
v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
|
1409
1419
|
|
1420
|
+
# qk normalization
|
1421
|
+
|
1422
|
+
if self.qk_norm:
|
1423
|
+
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1424
|
+
q, k = map(qk_l2norm, (q, k))
|
1425
|
+
scale = self.qk_norm_scale
|
1426
|
+
|
1427
|
+
q = q * self.qk_norm_q_scale
|
1428
|
+
k = k * self.qk_norm_k_scale
|
1429
|
+
|
1410
1430
|
# take care of caching
|
1411
1431
|
|
1412
1432
|
if exists(cache):
|
@@ -1427,14 +1447,6 @@ class Attention(Module):
|
|
1427
1447
|
mem_len = mem.shape[-2] if exists(mem) else 0
|
1428
1448
|
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
|
1429
1449
|
|
1430
|
-
if self.qk_norm:
|
1431
|
-
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1432
|
-
q, k = map(qk_l2norm, (q, k))
|
1433
|
-
scale = self.qk_norm_scale
|
1434
|
-
|
1435
|
-
q = q * self.qk_norm_q_scale
|
1436
|
-
k = k * self.qk_norm_k_scale
|
1437
|
-
|
1438
1450
|
if exists(rotary_pos_emb):
|
1439
1451
|
freqs, xpos_scale = rotary_pos_emb
|
1440
1452
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
@@ -1581,6 +1593,26 @@ class Attention(Module):
|
|
1581
1593
|
|
1582
1594
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
1583
1595
|
|
1596
|
+
# hybrid module
|
1597
|
+
|
1598
|
+
if exists(self.hybrid_module):
|
1599
|
+
|
1600
|
+
# hybrid input
|
1601
|
+
|
1602
|
+
hybrid_forward_kwargs = dict()
|
1603
|
+
|
1604
|
+
if not self.causal and exists(self.hybrid_mask_kwarg):
|
1605
|
+
hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
|
1606
|
+
|
1607
|
+
# hybrid forward
|
1608
|
+
|
1609
|
+
hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
|
1610
|
+
|
1611
|
+
# handle hybrid out
|
1612
|
+
|
1613
|
+
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outs)
|
1614
|
+
out = 0.5 * (out + hybrid_out)
|
1615
|
+
|
1584
1616
|
# alphafold2 styled gating of the values
|
1585
1617
|
|
1586
1618
|
if exists(self.to_v_gate):
|
@@ -2003,7 +2035,7 @@ class AttentionLayers(Module):
|
|
2003
2035
|
|
2004
2036
|
# determine whether can cache kv
|
2005
2037
|
|
2006
|
-
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)
|
2038
|
+
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
2007
2039
|
|
2008
2040
|
def forward(
|
2009
2041
|
self,
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=yjtB4kV4N9mzHdliIM9MjyA6SoMtvpzc2Z4iU6R9_Uc,101859
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
16
|
-
x_transformers-1.
|
12
|
+
x_transformers-1.44.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.1.dist-info/METADATA,sha256=Zw_Rscb4vNZxlKosWSHSQy4EsICF45U58K0hipxydpQ,738
|
14
|
+
x_transformers-1.44.1.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
|
15
|
+
x_transformers-1.44.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|