x-transformers 2.3.17__tar.gz → 2.3.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {x_transformers-2.3.17 → x_transformers-2.3.19}/PKG-INFO +1 -1
  2. {x_transformers-2.3.17 → x_transformers-2.3.19}/pyproject.toml +1 -1
  3. {x_transformers-2.3.17 → x_transformers-2.3.19}/tests/test_x_transformers.py +33 -3
  4. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/attend.py +1 -0
  5. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/continuous.py +17 -6
  6. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/x_transformers.py +16 -3
  7. {x_transformers-2.3.17 → x_transformers-2.3.19}/.github/FUNDING.yml +0 -0
  8. {x_transformers-2.3.17 → x_transformers-2.3.19}/.github/workflows/python-publish.yml +0 -0
  9. {x_transformers-2.3.17 → x_transformers-2.3.19}/.github/workflows/python-test.yaml +0 -0
  10. {x_transformers-2.3.17 → x_transformers-2.3.19}/.gitignore +0 -0
  11. {x_transformers-2.3.17 → x_transformers-2.3.19}/LICENSE +0 -0
  12. {x_transformers-2.3.17 → x_transformers-2.3.19}/README.md +0 -0
  13. {x_transformers-2.3.17 → x_transformers-2.3.19}/data/README.md +0 -0
  14. {x_transformers-2.3.17 → x_transformers-2.3.19}/data/enwik8.gz +0 -0
  15. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/all-attention.png +0 -0
  16. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/attention-on-attention.png +0 -0
  17. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/cosine-sim-attention.png +0 -0
  18. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/deepnorm.png +0 -0
  19. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/dynamic-pos-bias-linear.png +0 -0
  20. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/dynamic-pos-bias-log.png +0 -0
  21. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  22. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/dynamic-pos-bias.png +0 -0
  23. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/enhanced-recurrence.png +0 -0
  24. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/fcm.png +0 -0
  25. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/ffglu.png +0 -0
  26. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/flash-attention.png +0 -0
  27. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/gate_values.png +0 -0
  28. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/gating.png +0 -0
  29. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/length-extrapolation-scale.png +0 -0
  30. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/macaron-1.png +0 -0
  31. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/macaron-2.png +0 -0
  32. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/memory-transformer.png +0 -0
  33. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/normformer.png +0 -0
  34. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/pia.png +0 -0
  35. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/qknorm-analysis.png +0 -0
  36. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/resi_dual.png +0 -0
  37. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/residual_attn.png +0 -0
  38. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/rezero.png +0 -0
  39. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/rotary.png +0 -0
  40. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/sandwich-2.png +0 -0
  41. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/sandwich.png +0 -0
  42. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/sandwich_norm.png +0 -0
  43. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/scalenorm.png +0 -0
  44. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/talking-heads.png +0 -0
  45. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/topk-attention.png +0 -0
  46. {x_transformers-2.3.17 → x_transformers-2.3.19}/images/xval.png +0 -0
  47. {x_transformers-2.3.17 → x_transformers-2.3.19}/train_belief_state.py +0 -0
  48. {x_transformers-2.3.17 → x_transformers-2.3.19}/train_copy.py +0 -0
  49. {x_transformers-2.3.17 → x_transformers-2.3.19}/train_entropy_tokenizer.py +0 -0
  50. {x_transformers-2.3.17 → x_transformers-2.3.19}/train_enwik8.py +0 -0
  51. {x_transformers-2.3.17 → x_transformers-2.3.19}/train_length_extrapolate.py +0 -0
  52. {x_transformers-2.3.17 → x_transformers-2.3.19}/train_parity.py +0 -0
  53. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/__init__.py +0 -0
  54. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/autoregressive_wrapper.py +0 -0
  55. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/belief_state_wrapper.py +0 -0
  56. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/dpo.py +0 -0
  57. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/entropy_based_tokenizer.py +0 -0
  58. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/multi_input.py +0 -0
  59. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/neo_mlp.py +0 -0
  60. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/nonautoregressive_wrapper.py +0 -0
  61. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  62. {x_transformers-2.3.17 → x_transformers-2.3.19}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.17
3
+ Version: 2.3.19
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.3.17"
3
+ version = "2.3.19"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -651,6 +651,39 @@ def test_hybrid(hybrid_axial_dim):
651
651
  mask = torch.randint(0, 2, (2, 1024)).bool()
652
652
  embed = enc(x, mask = mask)
653
653
 
654
+ def test_hybrid_cache():
655
+ from torch.nn import GRU
656
+
657
+ model = TransformerWrapper(
658
+ num_tokens = 20000,
659
+ max_seq_len = 1024,
660
+ attn_layers = Decoder(
661
+ dim = 128,
662
+ depth = 6,
663
+ heads = 8,
664
+ attn_dim_head = 64,
665
+ attn_hybrid_fold_axial_dim = 1,
666
+ attn_hybrid_module = GRU(128, 64 * 8, batch_first = True)
667
+ )
668
+ )
669
+
670
+ x = torch.randint(0, 20000, (2, 4))
671
+
672
+ # parallel
673
+
674
+ out_parallel = model(x)
675
+
676
+ # sequential
677
+
678
+ x_without_last = x[:, :-1]
679
+
680
+ out1, cache = model(x_without_last, return_intermediates = True)
681
+ out2 = model(x, cache = cache)
682
+
683
+ out_seq = torch.cat((out1, out2), dim = 1)
684
+
685
+ assert torch.allclose(out_parallel, out_seq, atol = 1e-5)
686
+
654
687
  def test_multi_latent_attention():
