x-transformers 1.43.4__py3-none-any.whl → 1.44.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
  from typing import Callable
3
3
 
4
4
  import math
5
+ from copy import deepcopy
5
6
  from random import random, randrange
6
7
  from packaging import version
7
8
 
@@ -38,6 +39,7 @@ class LayerIntermediates:
38
39
  attn_z_loss: Tensor | None = None
39
40
  mems: Tensor | None = None
40
41
  memory_tokens: Tensor | None = None
42
+ logit_entropies: Tensor | None = None
41
43
 
42
44
  LinearNoBias = partial(nn.Linear, bias = False)
43
45
 
@@ -136,6 +138,15 @@ def or_reduce(masks):
136
138
  head = head | rest
137
139
  return head
138
140
 
141
+ # entropy
142
+
143
+ def calc_entropy(
144
+ t: Tensor,
145
+ is_prob = False
146
+ ):
147
+ prob = t.softmax(dim = -1) if not is_prob else t
148
+ return -(prob * log(prob)).sum(dim = -1)
149
+
139
150
  # auxiliary loss helpers
140
151
 
141
152
  def calc_z_loss(
@@ -1126,6 +1137,7 @@ class Attention(Module):
1126
1137
  sigmoid = False,
1127
1138
  selective = False,
1128
1139
  custom_attn_fn: Callable | None = None,
1140
+ hybrid_module: Module | None = None,
1129
1141
  one_kv_head = False,
1130
1142
  kv_heads = None,
1131
1143
  shared_kv = False,
@@ -1325,6 +1337,10 @@ class Attention(Module):
1325
1337
 
1326
1338
  self.attn_on_attn = on_attn
1327
1339
 
1340
+ # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1341
+
1342
+ self.hybrid_module = deepcopy(hybrid_module) if exists(hybrid_module) else None
1343
+
1328
1344
  # output dimension by default same as input, but can be overridden
1329
1345
 
1330
1346
  dim_out = default(dim_out, dim)
@@ -1397,6 +1413,16 @@ class Attention(Module):
1397
1413
  value_residual_mix = self.to_value_residual_mix(q_input)
1398
1414
  v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
1399
1415
 
1416
+ # qk normalization
1417
+
1418
+ if self.qk_norm:
1419
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1420
+ q, k = map(qk_l2norm, (q, k))
1421
+ scale = self.qk_norm_scale
1422
+
1423
+ q = q * self.qk_norm_q_scale
1424
+ k = k * self.qk_norm_k_scale
1425
+
1400
1426
  # take care of caching
1401
1427
 
1402
1428
  if exists(cache):
@@ -1417,14 +1443,6 @@ class Attention(Module):
1417
1443
  mem_len = mem.shape[-2] if exists(mem) else 0
1418
1444
  cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1419
1445
 
1420
- if self.qk_norm:
1421
- qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1422
- q, k = map(qk_l2norm, (q, k))
1423
- scale = self.qk_norm_scale
1424
-
1425
- q = q * self.qk_norm_q_scale
1426
- k = k * self.qk_norm_k_scale
1427
-
1428
1446
  if exists(rotary_pos_emb):
1429
1447
  freqs, xpos_scale = rotary_pos_emb
1430
1448
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
@@ -1571,6 +1589,12 @@ class Attention(Module):
1571
1589
 
1572
1590
  out = rearrange(out, 'b h n d -> b n (h d)')
1573
1591
 
1592
+ # hybrid module
1593
+
1594
+ if exists(self.hybrid_module):
1595
+ hybrid_out, _ = self.hybrid_module(x)
1596
+ out = 0.5 * (out + hybrid_out)
1597
+
1574
1598
  # alphafold2 styled gating of the values
1575
1599
 
1576
1600
  if exists(self.to_v_gate):
@@ -1993,7 +2017,7 @@ class AttentionLayers(Module):
1993
2017
 
1994
2018
  # determine whether can cache kv
1995
2019
 
1996
- self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention) ])
2020
+ self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
1997
2021
 
1998
2022
  def forward(
1999
2023
  self,
@@ -2592,6 +2616,7 @@ class TransformerWrapper(Module):
2592
2616
  return_embeddings = False,
2593
2617
  return_logits_and_embeddings = False,
2594
2618
  return_intermediates = False,
2619
+ return_logit_entropies = False,
2595
2620
  mask = None,
2596
2621
  return_mems = False,
2597
2622
  return_attn = False,
@@ -2809,6 +2834,12 @@ class TransformerWrapper(Module):
2809
2834
  else:
2810
2835
  out = logits
2811
2836
 
2837
+ # logit entropies
2838
+
2839
+ if return_logit_entropies:
2840
+ intermediates.logit_entropies = calc_entropy(logits)
2841
+ return_intermediates = True
2842
+
2812
2843
  # aux loss
2813
2844
 
2814
2845
  if return_attn_z_loss:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.43.4
3
+ Version: 1.44.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=gn0vRtwbjBA67T-Z8tkU-k3Xte0PaMTxZlmzdK8UsFw,100392
9
+ x_transformers/x_transformers.py,sha256=BI3RU3XFvwSNDZgoQBrFBSJ4SavJr38rOCCVgHZBTx0,101241
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.43.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.43.4.dist-info/METADATA,sha256=Nlj9DcMqnMxJH-xR4Dwd4aU1U-UQIUshpQaMDcggVes,738
14
- x_transformers-1.43.4.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.43.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.43.4.dist-info/RECORD,,
12
+ x_transformers-1.44.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.44.0.dist-info/METADATA,sha256=MNVwW_pDeKEIHRVEA1XOUNfzFmL706X7Npoh7xc3wIk,738
14
+ x_transformers-1.44.0.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.44.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.44.0.dist-info/RECORD,,