x-transformers 1.43.4__tar.gz → 1.44.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. {x_transformers-1.43.4/x_transformers.egg-info → x_transformers-1.44.0}/PKG-INFO +1 -1
  2. {x_transformers-1.43.4 → x_transformers-1.44.0}/README.md +11 -0
  3. {x_transformers-1.43.4 → x_transformers-1.44.0}/setup.py +1 -1
  4. {x_transformers-1.43.4 → x_transformers-1.44.0}/tests/test_x_transformers.py +19 -0
  5. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers/x_transformers.py +40 -9
  6. {x_transformers-1.43.4 → x_transformers-1.44.0/x_transformers.egg-info}/PKG-INFO +1 -1
  7. {x_transformers-1.43.4 → x_transformers-1.44.0}/LICENSE +0 -0
  8. {x_transformers-1.43.4 → x_transformers-1.44.0}/setup.cfg +0 -0
  9. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers/neo_mlp.py +0 -0
  16. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  17. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  18. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers/xval.py +0 -0
  19. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers.egg-info/SOURCES.txt +0 -0
  20. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers.egg-info/dependency_links.txt +0 -0
  21. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers.egg-info/requires.txt +0 -0
  22. {x_transformers-1.43.4 → x_transformers-1.44.0}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.4
3
+ Version: 1.44.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -2374,4 +2374,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2374
2374
  }
2375
2375
  ```
2376
2376
 
2377
+ ```bibtex
2378
+ @inproceedings{anonymous2024hymba,
2379
+ title = {Hymba: A Hybrid-head Architecture for Small Language Models},
2380
+ author = {Anonymous},
2381
+ booktitle = {Submitted to The Thirteenth International Conference on Learning Representations},
2382
+ year = {2024},
2383
+ url = {https://openreview.net/forum?id=A1ztozypga},
2384
+ note = {under review}
2385
+ }
2386
+ ```
2387
+
2377
2388
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.43.4',
6
+ version = '1.44.0',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -613,3 +613,22 @@ def test_hyper_connections(tanh):
613
613
  x = torch.randint(0, 20000, (2, 1024))
614
614
 
615
615
  model(x)
616
+
617
+ def test_hybrid():
618
+ from torch.nn import GRU
619
+
620
+ model = TransformerWrapper(
621
+ num_tokens = 20000,
622
+ max_seq_len = 1024,
623
+ attn_layers = Decoder(
624
+ dim = 128,
625
+ depth = 6,
626
+ heads = 8,
627
+ attn_dim_head = 64,
628
+ attn_hybrid_module = GRU(128, 64 * 8, batch_first = True)
629
+ )
630
+ )
631
+
632
+ x = torch.randint(0, 20000, (2, 1024))
633
+
634
+ embed = model(x)
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
  from typing import Callable
3
3
 
4
4
  import math
5
+ from copy import deepcopy
5
6
  from random import random, randrange
6
7
  from packaging import version
7
8
 
@@ -38,6 +39,7 @@ class LayerIntermediates:
38
39
  attn_z_loss: Tensor | None = None
39
40
  mems: Tensor | None = None
40
41
  memory_tokens: Tensor | None = None
42
+ logit_entropies: Tensor | None = None
41
43
 
42
44
  LinearNoBias = partial(nn.Linear, bias = False)
43
45
 
@@ -136,6 +138,15 @@ def or_reduce(masks):
136
138
  head = head | rest
137
139
  return head
138
140
 
141
+ # entropy
142
+
143
+ def calc_entropy(
144
+ t: Tensor,
145
+ is_prob = False
146
+ ):
147
+ prob = t.softmax(dim = -1) if not is_prob else t
148
+ return -(prob * log(prob)).sum(dim = -1)
149
+
139
150
  # auxiliary loss helpers
140
151
 
141
152
  def calc_z_loss(
@@ -1126,6 +1137,7 @@ class Attention(Module):
1126
1137
  sigmoid = False,
1127
1138
  selective = False,
1128
1139
  custom_attn_fn: Callable | None = None,
1140
+ hybrid_module: Module | None = None,
1129
1141
  one_kv_head = False,
1130
1142
  kv_heads = None,
1131
1143
  shared_kv = False,
@@ -1325,6 +1337,10 @@ class Attention(Module):
1325
1337
 
1326
1338
  self.attn_on_attn = on_attn
1327
1339
 
1340
+ # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1341
+
1342
+ self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1343
+
1328
1344
  # output dimension by default same as input, but can be overridden
1329
1345
 
1330
1346
  dim_out = default(dim_out, dim)
@@ -1397,6 +1413,16 @@ class Attention(Module):
1397
1413
  value_residual_mix = self.to_value_residual_mix(q_input)
1398
1414
  v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
1399
1415
 
1416
+ # qk normalization
1417
+
1418
+ if self.qk_norm:
1419
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1420
+ q, k = map(qk_l2norm, (q, k))
1421
+ scale = self.qk_norm_scale
1422
+
1423
+ q = q * self.qk_norm_q_scale
1424
+ k = k * self.qk_norm_k_scale
1425
+
1400
1426
  # take care of caching
1401
1427
 
1402
1428
  if exists(cache):
@@ -1417,14 +1443,6 @@ class Attention(Module):
1417
1443
  mem_len = mem.shape[-2] if exists(mem) else 0
1418
1444
  cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1419
1445
 
1420
- if self.qk_norm:
1421
- qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1422
- q, k = map(qk_l2norm, (q, k))
1423
- scale = self.qk_norm_scale
1424
-
1425
- q = q * self.qk_norm_q_scale
1426
- k = k * self.qk_norm_k_scale
1427
-
1428
1446
  if exists(rotary_pos_emb):
1429
1447
  freqs, xpos_scale = rotary_pos_emb
1430
1448
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
@@ -1571,6 +1589,12 @@ class Attention(Module):
1571
1589
 
1572
1590
  out = rearrange(out, 'b h n d -> b n (h d)')
1573
1591
 
1592
+ # hybrid module
1593
+
1594
+ if exists(self.hybrid_module):
1595
+ hybrid_out, _ = self.hybrid_module(x)
1596
+ out = 0.5 * (out + hybrid_out)
1597
+
1574
1598
  # alphafold2 styled gating of the values
1575
1599
 
1576
1600
  if exists(self.to_v_gate):
@@ -1993,7 +2017,7 @@ class AttentionLayers(Module):
1993
2017
 
1994
2018
  # determine whether can cache kv
1995
2019
 
1996
- self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ])
2020
+ self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
1997
2021
 
1998
2022
  def forward(
1999
2023
  self,
@@ -2592,6 +2616,7 @@ class TransformerWrapper(Module):
2592
2616
  return_embeddings = False,
2593
2617
  return_logits_and_embeddings = False,
2594
2618
  return_intermediates = False,
2619
+ return_logit_entropies = False,
2595
2620
  mask = None,
2596
2621
  return_mems = False,
2597
2622
  return_attn = False,
@@ -2809,6 +2834,12 @@ class TransformerWrapper(Module):
2809
2834
  else:
2810
2835
  out = logits
2811
2836
 
2837
+ # logit entropies
2838
+
2839
+ if return_logit_entropies:
2840
+ intermediates.logit_entropies = calc_entropy(logits)
2841
+ return_intermediates = True
2842
+
2812
2843
  # aux loss
2813
2844
 
2814
2845
  if return_attn_z_loss:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.4
3
+ Version: 1.44.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
File without changes