x-transformers 1.32.10__tar.gz → 1.32.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.32.10/x_transformers.egg-info → x_transformers-1.32.12}/PKG-INFO +1 -1
  2. {x_transformers-1.32.10 → x_transformers-1.32.12}/setup.py +1 -1
  3. {x_transformers-1.32.10 → x_transformers-1.32.12}/tests/test_x_transformers.py +21 -2
  4. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/x_transformers.py +20 -4
  5. {x_transformers-1.32.10 → x_transformers-1.32.12/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.32.10 → x_transformers-1.32.12}/LICENSE +0 -0
  7. {x_transformers-1.32.10 → x_transformers-1.32.12}/README.md +0 -0
  8. {x_transformers-1.32.10 → x_transformers-1.32.12}/setup.cfg +0 -0
  9. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/nonautoregressive_wrapper.py +0 -0
  16. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.32.10 → x_transformers-1.32.12}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.10
3
+ Version: 1.32.12
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.32.10',
6
+ version = '1.32.12',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -159,7 +159,6 @@ def test_multiple_input_embeds():
159
159
  assert embed.shape == (2, 1024, 128)
160
160
 
161
161
  def test_average_pool_embed():
162
-
163
162
  model = TransformerWrapper(
164
163
  num_tokens = 20000,
165
164
  max_seq_len = 1024,
@@ -180,7 +179,6 @@ def test_average_pool_embed():
180
179
  assert logits.shape == (2, 20000)
181
180
 
182
181
  def test_cls_token():
183
-
184
182
  model = TransformerWrapper(
185
183
  num_tokens = 20000,
186
184
  max_seq_len = 1024,
@@ -199,3 +197,24 @@ def test_cls_token():
199
197
  logits = model(x, mask = mask)
200
198
 
201
199
  assert logits.shape == (2, 20000)
200
+
201
+ def test_squeeze_logit_dim_one():
202
+ model = TransformerWrapper(
203
+ num_tokens = 20000,
204
+ max_seq_len = 1024,
205
+ logits_dim = 1,
206
+ average_pool_embed = True,
207
+ squeeze_out_last_dim = True,
208
+ attn_layers = Encoder(
209
+ dim = 128,
210
+ depth = 6,
211
+ heads = 8
212
+ )
213
+ )
214
+
215
+ x = torch.randint(0, 20000, (2, 1024))
216
+ mask = torch.randint(0, 2, (2, 1024)).bool()
217
+
218
+ logits = model(x, mask = mask)
219
+
220
+ assert logits.shape == (2,)
@@ -45,15 +45,18 @@ def default(val, d):
45
45
  return val
46
46
  return d() if callable(d) else d
47
47
 
48
- def cast_tuple(val, depth):
48
+ def first(it):
49
+ return it[0]
50
+
51
+ def is_empty(x):
52
+ return len(x) == 0
53
+
54
+ def cast_tuple(val, depth = 1):
49
55
  return val if isinstance(val, tuple) else (val,) * depth
50
56
 
51
57
  def divisible_by(num, den):
52
58
  return (num % den) == 0
53
59
 
54
- def is_empty(x):
55
- return len(x) == 0
56
-
57
60
  def maybe(fn):
58
61
  @wraps(fn)
59
62
  def inner(x, *args, **kwargs):
@@ -1922,6 +1925,7 @@ class TransformerWrapper(Module):
1922
1925
  attn_z_loss_weight = 1e-4,
1923
1926
  average_pool_embed = False,
1924
1927
  use_cls_token = False,
1928
+ squeeze_out_last_dim = False
1925
1929
  ):
1926
1930
  super().__init__()
1927
1931
 
@@ -2006,6 +2010,10 @@ class TransformerWrapper(Module):
2006
2010
 
2007
2011
  self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
2008
2012
 
2013
+ # squeeze out last dimension if possible
2014
+
2015
+ self.squeeze_out_last_dim = squeeze_out_last_dim
2016
+
2009
2017
  # whether can do cached kv decoding
2010
2018
 
2011
2019
  self.can_cache_kv = self.num_memory_tokens == 0
@@ -2173,6 +2181,14 @@ class TransformerWrapper(Module):
2173
2181
  else:
2174
2182
  logits = self.to_logits(x)
2175
2183
 
2184
+ # maybe squeeze out last dimension of logits
2185
+
2186
+ if self.squeeze_out_last_dim:
2187
+ logits = tuple((rearrange(t, '... 1 -> ...') if t.shape[-1] == 1 else t) for t in cast_tuple(logits))
2188
+
2189
+ if not self.has_multiple_heads:
2190
+ logits = first(logits)
2191
+
2176
2192
  # different returns
2177
2193
 
2178
2194
  if return_logits_and_embeddings:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.10
3
+ Version: 1.32.12
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang