x-transformers 1.43.5__tar.gz → 1.44.1__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.43.5/x_transformers.egg-info → x_transformers-1.44.1}/PKG-INFO +1 -1
- {x_transformers-1.43.5 → x_transformers-1.44.1}/README.md +14 -1
- {x_transformers-1.43.5 → x_transformers-1.44.1}/setup.py +1 -1
- {x_transformers-1.43.5 → x_transformers-1.44.1}/tests/test_x_transformers.py +34 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/x_transformers.py +42 -10
- {x_transformers-1.43.5 → x_transformers-1.44.1/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.43.5 → x_transformers-1.44.1}/LICENSE +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/setup.cfg +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/__init__.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/attend.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/continuous.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/dpo.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/xval.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers.egg-info/top_level.txt +0 -0
@@ -317,7 +317,9 @@ model = TransformerWrapper(
|
|
317
317
|
|
318
318
|
Update: MetaAI researchers <a href="https://arxiv.org/abs/2309.16588">have found</a> that adding memory tokens (they call them register tokens), alleviates outliers (which is suspected now to be a pathology of attention networks unable to <a href="https://arxiv.org/abs/2306.12929">attend to nothing</a>).
|
319
319
|
|
320
|
-
Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper
|
320
|
+
Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper.
|
321
|
+
|
322
|
+
Update 3: further corroborated by <a href="https://arxiv.org/abs/2501.00663">a paper</a> trying to extend memory in attention networks, termed persistent memory
|
321
323
|
|
322
324
|
### Transformers Without Tears
|
323
325
|
|
@@ -2374,4 +2376,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2374
2376
|
}
|
2375
2377
|
```
|
2376
2378
|
|
2379
|
+
```bibtex
|
2380
|
+
@inproceedings{anonymous2024hymba,
|
2381
|
+
title = {Hymba: A Hybrid-head Architecture for Small Language Models},
|
2382
|
+
author = {Anonymous},
|
2383
|
+
booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
|
2384
|
+
year = {2024},
|
2385
|
+
url = {https://openreview.net/forum?id=A1ztozypga},
|
2386
|
+
note = {under review}
|
2387
|
+
}
|
2388
|
+
```
|
2389
|
+
|
2377
2390
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -613,3 +613,37 @@ def test_hyper_connections(tanh):
|
|
613
613
|
x = torch.randint(0, 20000, (2, 1024))
|
614
614
|
|
615
615
|
model(x)
|
616
|
+
|
617
|
+
def test_hybrid():
|
618
|
+
from torch.nn import GRU
|
619
|
+
|
620
|
+
dec = TransformerWrapper(
|
621
|
+
num_tokens = 20000,
|
622
|
+
max_seq_len = 1024,
|
623
|
+
attn_layers = Decoder(
|
624
|
+
dim = 128,
|
625
|
+
depth = 6,
|
626
|
+
heads = 8,
|
627
|
+
attn_dim_head = 64,
|
628
|
+
attn_hybrid_module = GRU(128, 64 * 8, batch_first = True)
|
629
|
+
)
|
630
|
+
)
|
631
|
+
|
632
|
+
x = torch.randint(0, 20000, (2, 1024))
|
633
|
+
|
634
|
+
embed = dec(x)
|
635
|
+
|
636
|
+
enc = TransformerWrapper(
|
637
|
+
num_tokens = 20000,
|
638
|
+
max_seq_len = 1024,
|
639
|
+
attn_layers = Encoder(
|
640
|
+
dim = 128,
|
641
|
+
depth = 6,
|
642
|
+
heads = 8,
|
643
|
+
attn_dim_head = 64,
|
644
|
+
attn_hybrid_module = GRU(128, 64 * 4, batch_first = True, bidirectional = True)
|
645
|
+
)
|
646
|
+
)
|
647
|
+
|
648
|
+
mask = torch.randint(0, 2, (2, 1024)).bool()
|
649
|
+
embed = enc(x, mask = mask)
|
@@ -2,14 +2,16 @@ from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
3
3
|
|
4
4
|
import math
|
5
|
+
from copy import deepcopy
|
5
6
|
from random import random, randrange
|
6
7
|
from packaging import version
|
7
8
|
|
8
9
|
import torch
|
10
|
+
from torch.amp import autocast
|
9
11
|
import torch.nn.functional as F
|
10
12
|
from torch import nn, einsum, Tensor
|
13
|
+
from torch.utils._pytree import tree_flatten
|
11
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
12
|
-
from torch.amp import autocast
|
13
15
|
|
14
16
|
from functools import partial, wraps
|
15
17
|
from collections import namedtuple
|
@@ -1136,6 +1138,8 @@ class Attention(Module):
|
|
1136
1138
|
sigmoid = False,
|
1137
1139
|
selective = False,
|
1138
1140
|
custom_attn_fn: Callable | None = None,
|
1141
|
+
hybrid_module: Module | None = None,
|
1142
|
+
hybrid_mask_kwarg: str | None = None,
|
1139
1143
|
one_kv_head = False,
|
1140
1144
|
kv_heads = None,
|
1141
1145
|
shared_kv = False,
|
@@ -1335,6 +1339,12 @@ class Attention(Module):
|
|
1335
1339
|
|
1336
1340
|
self.attn_on_attn = on_attn
|
1337
1341
|
|
1342
|
+
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1343
|
+
|
1344
|
+
self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
|
1345
|
+
|
1346
|
+
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1347
|
+
|
1338
1348
|
# output dimension by default same as input, but can be overridden
|
1339
1349
|
|
1340
1350
|
dim_out = default(dim_out, dim)
|
@@ -1407,6 +1417,16 @@ class Attention(Module):
|
|
1407
1417
|
value_residual_mix = self.to_value_residual_mix(q_input)
|
1408
1418
|
v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
|
1409
1419
|
|
1420
|
+
# qk normalization
|
1421
|
+
|
1422
|
+
if self.qk_norm:
|
1423
|
+
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1424
|
+
q, k = map(qk_l2norm, (q, k))
|
1425
|
+
scale = self.qk_norm_scale
|
1426
|
+
|
1427
|
+
q = q * self.qk_norm_q_scale
|
1428
|
+
k = k * self.qk_norm_k_scale
|
1429
|
+
|
1410
1430
|
# take care of caching
|
1411
1431
|
|
1412
1432
|
if exists(cache):
|
@@ -1427,14 +1447,6 @@ class Attention(Module):
|
|
1427
1447
|
mem_len = mem.shape[-2] if exists(mem) else 0
|
1428
1448
|
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
|
1429
1449
|
|
1430
|
-
if self.qk_norm:
|
1431
|
-
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1432
|
-
q, k = map(qk_l2norm, (q, k))
|
1433
|
-
scale = self.qk_norm_scale
|
1434
|
-
|
1435
|
-
q = q * self.qk_norm_q_scale
|
1436
|
-
k = k * self.qk_norm_k_scale
|
1437
|
-
|
1438
1450
|
if exists(rotary_pos_emb):
|
1439
1451
|
freqs, xpos_scale = rotary_pos_emb
|
1440
1452
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
@@ -1581,6 +1593,26 @@ class Attention(Module):
|
|
1581
1593
|
|
1582
1594
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
1583
1595
|
|
1596
|
+
# hybrid module
|
1597
|
+
|
1598
|
+
if exists(self.hybrid_module):
|
1599
|
+
|
1600
|
+
# hybrid input
|
1601
|
+
|
1602
|
+
hybrid_forward_kwargs = dict()
|
1603
|
+
|
1604
|
+
if not self.causal and exists(self.hybrid_mask_kwarg):
|
1605
|
+
hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
|
1606
|
+
|
1607
|
+
# hybrid forward
|
1608
|
+
|
1609
|
+
hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
|
1610
|
+
|
1611
|
+
# handle hybrid out
|
1612
|
+
|
1613
|
+
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outs)
|
1614
|
+
out = 0.5 * (out + hybrid_out)
|
1615
|
+
|
1584
1616
|
# alphafold2 styled gating of the values
|
1585
1617
|
|
1586
1618
|
if exists(self.to_v_gate):
|
@@ -2003,7 +2035,7 @@ class AttentionLayers(Module):
|
|
2003
2035
|
|
2004
2036
|
# determine whether can cache kv
|
2005
2037
|
|
2006
|
-
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)
|
2038
|
+
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
2007
2039
|
|
2008
2040
|
def forward(
|
2009
2041
|
self,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|