x-transformers 1.43.5__py3-none-any.whl → 1.44.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,14 +2,16 @@ from __future__ import annotations
2
2
  from typing import Callable
3
3
 
4
4
  import math
5
+ from copy import deepcopy
5
6
  from random import random, randrange
6
7
  from packaging import version
7
8
 
8
9
  import torch
10
+ from torch.amp import autocast
9
11
  import torch.nn.functional as F
10
12
  from torch import nn, einsum, Tensor
13
+ from torch.utils._pytree import tree_flatten
11
14
  from torch.nn import Module, ModuleList, ModuleDict
12
- from torch.amp import autocast
13
15
 
14
16
  from functools import partial, wraps
15
17
  from collections import namedtuple
@@ -1136,6 +1138,8 @@ class Attention(Module):
1136
1138
  sigmoid = False,
1137
1139
  selective = False,
1138
1140
  custom_attn_fn: Callable | None = None,
1141
+ hybrid_module: Module | None = None,
1142
+ hybrid_mask_kwarg: str | None = None,
1139
1143
  one_kv_head = False,
1140
1144
  kv_heads = None,
1141
1145
  shared_kv = False,
@@ -1335,6 +1339,12 @@ class Attention(Module):
1335
1339
 
1336
1340
  self.attn_on_attn = on_attn
1337
1341
 
1342
+ # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1343
+
1344
+ self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1345
+
1346
+ self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1347
+
1338
1348
  # output dimension by default same as input, but can be overridden
1339
1349
 
1340
1350
  dim_out = default(dim_out, dim)
@@ -1407,6 +1417,16 @@ class Attention(Module):
1407
1417
  value_residual_mix = self.to_value_residual_mix(q_input)
1408
1418
  v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
1409
1419
 
1420
+ # qk normalization
1421
+
1422
+ if self.qk_norm:
1423
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1424
+ q, k = map(qk_l2norm, (q, k))
1425
+ scale = self.qk_norm_scale
1426
+
1427
+ q = q * self.qk_norm_q_scale
1428
+ k = k * self.qk_norm_k_scale
1429
+
1410
1430
  # take care of caching
1411
1431
 
1412
1432
  if exists(cache):
@@ -1427,14 +1447,6 @@ class Attention(Module):
1427
1447
  mem_len = mem.shape[-2] if exists(mem) else 0
1428
1448
  cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1429
1449
 
1430
- if self.qk_norm:
1431
- qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1432
- q, k = map(qk_l2norm, (q, k))
1433
- scale = self.qk_norm_scale
1434
-
1435
- q = q * self.qk_norm_q_scale
1436
- k = k * self.qk_norm_k_scale
1437
-
1438
1450
  if exists(rotary_pos_emb):
1439
1451
  freqs, xpos_scale = rotary_pos_emb
1440
1452
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
@@ -1581,6 +1593,26 @@ class Attention(Module):
1581
1593
 
1582
1594
  out = rearrange(out, 'b h n d -> b n (h d)')
1583
1595
 
1596
+ # hybrid module
1597
+
1598
+ if exists(self.hybrid_module):
1599
+
1600
+ # hybrid input
1601
+
1602
+ hybrid_forward_kwargs = dict()
1603
+
1604
+ if not self.causal and exists(self.hybrid_mask_kwarg):
1605
+ hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
1606
+
1607
+ # hybrid forward
1608
+
1609
+ hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
1610
+
1611
+ # handle hybrid out
1612
+
1613
+ (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outs)
1614
+ out = 0.5 * (out + hybrid_out)
1615
+
1584
1616
  # alphafold2 styled gating of the values
1585
1617
 
1586
1618
  if exists(self.to_v_gate):
@@ -2003,7 +2035,7 @@ class AttentionLayers(Module):
2003
2035
 
2004
2036
  # determine whether can cache kv
2005
2037
 
2006
- self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ])
2038
+ self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
2007
2039
 
2008
2040
  def forward(
2009
2041
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.5
3
+ Version: 1.44.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=z3RH6jvjcxaAVfZoCS0HWrE0Gy55-eXOKtzRt7rRRIw,100811
9
+ x_transformers/x_transformers.py,sha256=yjtB4kV4N9mzHdliIM9MjyA6SoMtvpzc2Z4iU6R9_Uc,101859
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.43.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.43.5.dist-info/METADATA,sha256=crd2xA3NbodKbOz9xY4D1j3XDbTmaY9vwkZZJOGoEw4,738
14
- x_transformers-1.43.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.43.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.43.5.dist-info/RECORD,,
12
+ x_transformers-1.44.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.44.1.dist-info/METADATA,sha256=Zw_Rscb4vNZxlKosWSHSQy4EsICF45U58K0hipxydpQ,738
14
+ x_transformers-1.44.1.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
15
+ x_transformers-1.44.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.44.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.7.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5