x-transformers 1.32.7__py3-none-any.whl → 1.32.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,7 @@ from x_transformers.x_transformers import (
9
9
  ScaledSinusoidalEmbedding,
10
10
  AbsolutePositionalEmbedding,
11
11
  LayerNorm,
12
+ masked_mean,
12
13
  always,
13
14
  pad_at_dim
14
15
  )
@@ -39,7 +40,8 @@ class ContinuousTransformerWrapper(nn.Module):
39
40
  post_emb_norm = False,
40
41
  emb_dropout = 0.,
41
42
  use_abs_pos_emb = True,
42
- scaled_sinu_pos_emb = False
43
+ scaled_sinu_pos_emb = False,
44
+ average_pool_embed = False
43
45
  ):
44
46
  super().__init__()
45
47
  dim = attn_layers.dim
@@ -72,6 +74,10 @@ class ContinuousTransformerWrapper(nn.Module):
72
74
 
73
75
  self.attn_layers = attn_layers
74
76
 
77
+ # average pool
78
+
79
+ self.average_pool_embed = average_pool_embed
80
+
75
81
  # project in and out
76
82
 
77
83
  self.project_in = nn.Linear(dim_in, dim, bias = False) if exists(dim_in) else nn.Identity()
@@ -92,7 +98,7 @@ class ContinuousTransformerWrapper(nn.Module):
92
98
  prepend_mask = None,
93
99
  **kwargs
94
100
  ):
95
- batch, seq, device = *x.shape[:2], x.device
101
+ batch, seq, orig_mask, device = *x.shape[:2], mask, x.device
96
102
 
97
103
  x = self.project_in(x)
98
104
  x = x + self.pos_emb(x, pos = pos)
@@ -136,6 +142,11 @@ class ContinuousTransformerWrapper(nn.Module):
136
142
  m, x = unpack(x, mem_ps, 'b * d')
137
143
  intermediates.memory_tokens = m
138
144
 
145
+ if self.average_pool_embed:
146
+ x = masked_mean(x, mask = orig_mask)
147
+
148
+ # maybe linear project out
149
+
139
150
  out = self.project_out(x) if not return_embeddings else x
140
151
 
141
152
  if return_intermediates:
@@ -99,6 +99,17 @@ def l2norm(t, groups = 1):
99
99
  def softclamp(t, value):
100
100
  return (t / value).tanh() * value
101
101
 
102
+ def masked_mean(t, mask = None, dim = 1):
103
+ if not exists(mask):
104
+ return t.mean(dim = dim)
105
+
106
+ dims_append = (1,) * (t.ndim - mask.ndim)
107
+ mask = mask.reshape(*mask.shape, *dims_append)
108
+
109
+ num = (t * mask).sum(dim = dim)
110
+ den = mask.sum(dim = dim).clamp(min = 1.)
111
+ return num / den
112
+
102
113
  def pad_at_dim(t, pad: Tuple[int, int], dim = -1, value = 0.):
103
114
  if pad == (0, 0):
104
115
  return t
@@ -1909,6 +1920,7 @@ class TransformerWrapper(Module):
1909
1920
  l2norm_embed = False,
1910
1921
  emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1911
1922
  attn_z_loss_weight = 1e-4,
1923
+ average_pool_embed = False
1912
1924
  ):
1913
1925
  super().__init__()
1914
1926
 
@@ -1954,6 +1966,10 @@ class TransformerWrapper(Module):
1954
1966
 
1955
1967
  assert num_output_heads > 0
1956
1968
 
1969
+ # whether to average pool the embed (`global average pool`)
1970
+
1971
+ self.average_pool_embed = average_pool_embed
1972
+
1957
1973
  # output head, usually to logits of num_tokens
1958
1974
 
1959
1975
  logits_dim = default(logits_dim, num_tokens)
@@ -2015,7 +2031,7 @@ class TransformerWrapper(Module):
2015
2031
  cache: LayerIntermediates | None = None,
2016
2032
  **kwargs
2017
2033
  ):
2018
- b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
2034
+ b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2019
2035
 
2020
2036
  return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2021
2037
  return_embeddings = return_embeddings | (not exists(self.to_logits))
@@ -2118,6 +2134,11 @@ class TransformerWrapper(Module):
2118
2134
 
2119
2135
  x = x[:, :n]
2120
2136
 
2137
+ # global average pool
2138
+
2139
+ if self.average_pool_embed:
2140
+ x = masked_mean(x, mask = orig_mask, dim = 1)
2141
+
2121
2142
  # projecting to logits
2122
2143
 
2123
2144
  if not return_embeddings:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.7
3
+ Version: 1.32.9
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,15 +1,15 @@
1
1
  x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
2
2
  x_transformers/attend.py,sha256=MI-m91wumBFqFqr_KK9MLgsLk_vPeaVbFMyDr_mWdmY,11349
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
- x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
4
+ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=5DHbYgx0RPg9QHvfBs2qHWrtn4Jji-q0d1MRBbcRPR8,76696
8
+ x_transformers/x_transformers.py,sha256=8558TPHcDxWUvJYz01EdeyZl0lkHB14bzlsEMwSMPyw,77300
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.32.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.32.7.dist-info/METADATA,sha256=25J9CJ3OxsR_SZkvubPhyjSN-NmvU_yVVQHNMFzoKVg,661
13
- x_transformers-1.32.7.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
- x_transformers-1.32.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.32.7.dist-info/RECORD,,
11
+ x_transformers-1.32.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.32.9.dist-info/METADATA,sha256=-GidCdPhcKpZ49ElbeuJUPko5LZZP_vyEodaN_P3g48,661
13
+ x_transformers-1.32.9.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
+ x_transformers-1.32.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.32.9.dist-info/RECORD,,