x-transformers 1.32.9__py3-none-any.whl → 1.32.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -45,15 +45,18 @@ def default(val, d):
45
45
  return val
46
46
  return d() if callable(d) else d
47
47
 
48
- def cast_tuple(val, depth):
48
+ def first(it):
49
+ return it[0]
50
+
51
+ def is_empty(x):
52
+ return len(x) == 0
53
+
54
+ def cast_tuple(val, depth = 1):
49
55
  return val if isinstance(val, tuple) else (val,) * depth
50
56
 
51
57
  def divisible_by(num, den):
52
58
  return (num % den) == 0
53
59
 
54
- def is_empty(x):
55
- return len(x) == 0
56
-
57
60
  def maybe(fn):
58
61
  @wraps(fn)
59
62
  def inner(x, *args, **kwargs):
@@ -1920,7 +1923,9 @@ class TransformerWrapper(Module):
1920
1923
  l2norm_embed = False,
1921
1924
  emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1922
1925
  attn_z_loss_weight = 1e-4,
1923
- average_pool_embed = False
1926
+ average_pool_embed = False,
1927
+ use_cls_token = False,
1928
+ squeeze_out_last_dim = False
1924
1929
  ):
1925
1930
  super().__init__()
1926
1931
 
@@ -1966,6 +1971,16 @@ class TransformerWrapper(Module):
1966
1971
 
1967
1972
  assert num_output_heads > 0
1968
1973
 
1974
+ assert at_most_one_of(average_pool_embed, use_cls_token)
1975
+
1976
+ # classic cls token from the bert days
1977
+
1978
+ self.cls_token = None
1979
+
1980
+ if use_cls_token:
1981
+ self.cls_token = nn.Parameter(torch.zeros(dim))
1982
+ nn.init.normal_(self.cls_token, std = 0.02)
1983
+
1969
1984
  # whether to average pool the embed (`global average pool`)
1970
1985
 
1971
1986
  self.average_pool_embed = average_pool_embed
@@ -1995,6 +2010,10 @@ class TransformerWrapper(Module):
1995
2010
 
1996
2011
  self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
1997
2012
 
2013
+ # squeeze out last dimension if possible
2014
+
2015
+ self.squeeze_out_last_dim = squeeze_out_last_dim
2016
+
1998
2017
  # whether can do cached kv decoding
1999
2018
 
2000
2019
  self.can_cache_kv = self.num_memory_tokens == 0
@@ -2092,7 +2111,19 @@ class TransformerWrapper(Module):
2092
2111
 
2093
2112
  x = self.project_emb(x)
2094
2113
 
2114
+ # maybe cls token
2115
+
2116
+ if exists(self.cls_token):
2117
+ cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
2118
+ x, cls_packed_shape = pack([cls_tokens, x], 'b * d')
2119
+
2120
+ if exists(mask):
2121
+ mask = F.pad(mask, (1, 0), value = True)
2122
+
2123
+ # maybe memory / register tokens
2124
+
2095
2125
  if has_memory_tokens:
2126
+ mem_seq = x.shape[-2]
2096
2127
  mem_every = self.memory_tokens_interspersed_every
2097
2128
 
2098
2129
  if exists(mem_every):
@@ -2132,13 +2163,16 @@ class TransformerWrapper(Module):
2132
2163
  if exists(mem_every):
2133
2164
  x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2134
2165
 
2135
- x = x[:, :n]
2166
+ x = x[:, :mem_seq]
2136
2167
 
2137
2168
  # global average pool
2138
2169
 
2139
2170
  if self.average_pool_embed:
2140
2171
  x = masked_mean(x, mask = orig_mask, dim = 1)
2141
2172
 
2173
+ if exists(self.cls_token):
2174
+ x, _ = unpack(x, cls_packed_shape, 'b * d')
2175
+
2142
2176
  # projecting to logits
2143
2177
 
2144
2178
  if not return_embeddings:
@@ -2147,6 +2181,14 @@ class TransformerWrapper(Module):
2147
2181
  else:
2148
2182
  logits = self.to_logits(x)
2149
2183
 
2184
+ # maybe squeeze out last dimension of logits
2185
+
2186
+ if self.squeeze_out_last_dim:
2187
+ logits = tuple(rearrange(t, '... 1 -> ...') for t in cast_tuple(logits))
2188
+
2189
+ if not self.has_multiple_heads:
2190
+ logits = first(logits)
2191
+
2150
2192
  # different returns
2151
2193
 
2152
2194
  if return_logits_and_embeddings:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.9
3
+ Version: 1.32.11
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=8558TPHcDxWUvJYz01EdeyZl0lkHB14bzlsEMwSMPyw,77300
8
+ x_transformers/x_transformers.py,sha256=Z04p-xySEkTgHSaY_060M0RlF6LnkK8ko5yTLunIYf8,78520
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.32.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.32.9.dist-info/METADATA,sha256=-GidCdPhcKpZ49ElbeuJUPko5LZZP_vyEodaN_P3g48,661
13
- x_transformers-1.32.9.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
- x_transformers-1.32.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.32.9.dist-info/RECORD,,
11
+ x_transformers-1.32.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.32.11.dist-info/METADATA,sha256=dVUc_T7ALnVvaSKHSiPKOR4Y4zAo19l0QedFQPOEbN8,662
13
+ x_transformers-1.32.11.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
+ x_transformers-1.32.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.32.11.dist-info/RECORD,,