x-transformers 1.32.7__py3-none-any.whl → 1.32.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,7 @@ from x_transformers.x_transformers import (
9
9
  ScaledSinusoidalEmbedding,
10
10
  AbsolutePositionalEmbedding,
11
11
  LayerNorm,
12
+ masked_mean,
12
13
  always,
13
14
  pad_at_dim
14
15
  )
@@ -39,7 +40,8 @@ class ContinuousTransformerWrapper(nn.Module):
39
40
  post_emb_norm = False,
40
41
  emb_dropout = 0.,
41
42
  use_abs_pos_emb = True,
42
- scaled_sinu_pos_emb = False
43
+ scaled_sinu_pos_emb = False,
44
+ average_pool_embed = False
43
45
  ):
44
46
  super().__init__()
45
47
  dim = attn_layers.dim
@@ -72,6 +74,10 @@ class ContinuousTransformerWrapper(nn.Module):
72
74
 
73
75
  self.attn_layers = attn_layers
74
76
 
77
+ # average pool
78
+
79
+ self.average_pool_embed = average_pool_embed
80
+
75
81
  # project in and out
76
82
 
77
83
  self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
@@ -92,7 +98,7 @@ class ContinuousTransformerWrapper(nn.Module):
92
98
  prepend_mask = None,
93
99
  **kwargs
94
100
  ):
95
- batch, seq, device = *x.shape[:2], x.device
101
+ batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
96
102
 
97
103
  x = self.project_in(x)
98
104
  x = x + self.pos_emb(x, pos = pos)
@@ -136,6 +142,11 @@ class ContinuousTransformerWrapper(nn.Module):
136
142
  m, x = unpack(x, mem_ps, 'b * d')
137
143
  intermediates.memory_tokens = m
138
144
 
145
+ if self.average_pool_embed:
146
+ x = masked_mean(x, mask = orig_mask)
147
+
148
+ # maybe linear project out
149
+
139
150
  out = self.project_out(x) if not return_embeddings else x
140
151
 
141
152
  if return_intermediates:
@@ -99,6 +99,17 @@ def l2norm(t, groups = 1):
99
99
  def softclamp(t, value):
100
100
  return (t / value).tanh() * value
101
101
 
102
+ def masked_mean(t, mask = None, dim = 1):
103
+ if not exists(mask):
104
+ return t.mean(dim = dim)
105
+
106
+ dims_append = (1,) * (t.ndim - mask.ndim)
107
+ mask = mask.reshape(*mask.shape, *dims_append)
108
+
109
+ num = (t * mask).sum(dim = dim)
110
+ den = mask.sum(dim = dim).clamp(min = 1.)
111
+ return num / den
112
+
102
113
  def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
103
114
  if pad == (0, 0):
104
115
  return t
@@ -1666,6 +1677,8 @@ class AttentionLayers(Module):
1666
1677
 
1667
1678
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
1668
1679
 
1680
+ first_skip = None
1681
+
1669
1682
  # go through the attention and feedforward layers
1670
1683
 
1671
1684
  for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
@@ -1674,6 +1687,9 @@ class AttentionLayers(Module):
1674
1687
  if self.training and layer_dropout > 0. and random() < layer_dropout:
1675
1688
  continue
1676
1689
 
1690
+ if ind == 1:
1691
+ first_skip = x.clone()
1692
+
1677
1693
  if layer_type == 'a':
1678
1694
  if return_hiddens:
1679
1695
  hiddens.append(x)
@@ -1909,6 +1925,7 @@ class TransformerWrapper(Module):
1909
1925
  l2norm_embed = False,
1910
1926
  emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1911
1927
  attn_z_loss_weight = 1e-4,
1928
+ average_pool_embed = False
1912
1929
  ):
1913
1930
  super().__init__()
1914
1931
 
@@ -1954,6 +1971,10 @@ class TransformerWrapper(Module):
1954
1971
 
1955
1972
  assert num_output_heads > 0
1956
1973
 
1974
+ # whether to average pool the embed (`global average pool`)
1975
+
1976
+ self.average_pool_embed = average_pool_embed
1977
+
1957
1978
  # output head, usually to logits of num_tokens
1958
1979
 
1959
1980
  logits_dim = default(logits_dim, num_tokens)
@@ -2015,7 +2036,7 @@ class TransformerWrapper(Module):
2015
2036
  cache: LayerIntermediates | None = None,
2016
2037
  **kwargs
2017
2038
  ):
2018
- b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
2039
+ b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2019
2040
 
2020
2041
  return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2021
2042
  return_embeddings = return_embeddings | (not exists(self.to_logits))
@@ -2118,6 +2139,11 @@ class TransformerWrapper(Module):
2118
2139
 
2119
2140
  x = x[:, :n]
2120
2141
 
2142
+ # global average pool
2143
+
2144
+ if self.average_pool_embed:
2145
+ x = masked_mean(x, mask = orig_mask, dim = 1)
2146
+
2121
2147
  # projecting to logits
2122
2148
 
2123
2149
  if not return_embeddings:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.7
3
+ Version: 1.32.8
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
2
  x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
- x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
4
+ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=5DHbYgx0RPg9QHvfBs2qHWrtn4Jji-q0d1MRBbcRPR8,76696
8
+ x_transformers/x_transformers.py,sha256=cq364zjUVvGEeFxdu703yI2tp1VhpxTIpLTgMshHpzI,77392
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.32.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.32.7.dist-info/METADATA,sha256=25J9CJ3OxsR_SZkvubPhyjSN-NmvU_yVVQHNMFzoKVg,661
13
- x_transformers-1.32.7.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
- x_transformers-1.32.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.32.7.dist-info/RECORD,,
11
+ x_transformers-1.32.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.32.8.dist-info/METADATA,sha256=tEHQVjqqXKQ2eSd-j5yrmfubkjuQZga_sT_5XgHanQo,661
13
+ x_transformers-1.32.8.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
+ x_transformers-1.32.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.32.8.dist-info/RECORD,,