x-transformers 1.43.5__py3-none-any.whl → 1.44.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +42 -10
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.1.dist-info}/METADATA +1 -1
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.1.dist-info}/RECORD +6 -6
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.1.dist-info}/WHEEL +1 -1
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.43.5.dist-info → x_transformers-1.44.1.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -2,14 +2,16 @@ from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
3
3
|
|
4
4
|
import math
|
5
|
+
from copy import deepcopy
|
5
6
|
from random import random, randrange
|
6
7
|
from packaging import version
|
7
8
|
|
8
9
|
import torch
|
10
|
+
from torch.amp import autocast
|
9
11
|
import torch.nn.functional as F
|
10
12
|
from torch import nn, einsum, Tensor
|
13
|
+
from torch.utils._pytree import tree_flatten
|
11
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
12
|
-
from torch.amp import autocast
|
13
15
|
|
14
16
|
from functools import partial, wraps
|
15
17
|
from collections import namedtuple
|
@@ -1136,6 +1138,8 @@ class Attention(Module):
|
|
1136
1138
|
sigmoid = False,
|
1137
1139
|
selective = False,
|
1138
1140
|
custom_attn_fn: Callable | None = None,
|
1141
|
+
hybrid_module: Module | None = None,
|
1142
|
+
hybrid_mask_kwarg: str | None = None,
|
1139
1143
|
one_kv_head = False,
|
1140
1144
|
kv_heads = None,
|
1141
1145
|
shared_kv = False,
|
@@ -1335,6 +1339,12 @@ class Attention(Module):
|
|
1335
1339
|
|
1336
1340
|
self.attn_on_attn = on_attn
|
1337
1341
|
|
1342
|
+
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1343
|
+
|
1344
|
+
self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
|
1345
|
+
|
1346
|
+
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1347
|
+
|
1338
1348
|
# output dimension by default same as input, but can be overridden
|
1339
1349
|
|
1340
1350
|
dim_out = default(dim_out, dim)
|
@@ -1407,6 +1417,16 @@ class Attention(Module):
|
|
1407
1417
|
value_residual_mix = self.to_value_residual_mix(q_input)
|
1408
1418
|
v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
|
1409
1419
|
|
1420
|
+
# qk normalization
|
1421
|
+
|
1422
|
+
if self.qk_norm:
|
1423
|
+
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1424
|
+
q, k = map(qk_l2norm, (q, k))
|
1425
|
+
scale = self.qk_norm_scale
|
1426
|
+
|
1427
|
+
q = q * self.qk_norm_q_scale
|
1428
|
+
k = k * self.qk_norm_k_scale
|
1429
|
+
|
1410
1430
|
# take care of caching
|
1411
1431
|
|
1412
1432
|
if exists(cache):
|
@@ -1427,14 +1447,6 @@ class Attention(Module):
|
|
1427
1447
|
mem_len = mem.shape[-2] if exists(mem) else 0
|
1428
1448
|
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
|
1429
1449
|
|
1430
|
-
if self.qk_norm:
|
1431
|
-
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1432
|
-
q, k = map(qk_l2norm, (q, k))
|
1433
|
-
scale = self.qk_norm_scale
|
1434
|
-
|
1435
|
-
q = q * self.qk_norm_q_scale
|
1436
|
-
k = k * self.qk_norm_k_scale
|
1437
|
-
|
1438
1450
|
if exists(rotary_pos_emb):
|
1439
1451
|
freqs, xpos_scale = rotary_pos_emb
|
1440
1452
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
@@ -1581,6 +1593,26 @@ class Attention(Module):
|
|
1581
1593
|
|
1582
1594
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
1583
1595
|
|
1596
|
+
# hybrid module
|
1597
|
+
|
1598
|
+
if exists(self.hybrid_module):
|
1599
|
+
|
1600
|
+
# hybrid input
|
1601
|
+
|
1602
|
+
hybrid_forward_kwargs = dict()
|
1603
|
+
|
1604
|
+
if not self.causal and exists(self.hybrid_mask_kwarg):
|
1605
|
+
hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
|
1606
|
+
|
1607
|
+
# hybrid forward
|
1608
|
+
|
1609
|
+
hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
|
1610
|
+
|
1611
|
+
# handle hybrid out
|
1612
|
+
|
1613
|
+
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outs)
|
1614
|
+
out = 0.5 * (out + hybrid_out)
|
1615
|
+
|
1584
1616
|
# alphafold2 styled gating of the values
|
1585
1617
|
|
1586
1618
|
if exists(self.to_v_gate):
|
@@ -2003,7 +2035,7 @@ class AttentionLayers(Module):
|
|
2003
2035
|
|
2004
2036
|
# determine whether can cache kv
|
2005
2037
|
|
2006
|
-
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)
|
2038
|
+
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
2007
2039
|
|
2008
2040
|
def forward(
|
2009
2041
|
self,
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=yjtB4kV4N9mzHdliIM9MjyA6SoMtvpzc2Z4iU6R9_Uc,101859
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
16
|
-
x_transformers-1.
|
12
|
+
x_transformers-1.44.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.1.dist-info/METADATA,sha256=Zw_Rscb4vNZxlKosWSHSQy4EsICF45U58K0hipxydpQ,738
|
14
|
+
x_transformers-1.44.1.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
|
15
|
+
x_transformers-1.44.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|