655
688
  model = TransformerWrapper(
656
689
  num_tokens = 20000,
@@ -876,9 +909,6 @@ def test_continuous(
876
909
  cache_kv,
877
910
  rollout_steps
878
911
  ):
879
- if probabilistic and rollout_steps > 1:
880
- pytest.skip()
881
-
882
912
  from x_transformers import (
883
913
  ContinuousTransformerWrapper,
884
914
  Decoder,
@@ -25,6 +25,7 @@ class Intermediates:
25
25
  values: Tensor | None = None
26
26
  cached_kv: Tuple[Tensor, Tensor] | None = None
27
27
  layer_type: str | None = None
28
+ hybrid_hidden: Tensor | None = None
28
29
 
29
30
  def to_tuple(self):
30
31
  return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
@@ -32,6 +32,15 @@ def default(val, d):
32
32
  return val
33
33
  return d() if not isinstance(d, Module) and callable(d) else d
34
34
 
35
+ def sample_from_mean_variance(
36
+ mean,
37
+ variance,
38
+ eps = 1e-5,
39
+ temperature = 1.
40
+ ):
41
+ std = variance.clamp(min = eps).sqrt()
42
+ return torch.normal(mean, std * temperature)
43
+
35
44
  def masked_mean(t, mask):
36
45
  t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
37
46
 
@@ -274,9 +283,7 @@ class ContinuousAutoregressiveWrapper(Module):
274
283
 
275
284
  if self.probabilistic:
276
285
  mean, var = last_output
277
- stddev = var.clamp(min = 1e-5).sqrt()
278
-
279
- last_output = torch.normal(mean, stddev * temperature)
286
+ last_output = sample_from_mean_variance(mean, var, temperature = temperature)
280
287
 
281
288
  out = cat((out, last_output), dim = -2)
282
289
 
@@ -298,7 +305,6 @@ class ContinuousAutoregressiveWrapper(Module):
298
305
  **kwargs
299
306
  ):
300
307
  assert rollout_steps > 1
301
- assert not self.probabilistic, 'probabilistic not supported yet'
302
308
 
303
309
  steps = rollout_steps
304
310
 
@@ -369,8 +375,13 @@ class ContinuousAutoregressiveWrapper(Module):
369
375
  **kwargs
370
376
  )
371
377
 
372
- last_pred = out[:, -1:]
373
- inp = last_pred
378
+ last_pred = out[..., -1:, :]
379
+
380
+ if self.probabilistic:
381
+ mean, var = last_pred
382
+ inp = sample_from_mean_variance(mean, var)
383
+ else:
384
+ inp = last_pred
374
385
 
375
386
  preds.append(last_pred)
376
387
 
@@ -1079,10 +1079,11 @@ class FoldAxially(Module):
1079
1079
  def forward(
1080
1080
  self,
1081
1081
  x,
1082
+ *args,
1082
1083
  **kwargs
1083
1084
  ):
1084
1085
  if self.axial_dim == 1:
1085
- return self.fn(x, **kwargs)
1086
+ return self.fn(x, *args, **kwargs)
1086
1087
 
1087
1088
  seq_len, axial_dim = x.shape[1], self.axial_dim
1088
1089
 
@@ -1091,7 +1092,7 @@ class FoldAxially(Module):
1091
1092
 
1092
1093
  x = rearrange(x, 'b (n axial_dim) ... -> (b axial_dim) n ...', axial_dim = axial_dim)
1093
1094
 
1094
- out = self.fn(x, **kwargs)
1095
+ out = self.fn(x, *args, **kwargs)
1095
1096
 
1096
1097
  (out, *rest_out), tree_spec = tree_flatten(out)
1097
1098
 
@@ -1857,9 +1858,17 @@ class Attention(Module):
1857
1858
  if not self.causal and exists(self.hybrid_mask_kwarg):
1858
1859
  hybrid_forward_kwargs = {self.hybrid_mask_kwarg: mask}
1859
1860
 
1861
+ # handle maybe hybrid cache
1862
+
1863
+ hybrid_forward_args = ()
1864
+
1865
+ if exists(cache) and exists(cache.hybrid_hidden):
1866
+ hybrid_hiddens = cache.hybrid_hidden
1867
+ hybrid_forward_args = (hybrid_hiddens,)
1868
+
1860
1869
  # hybrid forward
1861
1870
 
1862
- hybrid_outputs = self.hybrid_module(x, **hybrid_forward_kwargs)
1871
+ hybrid_outputs = self.hybrid_module(x, *hybrid_forward_args, **hybrid_forward_kwargs)
1863
1872
 
1864
1873
  # handle hybrid out
1865
1874
 
@@ -1870,6 +1879,10 @@ class Attention(Module):
1870
1879
  if hybrid_out.ndim == 3:
1871
1880
  hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
1872
1881
 
1882
+ if len(rest_hybrid_outs) > 0:
1883
+ hybrid_hidden = first(rest_hybrid_outs)
1884
+ intermediates.hybrid_hidden = hybrid_hidden
1885
+
1873
1886
  out_norm, hybrid_out_norm = self.hybrid_norms
1874
1887
 
1875
1888
  out = out_norm(out)
File without changes