x-transformers 1.44.0__tar.gz → 1.44.2__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.44.0/x_transformers.egg-info → x_transformers-1.44.2}/PKG-INFO +1 -1
  2. {x_transformers-1.44.0 → x_transformers-1.44.2}/README.md +3 -1
  3. {x_transformers-1.44.0 → x_transformers-1.44.2}/setup.py +1 -1
  4. {x_transformers-1.44.0 → x_transformers-1.44.2}/tests/test_x_transformers.py +17 -2
  5. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/x_transformers.py +20 -2
  6. {x_transformers-1.44.0 → x_transformers-1.44.2/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.44.0 → x_transformers-1.44.2}/LICENSE +0 -0
  8. {x_transformers-1.44.0 → x_transformers-1.44.2}/setup.cfg +0 -0
  9. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.44.0 → x_transformers-1.44.2}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.44.0
3
+ Version: 1.44.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -317,7 +317,9 @@ model = TransformerWrapper(
317
317
 
318
318
  Update: MetaAI researchers <a href="https://arxiv.org/abs/2309.16588">have found</a> that adding memory tokens (they call them register tokens), alleviates outliers (which is suspected now to be a pathology of attention networks unable to <a href="https://arxiv.org/abs/2306.12929">attend to nothing</a>).
319
319
 
320
- Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper
320
+ Update 2: a hybrid architecture out of Nvidia named <a href="https://openreview.net/forum?id=A1ztozypga">Hymba</a> used memory tokens successfully in the autoregressive case, termed meta tokens in their paper.
321
+
322
+ Update 3: further corroborated by <a href="https://arxiv.org/abs/2501.00663">a paper</a> trying to extend memory in attention networks, termed persistent memory
321
323
 
322
324
  ### Transformers Without Tears
323
325
 
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.44.0',
6
+ version = '1.44.2',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -617,7 +617,7 @@ def test_hyper_connections(tanh):
617
617
  def test_hybrid():
618
618
  from torch.nn import GRU
619
619
 
620
- model = TransformerWrapper(
620
+ dec = TransformerWrapper(
621
621
  num_tokens = 20000,
622
622
  max_seq_len = 1024,
623
623
  attn_layers = Decoder(
@@ -631,4 +631,19 @@ def test_hybrid():
631
631
 
632
632
  x = torch.randint(0, 20000, (2, 1024))
633
633
 
634
- embed = model(x)
634
+ embed = dec(x)
635
+
636
+ enc = TransformerWrapper(
637
+ num_tokens = 20000,
638
+ max_seq_len = 1024,
639
+ attn_layers = Encoder(
640
+ dim = 128,
641
+ depth = 6,
642
+ heads = 8,
643
+ attn_dim_head = 64,
644
+ attn_hybrid_module = GRU(128, 64 * 4, batch_first = True, bidirectional = True)
645
+ )
646
+ )
647
+
648
+ mask = torch.randint(0, 2, (2, 1024)).bool()
649
+ embed = enc(x, mask = mask)
@@ -7,10 +7,11 @@ from random import random, randrange
7
7
  from packaging import version
8
8
 
9
9
  import torch
10
+ from torch.amp import autocast
10
11
  import torch.nn.functional as F
11
12
  from torch import nn, einsum, Tensor
13
+ from torch.utils._pytree import tree_flatten
12
14
  from torch.nn import Module, ModuleList, ModuleDict
13
- from torch.amp import autocast
14
15
 
15
16
  from functools import partial, wraps
16
17
  from collections import namedtuple
@@ -1138,6 +1139,7 @@ class Attention(Module):
1138
1139
  selective = False,
1139
1140
  custom_attn_fn: Callable | None = None,
1140
1141
  hybrid_module: Module | None = None,
1142
+ hybrid_mask_kwarg: str | None = None,
1141
1143
  one_kv_head = False,
1142
1144
  kv_heads = None,
1143
1145
  shared_kv = False,
@@ -1341,6 +1343,8 @@ class Attention(Module):
1341
1343
 
1342
1344
  self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1343
1345
 
1346
+ self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1347
+
1344
1348
  # output dimension by default same as input, but can be overridden
1345
1349
 
1346
1350
  dim_out = default(dim_out, dim)
@@ -1592,7 +1596,21 @@ class Attention(Module):
1592
1596
  # hybrid module
1593
1597
 
1594
1598
  if exists(self.hybrid_module):
1595
- hybrid_out, _ = self.hybrid_module(x)
1599
+
1600
+ # hybrid input
1601
+
1602
+ hybrid_forward_kwargs = dict()
1603
+
1604
+ if not self.causal and exists(self.hybrid_mask_kwarg):
1605
+ hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
1606
+
1607
+ # hybrid forward
1608
+
1609
+ hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
1610
+
1611
+ # handle hybrid out
1612
+
1613
+ (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
1596
1614
  out = 0.5 * (out + hybrid_out)
1597
1615
 
1598
1616
  # alphafold2 styled gating of the values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.44.0
3
+ Version: 1.44.2
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes