x-transformers 1.43.5__tar.gz → 1.44.0__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-1.43.5/x_transformers.egg-info → x_transformers-1.44.0}/PKG-INFO +1 -1
- {x_transformers-1.43.5 → x_transformers-1.44.0}/README.md +11 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/setup.py +1 -1
- {x_transformers-1.43.5 → x_transformers-1.44.0}/tests/test_x_transformers.py +19 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/x_transformers.py +23 -9
- {x_transformers-1.43.5 → x_transformers-1.44.0/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.43.5 → x_transformers-1.44.0}/LICENSE +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/setup.cfg +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/__init__.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/attend.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/continuous.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/dpo.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/multi_input.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/xval.py +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers.egg-info/top_level.txt +0 -0
@@ -2374,4 +2374,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2374
2374
|
}
|
2375
2375
|
```
|
2376
2376
|
|
2377
|
+
```bibtex
|
2378
|
+
@inproceedings{anonymous2024hymba,
|
2379
|
+
title = {Hymba: A Hybrid-head Architecture for Small Language Models},
|
2380
|
+
author = {Anonymous},
|
2381
|
+
booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
|
2382
|
+
year = {2024},
|
2383
|
+
url = {https://openreview.net/forum?id=A1ztozypga},
|
2384
|
+
note = {under review}
|
2385
|
+
}
|
2386
|
+
```
|
2387
|
+
|
2377
2388
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -613,3 +613,22 @@ def test_hyper_connections(tanh):
|
|
613
613
|
x = torch.randint(0, 20000, (2, 1024))
|
614
614
|
|
615
615
|
model(x)
|
616
|
+
|
617
|
+
def test_hybrid():
|
618
|
+
from torch.nn import GRU
|
619
|
+
|
620
|
+
model = TransformerWrapper(
|
621
|
+
num_tokens = 20000,
|
622
|
+
max_seq_len = 1024,
|
623
|
+
attn_layers = Decoder(
|
624
|
+
dim = 128,
|
625
|
+
depth = 6,
|
626
|
+
heads = 8,
|
627
|
+
attn_dim_head = 64,
|
628
|
+
attn_hybrid_module = GRU(128, 64 * 8, batch_first = True)
|
629
|
+
)
|
630
|
+
)
|
631
|
+
|
632
|
+
x = torch.randint(0, 20000, (2, 1024))
|
633
|
+
|
634
|
+
embed = model(x)
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
3
3
|
|
4
4
|
import math
|
5
|
+
from copy import deepcopy
|
5
6
|
from random import random, randrange
|
6
7
|
from packaging import version
|
7
8
|
|
@@ -1136,6 +1137,7 @@ class Attention(Module):
|
|
1136
1137
|
sigmoid = False,
|
1137
1138
|
selective = False,
|
1138
1139
|
custom_attn_fn: Callable | None = None,
|
1140
|
+
hybrid_module: Module | None = None,
|
1139
1141
|
one_kv_head = False,
|
1140
1142
|
kv_heads = None,
|
1141
1143
|
shared_kv = False,
|
@@ -1335,6 +1337,10 @@ class Attention(Module):
|
|
1335
1337
|
|
1336
1338
|
self.attn_on_attn = on_attn
|
1337
1339
|
|
1340
|
+
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1341
|
+
|
1342
|
+
self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
|
1343
|
+
|
1338
1344
|
# output dimension by default same as input, but can be overridden
|
1339
1345
|
|
1340
1346
|
dim_out = default(dim_out, dim)
|
@@ -1407,6 +1413,16 @@ class Attention(Module):
|
|
1407
1413
|
value_residual_mix = self.to_value_residual_mix(q_input)
|
1408
1414
|
v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
|
1409
1415
|
|
1416
|
+
# qk normalization
|
1417
|
+
|
1418
|
+
if self.qk_norm:
|
1419
|
+
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1420
|
+
q, k = map(qk_l2norm, (q, k))
|
1421
|
+
scale = self.qk_norm_scale
|
1422
|
+
|
1423
|
+
q = q * self.qk_norm_q_scale
|
1424
|
+
k = k * self.qk_norm_k_scale
|
1425
|
+
|
1410
1426
|
# take care of caching
|
1411
1427
|
|
1412
1428
|
if exists(cache):
|
@@ -1427,14 +1443,6 @@ class Attention(Module):
|
|
1427
1443
|
mem_len = mem.shape[-2] if exists(mem) else 0
|
1428
1444
|
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
|
1429
1445
|
|
1430
|
-
if self.qk_norm:
|
1431
|
-
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
|
1432
|
-
q, k = map(qk_l2norm, (q, k))
|
1433
|
-
scale = self.qk_norm_scale
|
1434
|
-
|
1435
|
-
q = q * self.qk_norm_q_scale
|
1436
|
-
k = k * self.qk_norm_k_scale
|
1437
|
-
|
1438
1446
|
if exists(rotary_pos_emb):
|
1439
1447
|
freqs, xpos_scale = rotary_pos_emb
|
1440
1448
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
@@ -1581,6 +1589,12 @@ class Attention(Module):
|
|
1581
1589
|
|
1582
1590
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
1583
1591
|
|
1592
|
+
# hybrid module
|
1593
|
+
|
1594
|
+
if exists(self.hybrid_module):
|
1595
|
+
hybrid_out, _ = self.hybrid_module(x)
|
1596
|
+
out = 0.5 * (out + hybrid_out)
|
1597
|
+
|
1584
1598
|
# alphafold2 styled gating of the values
|
1585
1599
|
|
1586
1600
|
if exists(self.to_v_gate):
|
@@ -2003,7 +2017,7 @@ class AttentionLayers(Module):
|
|
2003
2017
|
|
2004
2018
|
# determine whether can cache kv
|
2005
2019
|
|
2006
|
-
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)
|
2020
|
+
self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
|
2007
2021
|
|
2008
2022
|
def forward(
|
2009
2023
|
self,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|