x-transformers 1.43.5__tar.gz → 1.44.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.43.5/x_transformers.egg-info → x_transformers-1.44.0}/PKG-INFO +1 -1
  2. {x_transformers-1.43.5 → x_transformers-1.44.0}/README.md +11 -0
  3. {x_transformers-1.43.5 → x_transformers-1.44.0}/setup.py +1 -1
  4. {x_transformers-1.43.5 → x_transformers-1.44.0}/tests/test_x_transformers.py +19 -0
  5. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/x_transformers.py +23 -9
  6. {x_transformers-1.43.5 → x_transformers-1.44.0/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.43.5 → x_transformers-1.44.0}/LICENSE +0 -0
  8. {x_transformers-1.43.5 → x_transformers-1.44.0}/setup.cfg +0 -0
  9. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.43.5 → x_transformers-1.44.0}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.5
3
+ Version: 1.44.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -2374,4 +2374,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2374
2374
  }
2375
2375
  ```
2376
2376
 
2377
+ ```bibtex
2378
+ @inproceedings{anonymous2024hymba,
2379
+ title = {Hymba: A Hybrid-head Architecture for Small Language Models},
2380
+ author = {Anonymous},
2381
+ booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
2382
+ year = {2024},
2383
+ url = {https://openreview.net/forum?id=A1ztozypga},
2384
+ note = {under review}
2385
+ }
2386
+ ```
2387
+
2377
2388
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.43.5',
6
+ version = '1.44.0',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -613,3 +613,22 @@ def test_hyper_connections(tanh):
613
613
  x = torch.randint(0, 20000, (2, 1024))
614
614
 
615
615
  model(x)
616
+
617
+ def test_hybrid():
618
+ from torch.nn import GRU
619
+
620
+ model = TransformerWrapper(
621
+ num_tokens = 20000,
622
+ max_seq_len = 1024,
623
+ attn_layers = Decoder(
624
+ dim = 128,
625
+ depth = 6,
626
+ heads = 8,
627
+ attn_dim_head = 64,
628
+ attn_hybrid_module = GRU(128, 64 * 8, batch_first = True)
629
+ )
630
+ )
631
+
632
+ x = torch.randint(0, 20000, (2, 1024))
633
+
634
+ embed = model(x)
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
  from typing import Callable
3
3
 
4
4
  import math
5
+ from copy import deepcopy
5
6
  from random import random, randrange
6
7
  from packaging import version
7
8
 
@@ -1136,6 +1137,7 @@ class Attention(Module):
1136
1137
  sigmoid = False,
1137
1138
  selective = False,
1138
1139
  custom_attn_fn: Callable | None = None,
1140
+ hybrid_module: Module | None = None,
1139
1141
  one_kv_head = False,
1140
1142
  kv_heads = None,
1141
1143
  shared_kv = False,
@@ -1335,6 +1337,10 @@ class Attention(Module):
1335
1337
 
1336
1338
  self.attn_on_attn = on_attn
1337
1339
 
1340
+ # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1341
+
1342
+ self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1343
+
1338
1344
  # output dimension by default same as input, but can be overridden
1339
1345
 
1340
1346
  dim_out = default(dim_out, dim)
@@ -1407,6 +1413,16 @@ class Attention(Module):
1407
1413
  value_residual_mix = self.to_value_residual_mix(q_input)
1408
1414
  v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
1409
1415
 
1416
+ # qk normalization
1417
+
1418
+ if self.qk_norm:
1419
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1420
+ q, k = map(qk_l2norm, (q, k))
1421
+ scale = self.qk_norm_scale
1422
+
1423
+ q = q * self.qk_norm_q_scale
1424
+ k = k * self.qk_norm_k_scale
1425
+
1410
1426
  # take care of caching
1411
1427
 
1412
1428
  if exists(cache):
@@ -1427,14 +1443,6 @@ class Attention(Module):
1427
1443
  mem_len = mem.shape[-2] if exists(mem) else 0
1428
1444
  cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1429
1445
 
1430
- if self.qk_norm:
1431
- qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1432
- q, k = map(qk_l2norm, (q, k))
1433
- scale = self.qk_norm_scale
1434
-
1435
- q = q * self.qk_norm_q_scale
1436
- k = k * self.qk_norm_k_scale
1437
-
1438
1446
  if exists(rotary_pos_emb):
1439
1447
  freqs, xpos_scale = rotary_pos_emb
1440
1448
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
@@ -1581,6 +1589,12 @@ class Attention(Module):
1581
1589
 
1582
1590
  out = rearrange(out, 'b h n d -> b n (h d)')
1583
1591
 
1592
+ # hybrid module
1593
+
1594
+ if exists(self.hybrid_module):
1595
+ hybrid_out, _ = self.hybrid_module(x)
1596
+ out = 0.5 * (out + hybrid_out)
1597
+
1584
1598
  # alphafold2 styled gating of the values
1585
1599
 
1586
1600
  if exists(self.to_v_gate):
@@ -2003,7 +2017,7 @@ class AttentionLayers(Module):
2003
2017
 
2004
2018
  # determine whether can cache kv
2005
2019
 
2006
- self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ])
2020
+ self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
2007
2021
 
2008
2022
  def forward(
2009
2023
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.5
3
+ Version: 1.44.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes