x-transformers 1.43.4__py3-none-any.whl → 1.44.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +40 -9
- {x_transformers-1.43.4.dist-info → x_transformers-1.44.0.dist-info}/METADATA +1 -1
- {x_transformers-1.43.4.dist-info → x_transformers-1.44.0.dist-info}/RECORD +6 -6
- {x_transformers-1.43.4.dist-info → x_transformers-1.44.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.43.4.dist-info → x_transformers-1.44.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.43.4.dist-info → x_transformers-1.44.0.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
3
3
|
|
4
4
|
import math
|
5
|
+
from copy import deepcopy
|
5
6
|
from random import random, randrange
|
6
7
|
from packaging import version
|
7
8
|
|
@@ -38,6 +39,7 @@ class LayerIntermediates:
|
|
38
39
|
attn_z_loss: Tensor | None = None
|
39
40
|
mems: Tensor | None = None
|
40
41
|
memory_tokens: Tensor | None = None
|
42
|
+
logit_entropies: Tensor | None = None
|
41
43
|
|
42
44
|
LinearNoBias = partial(nn.Linear, bias = False)
|
43
45
|
|
@@ -136,6 +138,15 @@ def or_reduce(masks):
|
|
136
138
|
head = head | rest
|
137
139
|
return head
|
138
140
|
|
141
|
+
# entropy
|
142
|
+
|
143
|
+
def calc_entropy(
|
144
|
+
t: Tensor,
|
145
|
+
is_prob = False
|
146
|
+
):
|
147
|
+
prob = t.softmax(dim = -1) if not is_prob else t
|
148
|
+
return -(prob * log(prob)).sum(dim = -1)
|
149
|
+
|
139
150
|
# auxiliary loss helpers
|
140
151
|
|
141
152
|
def calc_z_loss(
|
@@ -1126,6 +1137,7 @@ class Attention(Module):
|
|
1126
1137
|
sigmoid = False,
|
1127
1138
|
selective = False,
|
1128
1139
|
custom_attn_fn: Callable | None = None,
|
1140
|
+
hybrid_module: Module | None = None,
|
1129
1141
|
one_kv_head = False,
|
1130
1142
|
kv_heads = None,
|
1131
1143
|
shared_kv = False,
|
@@ -1325,6 +1337,10 @@ class Attention(Module):
|
|
1325
1337
|
|
1326
1338
|
self.attn_on_attn = on_attn
|
1327
1339
|
|
1340
|
+
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1341
|
+
|
1342
|
+
self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
|
1343
|
+
|
1328
1344
|
# output dimension by default same as input, but can be overridden
|
1329
1345
|
|
1330
1346
|
dim_out = default(dim_out, dim)
|
@@ -1397,6 +1413,16 @@ class Attention(Module):
|
|
1397
1413
|
value_residual_mix = self.to_value_residual_mix(q_input)
|
1398
1414
|
v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
|
1399
1415
|
|
1416
|
+
# qk normalization
|
1417
|
+
|
1418
|
+
if self.qk_norm:
|
1419
|
+
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1420
|
+
q, k = map(qk_l2norm, (q, k))
|
1421
|
+
scale = self.qk_norm_scale
|
1422
|
+
|
1423
|
+
q = q * self.qk_norm_q_scale
|
1424
|
+
k = k * self.qk_norm_k_scale
|
1425
|
+
|
1400
1426
|
# take care of caching
|
1401
1427
|
|
1402
1428
|
if exists(cache):
|
@@ -1417,14 +1443,6 @@ class Attention(Module):
|
|
1417
1443
|
mem_len = mem.shape[-2] if exists(mem) else 0
|
1418
1444
|
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
|
1419
1445
|
|
1420
|
-
if self.qk_norm:
|
1421
|
-
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1422
|
-
q, k = map(qk_l2norm, (q, k))
|
1423
|
-
scale = self.qk_norm_scale
|
1424
|
-
|
1425
|
-
q = q * self.qk_norm_q_scale
|
1426
|
-
k = k * self.qk_norm_k_scale
|
1427
|
-
|
1428
1446
|
if exists(rotary_pos_emb):
|
1429
1447
|
freqs, xpos_scale = rotary_pos_emb
|
1430
1448
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
@@ -1571,6 +1589,12 @@ class Attention(Module):
|
|
1571
1589
|
|
1572
1590
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
1573
1591
|
|
1592
|
+
# hybrid module
|
1593
|
+
|
1594
|
+
if exists(self.hybrid_module):
|
1595
|
+
hybrid_out, _ = self.hybrid_module(x)
|
1596
|
+
out = 0.5 * (out + hybrid_out)
|
1597
|
+
|
1574
1598
|
# alphafold2 styled gating of the values
|
1575
1599
|
|
1576
1600
|
if exists(self.to_v_gate):
|
@@ -1993,7 +2017,7 @@ class AttentionLayers(Module):
|
|
1993
2017
|
|
1994
2018
|
# determine whether can cache kv
|
1995
2019
|
|
1996
|
-
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)
|
2020
|
+
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
1997
2021
|
|
1998
2022
|
def forward(
|
1999
2023
|
self,
|
@@ -2592,6 +2616,7 @@ class TransformerWrapper(Module):
|
|
2592
2616
|
return_embeddings = False,
|
2593
2617
|
return_logits_and_embeddings = False,
|
2594
2618
|
return_intermediates = False,
|
2619
|
+
return_logit_entropies = False,
|
2595
2620
|
mask = None,
|
2596
2621
|
return_mems = False,
|
2597
2622
|
return_attn = False,
|
@@ -2809,6 +2834,12 @@ class TransformerWrapper(Module):
|
|
2809
2834
|
else:
|
2810
2835
|
out = logits
|
2811
2836
|
|
2837
|
+
# logit entropies
|
2838
|
+
|
2839
|
+
if return_logit_entropies:
|
2840
|
+
intermediates.logit_entropies = calc_entropy(logits)
|
2841
|
+
return_intermediates = True
|
2842
|
+
|
2812
2843
|
# aux loss
|
2813
2844
|
|
2814
2845
|
if return_attn_z_loss:
|
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=BI3RU3XFvwSNDZgoQBrFBSJ4SavJr38rOCCVgHZBTx0,101241
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
15
|
-
x_transformers-1.
|
16
|
-
x_transformers-1.
|
12
|
+
x_transformers-1.44.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.44.0.dist-info/METADATA,sha256=MNVwW_pDeKEIHRVEA1XOUNfzFmL706X7Npoh7xc3wIk,738
|
14
|
+
x_transformers-1.44.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.44.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.44.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|