x-transformers 1.43.5__tar.gz → 1.44.1__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.43.5/x_transformers.egg-info → x_transformers-1.44.1}/PKG-INFO +1 -1
  2. {x_transformers-1.43.5 → x_transformers-1.44.1}/README.md +14 -1
  3. {x_transformers-1.43.5 → x_transformers-1.44.1}/setup.py +1 -1
  4. {x_transformers-1.43.5 → x_transformers-1.44.1}/tests/test_x_transformers.py +34 -0
  5. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/x_transformers.py +42 -10
  6. {x_transformers-1.43.5 → x_transformers-1.44.1/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.43.5 → x_transformers-1.44.1}/LICENSE +0 -0
  8. {x_transformers-1.43.5 → x_transformers-1.44.1}/setup.cfg +0 -0
  9. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.43.5 → x_transformers-1.44.1}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.5
3
+ Version: 1.44.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -317,7 +317,9 @@ model = TransformerWrapper(
317
317
 
318
318
  Update: MetaAI researchers <a href="https://arxiv.org/abs/2309.16588">have found</a> that adding memory tokens (they call them register tokens), alleviates outliers (which is suspected now to be a pathology of attention networks unable to <a href="https://arxiv.org/abs/2306.12929">attend to nothing</a>).
319
319
 
320
- Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper
320
+ Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper.
321
+
322
+ Update 3: further corroborated by <a href="https://arxiv.org/abs/2501.00663">a paper</a> trying to extend memory in attention networks, termed persistent memory
321
323
 
322
324
  ### Transformers Without Tears
323
325
 
@@ -2374,4 +2376,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2374
2376
  }
2375
2377
  ```
2376
2378
 
2379
+ ```bibtex
2380
+ @inproceedings{anonymous2024hymba,
2381
+ title = {Hymba: A Hybrid-head Architecture for Small Language Models},
2382
+ author = {Anonymous},
2383
+ booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
2384
+ year = {2024},
2385
+ url = {https://openreview.net/forum?id=A1ztozypga},
2386
+ note = {under review}
2387
+ }
2388
+ ```
2389
+
2377
2390
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.43.5',
6
+ version = '1.44.1',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -613,3 +613,37 @@ def test_hyper_connections(tanh):
613
613
  x = torch.randint(0, 20000, (2, 1024))
614
614
 
615
615
  model(x)
616
+
617
+ def test_hybrid():
618
+ from torch.nn import GRU
619
+
620
+ dec = TransformerWrapper(
621
+ num_tokens = 20000,
622
+ max_seq_len = 1024,
623
+ attn_layers = Decoder(
624
+ dim = 128,
625
+ depth = 6,
626
+ heads = 8,
627
+ attn_dim_head = 64,
628
+ attn_hybrid_module = GRU(128, 64 * 8, batch_first = True)
629
+ )
630
+ )
631
+
632
+ x = torch.randint(0, 20000, (2, 1024))
633
+
634
+ embed = dec(x)
635
+
636
+ enc = TransformerWrapper(
637
+ num_tokens = 20000,
638
+ max_seq_len = 1024,
639
+ attn_layers = Encoder(
640
+ dim = 128,
641
+ depth = 6,
642
+ heads = 8,
643
+ attn_dim_head = 64,
644
+ attn_hybrid_module = GRU(128, 64 * 4, batch_first = True, bidirectional = True)
645
+ )
646
+ )
647
+
648
+ mask = torch.randint(0, 2, (2, 1024)).bool()
649
+ embed = enc(x, mask = mask)
@@ -2,14 +2,16 @@ from __future__ import annotations
2
2
  from typing import Callable
3
3
 
4
4
  import math
5
+ from copy import deepcopy
5
6
  from random import random, randrange
6
7
  from packaging import version
7
8
 
8
9
  import torch
10
+ from torch.amp import autocast
9
11
  import torch.nn.functional as F
10
12
  from torch import nn, einsum, Tensor
13
+ from torch.utils._pytree import tree_flatten
11
14
  from torch.nn import Module, ModuleList, ModuleDict
12
- from torch.amp import autocast
13
15
 
14
16
  from functools import partial, wraps
15
17
  from collections import namedtuple
@@ -1136,6 +1138,8 @@ class Attention(Module):
1136
1138
  sigmoid = False,
1137
1139
  selective = False,
1138
1140
  custom_attn_fn: Callable | None = None,
1141
+ hybrid_module: Module | None = None,
1142
+ hybrid_mask_kwarg: str | None = None,
1139
1143
  one_kv_head = False,
1140
1144
  kv_heads = None,
1141
1145
  shared_kv = False,
@@ -1335,6 +1339,12 @@ class Attention(Module):
1335
1339
 
1336
1340
  self.attn_on_attn = on_attn
1337
1341
 
1342
+ # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1343
+
1344
+ self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1345
+
1346
+ self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1347
+
1338
1348
  # output dimension by default same as input, but can be overridden
1339
1349
 
1340
1350
  dim_out = default(dim_out, dim)
@@ -1407,6 +1417,16 @@ class Attention(Module):
1407
1417
  value_residual_mix = self.to_value_residual_mix(q_input)
1408
1418
  v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
1409
1419
 
1420
+ # qk normalization
1421
+
1422
+ if self.qk_norm:
1423
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1424
+ q, k = map(qk_l2norm, (q, k))
1425
+ scale = self.qk_norm_scale
1426
+
1427
+ q = q * self.qk_norm_q_scale
1428
+ k = k * self.qk_norm_k_scale
1429
+
1410
1430
  # take care of caching
1411
1431
 
1412
1432
  if exists(cache):
@@ -1427,14 +1447,6 @@ class Attention(Module):
1427
1447
  mem_len = mem.shape[-2] if exists(mem) else 0
1428
1448
  cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1429
1449
 
1430
- if self.qk_norm:
1431
- qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1432
- q, k = map(qk_l2norm, (q, k))
1433
- scale = self.qk_norm_scale
1434
-
1435
- q = q * self.qk_norm_q_scale
1436
- k = k * self.qk_norm_k_scale
1437
-
1438
1450
  if exists(rotary_pos_emb):
1439
1451
  freqs, xpos_scale = rotary_pos_emb
1440
1452
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
@@ -1581,6 +1593,26 @@ class Attention(Module):
1581
1593
 
1582
1594
  out = rearrange(out, 'b h n d -> b n (h d)')
1583
1595
 
1596
+ # hybrid module
1597
+
1598
+ if exists(self.hybrid_module):
1599
+
1600
+ # hybrid input
1601
+
1602
+ hybrid_forward_kwargs = dict()
1603
+
1604
+ if not self.causal and exists(self.hybrid_mask_kwarg):
1605
+ hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
1606
+
1607
+ # hybrid forward
1608
+
1609
+ hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
1610
+
1611
+ # handle hybrid out
1612
+
1613
+ (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outs)
1614
+ out = 0.5 * (out + hybrid_out)
1615
+
1584
1616
  # alphafold2 styled gating of the values
1585
1617
 
1586
1618
  if exists(self.to_v_gate):
@@ -2003,7 +2035,7 @@ class AttentionLayers(Module):
2003
2035
 
2004
2036
  # determine whether can cache kv
2005
2037
 
2006
- self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ])
2038
+ self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
2007
2039
 
2008
2040
  def forward(
2009
2041
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.5
3
+ Version: 1.44.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes