x-transformers 2.3.0__tar.gz → 2.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {x_transformers-2.3.0 → x_transformers-2.3.2}/PKG-INFO +1 -1
  2. {x_transformers-2.3.0 → x_transformers-2.3.2}/pyproject.toml +1 -1
  3. {x_transformers-2.3.0 → x_transformers-2.3.2}/tests/test_x_transformers.py +4 -2
  4. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/continuous.py +19 -3
  5. {x_transformers-2.3.0 → x_transformers-2.3.2}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.3.0 → x_transformers-2.3.2}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.3.0 → x_transformers-2.3.2}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.3.0 → x_transformers-2.3.2}/.gitignore +0 -0
  9. {x_transformers-2.3.0 → x_transformers-2.3.2}/LICENSE +0 -0
  10. {x_transformers-2.3.0 → x_transformers-2.3.2}/README.md +0 -0
  11. {x_transformers-2.3.0 → x_transformers-2.3.2}/data/README.md +0 -0
  12. {x_transformers-2.3.0 → x_transformers-2.3.2}/data/enwik8.gz +0 -0
  13. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/all-attention.png +0 -0
  14. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/deepnorm.png +0 -0
  17. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/fcm.png +0 -0
  23. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/ffglu.png +0 -0
  24. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/flash-attention.png +0 -0
  25. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/gate_values.png +0 -0
  26. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/gating.png +0 -0
  27. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/macaron-1.png +0 -0
  29. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/macaron-2.png +0 -0
  30. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/normformer.png +0 -0
  32. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/pia.png +0 -0
  33. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/resi_dual.png +0 -0
  35. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/residual_attn.png +0 -0
  36. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/rezero.png +0 -0
  37. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/rotary.png +0 -0
  38. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/sandwich.png +0 -0
  40. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/scalenorm.png +0 -0
  42. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/talking-heads.png +0 -0
  43. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/topk-attention.png +0 -0
  44. {x_transformers-2.3.0 → x_transformers-2.3.2}/images/xval.png +0 -0
  45. {x_transformers-2.3.0 → x_transformers-2.3.2}/train_belief_state.py +0 -0
  46. {x_transformers-2.3.0 → x_transformers-2.3.2}/train_copy.py +0 -0
  47. {x_transformers-2.3.0 → x_transformers-2.3.2}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.3.0 → x_transformers-2.3.2}/train_enwik8.py +0 -0
  49. {x_transformers-2.3.0 → x_transformers-2.3.2}/train_length_extrapolate.py +0 -0
  50. {x_transformers-2.3.0 → x_transformers-2.3.2}/train_parity.py +0 -0
  51. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/__init__.py +0 -0
  52. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/attend.py +0 -0
  53. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/autoregressive_wrapper.py +0 -0
  54. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/belief_state_wrapper.py +0 -0
  55. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/dpo.py +0 -0
  56. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/entropy_based_tokenizer.py +0 -0
  57. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/multi_input.py +0 -0
  58. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/neo_mlp.py +0 -0
  59. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
  60. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/x_transformers.py +0 -0
  61. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  62. {x_transformers-2.3.0 → x_transformers-2.3.2}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.0
3
+ Version: 2.3.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.3.0"
3
+ version = "2.3.2"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -849,8 +849,10 @@ def test_custom_ff_activation():
849
849
  assert logits.shape == (2, 1024, 20000)
850
850
 
851
851
  @pytest.mark.parametrize('probabilistic', (False, True))
852
+ @pytest.mark.parametrize('cache_kv', (False, True))
852
853
  def test_continuous(
853
- probabilistic
854
+ probabilistic,
855
+ cache_kv
854
856
  ):
855
857
  from x_transformers import (
856
858
  ContinuousTransformerWrapper,
@@ -887,5 +889,5 @@ def test_continuous(
887
889
  # then generate
888
890
 
889
891
  start_emb = torch.randn(1, 777)
890
- generated = model.generate(start_emb, 17) # (17, 777)
892
+ generated = model.generate(start_emb, 17, cache_kv = cache_kv) # (17, 777)
891
893
  assert generated.shape == (17, 777)
@@ -10,6 +10,7 @@ import einx
10
10
  from einops import rearrange, reduce, pack, repeat, unpack
11
11
 
12
12
  from x_transformers.x_transformers import (
13
+ Attention,
13
14
  AttentionLayers,
14
15
  ScaledSinusoidalEmbedding,
15
16
  AbsolutePositionalEmbedding,
@@ -111,6 +112,10 @@ class ContinuousTransformerWrapper(Module):
111
112
 
112
113
  self.project_out = nn.Linear(dim, dim_out * (2 if probabilistic else 1), bias = False) if exists(dim_out) else nn.Identity()
113
114
 
115
+ # can cache kv
116
+
117
+ self.can_cache_kv = all([module.can_cache_kv for module in self.modules() if isinstance(module, Attention)])
118
+
114
119
  def forward(
115
120
  self,
116
121
  x,
@@ -180,7 +185,7 @@ class ContinuousTransformerWrapper(Module):
180
185
  if not return_embeddings and self.probabilistic:
181
186
  mean, log_var = rearrange(out, '... (d mean_log_var) -> mean_log_var ... d', mean_log_var = 2)
182
187
  variance = log_var.exp()
183
- return stack((mean, variance))
188
+ out = stack((mean, variance))
184
189
 
185
190
  if return_intermediates:
186
191
  return out, intermediates
@@ -222,9 +227,13 @@ class ContinuousAutoregressiveWrapper(Module):
222
227
  self,
223
228
  start_tokens,
224
229
  seq_len,
230
+ temperature = 1.,
231
+ cache_kv = True,
225
232
  **kwargs
226
233
  ):
234
+ should_cache_kv = cache_kv and self.net.can_cache_kv
227
235
  device = start_tokens.device
236
+
228
237
  was_training = self.net.training
229
238
  num_dims = len(start_tokens.shape)
230
239
 
@@ -238,17 +247,24 @@ class ContinuousAutoregressiveWrapper(Module):
238
247
  self.net.eval()
239
248
  out = start_tokens
240
249
 
250
+ cache = None
251
+
241
252
  for _ in range(seq_len):
242
253
  x = out[:, -self.max_seq_len:]
243
254
 
244
- last_output = self.net(x, **kwargs)[..., -1:, :]
255
+ net_out, new_cache = self.net(x, cache = cache, return_intermediates = True, **kwargs)
256
+
257
+ last_output = net_out[..., -1:, :]
245
258
 
246
259
  if self.probabilistic:
247
260
  mean, var = last_output
248
- last_output = torch.normal(mean, var)
261
+ last_output = torch.normal(mean, var * temperature)
249
262
 
250
263
  out = cat((out, last_output), dim = -2)
251
264
 
265
+ if should_cache_kv:
266
+ cache = new_cache
267
+
252
268
  out = out[:, t:]
253
269
 
254
270
  if num_dims == 2:
File without changes
File without changes