x-transformers 1.40.9__tar.gz → 1.40.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {x_transformers-1.40.9/x_transformers.egg-info → x_transformers-1.40.11}/PKG-INFO +1 -1
  2. {x_transformers-1.40.9 → x_transformers-1.40.11}/setup.py +1 -1
  3. {x_transformers-1.40.9 → x_transformers-1.40.11}/tests/test_x_transformers.py +9 -2
  4. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers/x_transformers.py +7 -4
  5. {x_transformers-1.40.9 → x_transformers-1.40.11/x_transformers.egg-info}/PKG-INFO +1 -1
  6. {x_transformers-1.40.9 → x_transformers-1.40.11}/LICENSE +0 -0
  7. {x_transformers-1.40.9 → x_transformers-1.40.11}/README.md +0 -0
  8. {x_transformers-1.40.9 → x_transformers-1.40.11}/setup.cfg +0 -0
  9. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers/__init__.py +0 -0
  10. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers/attend.py +0 -0
  11. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers/autoregressive_wrapper.py +0 -0
  12. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers/continuous.py +0 -0
  13. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers/dpo.py +0 -0
  14. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers/multi_input.py +0 -0
  15. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers/nonautoregressive_wrapper.py +0 -0
  16. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  17. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers/xval.py +0 -0
  18. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers.egg-info/SOURCES.txt +0 -0
  19. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers.egg-info/dependency_links.txt +0 -0
  20. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers.egg-info/requires.txt +0 -0
  21. {x_transformers-1.40.9 → x_transformers-1.40.11}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.9
3
+ Version: 1.40.11
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.40.9',
6
+ version = '1.40.11',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -179,12 +179,14 @@ def test_average_pool_embed():
179
179
 
180
180
  assert logits.shape == (2, 20000)
181
181
 
182
- def test_cls_token():
182
+ @pytest.mark.parametrize('num_cls_tokens', (1, 2))
183
+ def test_cls_token(num_cls_tokens):
183
184
  model = TransformerWrapper(
184
185
  num_tokens = 20000,
185
186
  max_seq_len = 1024,
186
187
  num_memory_tokens = 2,
187
188
  use_cls_token = True,
189
+ num_cls_tokens=num_cls_tokens,
188
190
  attn_layers = Encoder(
189
191
  dim = 128,
190
192
  depth = 6,
@@ -197,7 +199,12 @@ def test_cls_token():
197
199
 
198
200
  logits = model(x, mask = mask)
199
201
 
200
- assert logits.shape == (2, 20000)
202
+ if num_cls_tokens == 1:
203
+ expected_shape = (2, 20000)
204
+ else:
205
+ expected_shape = (2, num_cls_tokens, 20000)
206
+
207
+ assert logits.shape == expected_shape
201
208
 
202
209
  def test_squeeze_logit_dim_one():
203
210
  model = TransformerWrapper(
@@ -1019,7 +1019,7 @@ class Attention(Module):
1019
1019
  self.qk_norm_q_scale = self.qk_norm_k_scale = 1
1020
1020
  if qk_norm and qk_norm_dim_scale:
1021
1021
  self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1022
- self.qk_norm_k_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1022
+ self.qk_norm_k_scale = nn.Parameter(torch.ones(kv_heads, 1, dim_head))
1023
1023
 
1024
1024
  assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
1025
1025
  assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
@@ -2104,6 +2104,7 @@ class TransformerWrapper(Module):
2104
2104
  attn_z_loss_weight = 1e-4,
2105
2105
  average_pool_embed = False,
2106
2106
  use_cls_token = False,
2107
+ num_cls_tokens = 1,
2107
2108
  squeeze_out_last_dim = False,
2108
2109
  token_emb: TokenEmbedding | None = None,
2109
2110
  mixture_of_softmax = False,
@@ -2116,6 +2117,7 @@ class TransformerWrapper(Module):
2116
2117
  emb_dim = default(emb_dim, dim)
2117
2118
  self.emb_dim = emb_dim
2118
2119
  self.num_tokens = num_tokens
2120
+ self.num_cls_tokens = num_cls_tokens
2119
2121
 
2120
2122
  self.max_seq_len = max_seq_len
2121
2123
  self.max_mem_len = max_mem_len
@@ -2172,7 +2174,7 @@ class TransformerWrapper(Module):
2172
2174
  self.cls_token = None
2173
2175
 
2174
2176
  if use_cls_token:
2175
- self.cls_token = nn.Parameter(torch.zeros(dim))
2177
+ self.cls_token = nn.Parameter(torch.zeros(num_cls_tokens, dim))
2176
2178
  nn.init.normal_(self.cls_token, std = 0.02)
2177
2179
 
2178
2180
  # whether to average pool the embed (`global average pool`)
@@ -2329,11 +2331,11 @@ class TransformerWrapper(Module):
2329
2331
  # maybe cls token
2330
2332
 
2331
2333
  if exists(self.cls_token):
2332
- cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
2334
+ cls_tokens = repeat(self.cls_token, '... -> b ...', b = b)
2333
2335
  x, cls_packed_shape = pack([cls_tokens, x], 'b * d')
2334
2336
 
2335
2337
  if exists(mask):
2336
- mask = F.pad(mask, (1, 0), value = True)
2338
+ mask = F.pad(mask, (self.num_cls_tokens, 0), value = True)
2337
2339
 
2338
2340
  # maybe memory / register tokens
2339
2341
 
@@ -2415,6 +2417,7 @@ class TransformerWrapper(Module):
2415
2417
 
2416
2418
  if exists(self.cls_token):
2417
2419
  x, _ = unpack(x, cls_packed_shape, 'b * d')
2420
+ x = x.squeeze(1) # Remove sequence dimension if num_cls_tokens=1 to keep previous behavior
2418
2421
 
2419
2422
  # handle expansion to mixture if needed (for mixture of softmax)
2420
2423
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.9
3
+ Version: 1.40.11
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang