x-transformers 1.44.0__tar.gz → 1.44.2__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.44.0/x_transformers.egg-info → x_transformers-1.44.2}/PKG-INFO +1 -1
- {x_transformers-1.44.0 → x_transformers-1.44.2}/README.md +3 -1
- {x_transformers-1.44.0 → x_transformers-1.44.2}/setup.py +1 -1
- {x_transformers-1.44.0 → x_transformers-1.44.2}/tests/test_x_transformers.py +17 -2
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/x_transformers.py +20 -2
- {x_transformers-1.44.0 → x_transformers-1.44.2/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.44.0 → x_transformers-1.44.2}/LICENSE +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/setup.cfg +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/__init__.py +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/attend.py +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/continuous.py +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/dpo.py +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/xval.py +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers.egg-info/top_level.txt +0 -0
@@ -317,7 +317,9 @@ model = TransformerWrapper(
|
|
317
317
|
|
318
318
|
Update: MetaAI researchers <a href="https://arxiv.org/abs/2309.16588">have found</a> that adding memory tokens (they call them register tokens), alleviates outliers (which is suspected now to be a pathology of attention networks unable to <a href="https://arxiv.org/abs/2306.12929">attend to nothing</a>).
|
319
319
|
|
320
|
-
Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper
|
320
|
+
Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper.
|
321
|
+
|
322
|
+
Update 3: further corroborated by <a href="https://arxiv.org/abs/2501.00663">a paper</a> trying to extend memory in attention networks, termed persistent memory
|
321
323
|
|
322
324
|
### Transformers Without Tears
|
323
325
|
|
@@ -617,7 +617,7 @@ def test_hyper_connections(tanh):
|
|
617
617
|
def test_hybrid():
|
618
618
|
from torch.nn import GRU
|
619
619
|
|
620
|
-
|
620
|
+
dec = TransformerWrapper(
|
621
621
|
num_tokens = 20000,
|
622
622
|
max_seq_len = 1024,
|
623
623
|
attn_layers = Decoder(
|
@@ -631,4 +631,19 @@ def test_hybrid():
|
|
631
631
|
|
632
632
|
x = torch.randint(0, 20000, (2, 1024))
|
633
633
|
|
634
|
-
embed =
|
634
|
+
embed = dec(x)
|
635
|
+
|
636
|
+
enc = TransformerWrapper(
|
637
|
+
num_tokens = 20000,
|
638
|
+
max_seq_len = 1024,
|
639
|
+
attn_layers = Encoder(
|
640
|
+
dim = 128,
|
641
|
+
depth = 6,
|
642
|
+
heads = 8,
|
643
|
+
attn_dim_head = 64,
|
644
|
+
attn_hybrid_module = GRU(128, 64 * 4, batch_first = True, bidirectional = True)
|
645
|
+
)
|
646
|
+
)
|
647
|
+
|
648
|
+
mask = torch.randint(0, 2, (2, 1024)).bool()
|
649
|
+
embed = enc(x, mask = mask)
|
@@ -7,10 +7,11 @@ from random import random, randrange
|
|
7
7
|
from packaging import version
|
8
8
|
|
9
9
|
import torch
|
10
|
+
from torch.amp import autocast
|
10
11
|
import torch.nn.functional as F
|
11
12
|
from torch import nn, einsum, Tensor
|
13
|
+
from torch.utils._pytree import tree_flatten
|
12
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
13
|
-
from torch.amp import autocast
|
14
15
|
|
15
16
|
from functools import partial, wraps
|
16
17
|
from collections import namedtuple
|
@@ -1138,6 +1139,7 @@ class Attention(Module):
|
|
1138
1139
|
selective = False,
|
1139
1140
|
custom_attn_fn: Callable | None = None,
|
1140
1141
|
hybrid_module: Module | None = None,
|
1142
|
+
hybrid_mask_kwarg: str | None = None,
|
1141
1143
|
one_kv_head = False,
|
1142
1144
|
kv_heads = None,
|
1143
1145
|
shared_kv = False,
|
@@ -1341,6 +1343,8 @@ class Attention(Module):
|
|
1341
1343
|
|
1342
1344
|
self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
|
1343
1345
|
|
1346
|
+
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1347
|
+
|
1344
1348
|
# output dimension by default same as input, but can be overridden
|
1345
1349
|
|
1346
1350
|
dim_out = default(dim_out, dim)
|
@@ -1592,7 +1596,21 @@ class Attention(Module):
|
|
1592
1596
|
# hybrid module
|
1593
1597
|
|
1594
1598
|
if exists(self.hybrid_module):
|
1595
|
-
|
1599
|
+
|
1600
|
+
# hybrid input
|
1601
|
+
|
1602
|
+
hybrid_forward_kwargs = dict()
|
1603
|
+
|
1604
|
+
if not self.causal and exists(self.hybrid_mask_kwarg):
|
1605
|
+
hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
|
1606
|
+
|
1607
|
+
# hybrid forward
|
1608
|
+
|
1609
|
+
hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
|
1610
|
+
|
1611
|
+
# handle hybrid out
|
1612
|
+
|
1613
|
+
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
|
1596
1614
|
out = 0.5 * (out + hybrid_out)
|
1597
1615
|
|
1598
1616
|
# alphafold2 styled gating of the values
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|