x-transformers 1.43.5__py3-none-any.whl → 1.44.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -2,14 +2,16 @@ from __future__ import annotations
2
2
  from typing import Callable
3
3
 
4
4
  import math
5
+ from copy import deepcopy
5
6
  from random import random, randrange
6
7
  from packaging import version
7
8
 
8
9
  import torch
10
+ from torch.amp import autocast
9
11
  import torch.nn.functional as F
10
12
  from torch import nn, einsum, Tensor
13
+ from torch.utils._pytree import tree_flatten
11
14
  from torch.nn import Module, ModuleList, ModuleDict
12
- from torch.amp import autocast
13
15
 
14
16
  from functools import partial, wraps
15
17
  from collections import namedtuple
@@ -1136,6 +1138,8 @@ class Attention(Module):
1136
1138
  sigmoid = False,
1137
1139
  selective = False,
1138
1140
  custom_attn_fn: Callable | None = None,
1141
+ hybrid_module: Module | None = None,
1142
+ hybrid_mask_kwarg: str | None = None,
1139
1143
  one_kv_head = False,
1140
1144
  kv_heads = None,
1141
1145
  shared_kv = False,
@@ -1335,6 +1339,12 @@ class Attention(Module):
1335
1339
 
1336
1340
  self.attn_on_attn = on_attn
1337
1341
 
1342
+ # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1343
+
1344
+ self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1345
+
1346
+ self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1347
+
1338
1348
  # output dimension by default same as input, but can be overridden
1339
1349
 
1340
1350
  dim_out = default(dim_out, dim)
@@ -1407,6 +1417,16 @@ class Attention(Module):
1407
1417
  value_residual_mix = self.to_value_residual_mix(q_input)
1408
1418
  v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
1409
1419
 
1420
+ # qk normalization
1421
+
1422
+ if self.qk_norm:
1423
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1424
+ q, k = map(qk_l2norm, (q, k))
1425
+ scale = self.qk_norm_scale
1426
+
1427
+ q = q * self.qk_norm_q_scale
1428
+ k = k * self.qk_norm_k_scale
1429
+
1410
1430
  # take care of caching
1411
1431
 
1412
1432
  if exists(cache):
@@ -1427,14 +1447,6 @@ class Attention(Module):
1427
1447
  mem_len = mem.shape[-2] if exists(mem) else 0
1428
1448
  cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1429
1449
 
1430
- if self.qk_norm:
1431
- qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1432
- q, k = map(qk_l2norm, (q, k))
1433
- scale = self.qk_norm_scale
1434
-
1435
- q = q * self.qk_norm_q_scale
1436
- k = k * self.qk_norm_k_scale
1437
-
1438
1450
  if exists(rotary_pos_emb):
1439
1451
  freqs, xpos_scale = rotary_pos_emb
1440
1452
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
@@ -1581,6 +1593,26 @@ class Attention(Module):
1581
1593
 
1582
1594
  out = rearrange(out, 'b h n d -> b n (h d)')
1583
1595
 
1596
+ # hybrid module
1597
+
1598
+ if exists(self.hybrid_module):
1599
+
1600
+ # hybrid input
1601
+
1602
+ hybrid_forward_kwargs = dict()
1603
+
1604
+ if not self.causal and exists(self.hybrid_mask_kwarg):
1605
+ hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
1606
+
1607
+ # hybrid forward
1608
+
1609
+ hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
1610
+
1611
+ # handle hybrid out
1612
+
1613
+ (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outs)
1614
+ out = 0.5 * (out + hybrid_out)
1615
+
1584
1616
  # alphafold2 styled gating of the values
1585
1617
 
1586
1618
  if exists(self.to_v_gate):
@@ -2003,7 +2035,7 @@ class AttentionLayers(Module):
2003
2035
 
2004
2036
  # determine whether can cache kv
2005
2037
 
2006
- self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ])
2038
+ self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
2007
2039
 
2008
2040
  def forward(
2009
2041
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.5
3
+ Version: 1.44.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=z3RH6jvjcxaAVfZoCS0HWrE0Gy55-eXOKtzRt7rRRIw,100811
9
+ x_transformers/x_transformers.py,sha256=yjtB4kV4N9mzHdliIM9MjyA6SoMtvpzc2Z4iU6R9_Uc,101859
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.43.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.43.5.dist-info/METADATA,sha256=crd2xA3NbodKbOz9xY4D1j3XDbTmaY9vwkZZJOGoEw4,738
14
- x_transformers-1.43.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.43.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.43.5.dist-info/RECORD,,
12
+ x_transformers-1.44.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.44.1.dist-info/METADATA,sha256=Zw_Rscb4vNZxlKosWSHSQy4EsICF45U58K0hipxydpQ,738
14
+ x_transformers-1.44.1.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
15
+ x_transformers-1.44.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.44.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.6.0)
2
+ Generator: setuptools (75.7.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5