x-transformers 2.1.24__tar.gz → 2.1.26__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {x_transformers-2.1.24 → x_transformers-2.1.26}/PKG-INFO +1 -1
  2. {x_transformers-2.1.24 → x_transformers-2.1.26}/pyproject.toml +1 -1
  3. {x_transformers-2.1.24 → x_transformers-2.1.26}/tests/test_x_transformers.py +5 -2
  4. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/belief_state_wrapper.py +18 -15
  5. {x_transformers-2.1.24 → x_transformers-2.1.26}/.github/FUNDING.yml +0 -0
  6. {x_transformers-2.1.24 → x_transformers-2.1.26}/.github/workflows/python-publish.yml +0 -0
  7. {x_transformers-2.1.24 → x_transformers-2.1.26}/.github/workflows/python-test.yaml +0 -0
  8. {x_transformers-2.1.24 → x_transformers-2.1.26}/.gitignore +0 -0
  9. {x_transformers-2.1.24 → x_transformers-2.1.26}/LICENSE +0 -0
  10. {x_transformers-2.1.24 → x_transformers-2.1.26}/README.md +0 -0
  11. {x_transformers-2.1.24 → x_transformers-2.1.26}/data/README.md +0 -0
  12. {x_transformers-2.1.24 → x_transformers-2.1.26}/data/enwik8.gz +0 -0
  13. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/all-attention.png +0 -0
  14. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/deepnorm.png +0 -0
  17. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/fcm.png +0 -0
  23. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/ffglu.png +0 -0
  24. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/flash-attention.png +0 -0
  25. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/gate_values.png +0 -0
  26. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/gating.png +0 -0
  27. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/macaron-1.png +0 -0
  29. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/macaron-2.png +0 -0
  30. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/normformer.png +0 -0
  32. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/pia.png +0 -0
  33. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/resi_dual.png +0 -0
  35. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/residual_attn.png +0 -0
  36. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/rezero.png +0 -0
  37. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/rotary.png +0 -0
  38. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/sandwich.png +0 -0
  40. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/scalenorm.png +0 -0
  42. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/talking-heads.png +0 -0
  43. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/topk-attention.png +0 -0
  44. {x_transformers-2.1.24 → x_transformers-2.1.26}/images/xval.png +0 -0
  45. {x_transformers-2.1.24 → x_transformers-2.1.26}/train_belief_state.py +0 -0
  46. {x_transformers-2.1.24 → x_transformers-2.1.26}/train_copy.py +0 -0
  47. {x_transformers-2.1.24 → x_transformers-2.1.26}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.24 → x_transformers-2.1.26}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.24 → x_transformers-2.1.26}/train_parity.py +0 -0
  50. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/__init__.py +0 -0
  51. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/attend.py +0 -0
  52. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/autoregressive_wrapper.py +0 -0
  53. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/continuous.py +0 -0
  54. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/dpo.py +0 -0
  55. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/multi_input.py +0 -0
  56. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/neo_mlp.py +0 -0
  57. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/nonautoregressive_wrapper.py +0 -0
  58. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/x_transformers.py +0 -0
  59. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.24 → x_transformers-2.1.26}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.24
3
+ Version: 2.1.26
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.24"
3
+ version = "2.1.26"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -696,9 +696,11 @@ def test_lime(
696
696
 
697
697
  @pytest.mark.parametrize('backward_ar_loss_weight', (1., 0.5))
698
698
  @pytest.mark.parametrize('goal_suffix', (False, True))
699
+ @pytest.mark.parametrize('pred_distance', (False, True))
699
700
  def test_belief_state_wrapper(
700
701
  backward_ar_loss_weight,
701
- goal_suffix
702
+ goal_suffix,
703
+ pred_distance
702
704
  ):
703
705
  from x_transformers.belief_state_wrapper import BeliefStateWrapper
704
706
 
@@ -727,7 +729,8 @@ def test_belief_state_wrapper(
727
729
  model = BeliefStateWrapper(
728
730
  forward_decoder = forward_model,
729
731
  backward_decoder = backward_model,
730
- backward_ar_loss_weight = backward_ar_loss_weight
732
+ backward_ar_loss_weight = backward_ar_loss_weight,
733
+ pred_distance = pred_distance
731
734
  )
732
735
 
733
736
  seq = torch.randint(0, 20000, (2, 16))
@@ -48,8 +48,9 @@ class BeliefStateWrapper(Module):
48
48
  train_frac_forward_backward_pairs: float = 1.,
49
49
  text_head: Module | None = None,
50
50
  backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
51
- pred_terminal = False,
52
- pred_terminal_loss_weight: float = 1.
51
+ pred_distance = False,
52
+ pred_distance_loss_weight: float = 1.,
53
+ max_pred_distance = None
53
54
  ):
54
55
  super().__init__()
55
56
  backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token
@@ -59,6 +60,7 @@ class BeliefStateWrapper(Module):
59
60
 
60
61
  dim = forward_decoder.emb_dim
61
62
  num_tokens = forward_decoder.num_tokens
63
+ max_seq_len = forward_decoder.max_seq_len
62
64
 
63
65
  self.num_tokens = num_tokens
64
66
 
@@ -80,14 +82,15 @@ class BeliefStateWrapper(Module):
80
82
 
81
83
  # predicting terminal state (when suffix and prefix predict the same token)
82
84
 
83
- self.to_terminal_logit = nn.Sequential(
85
+ self.max_pred_distance = default(max_pred_distance, max_seq_len)
86
+
87
+ self.to_distance_logits = nn.Sequential(
84
88
  nn.Linear(dim * 2, dim),
85
89
  nn.LeakyReLU(),
86
- nn.Linear(dim, 1),
87
- Rearrange('... 1 -> ...')
88
- ) if pred_terminal else None
90
+ nn.Linear(dim, self.max_pred_distance),
91
+ ) if pred_distance else None
89
92
 
90
- self.pred_terminal_loss_weight = pred_terminal_loss_weight
93
+ self.pred_distance_loss_weight = pred_distance_loss_weight
91
94
 
92
95
  # the two decoders, one which is causal forward, the other causal backwards
93
96
 
@@ -314,20 +317,20 @@ class BeliefStateWrapper(Module):
314
317
 
315
318
  # maybe predict terminal
316
319
 
317
- if exists(self.to_terminal_logit):
318
- terminal_logits = self.to_terminal_logit(fb_embeds)
320
+ if exists(self.to_distance_logits):
321
+ distance_logits = self.to_distance_logits(fb_embeds)
319
322
 
320
- terminal_labels = ((bi - fi) == 2).float() # distance is exactly 2
321
- terminal_labels = repeat(terminal_labels, 'n -> b n', b = batch)
323
+ distance_labels = (bi - fi).clamp(max = self.max_pred_distance - 1)
324
+ distance_labels = repeat(distance_labels, 'n -> b n', b = batch)
322
325
 
323
- is_end_loss = F.binary_cross_entropy_with_logits(
324
- terminal_logits,
325
- terminal_labels
326
+ pred_dist_loss = F.cross_entropy(
327
+ rearrange(distance_logits, 'b n l -> b l n'),
328
+ distance_labels
326
329
  )
327
330
 
328
331
  loss = (
329
332
  loss +
330
- is_end_loss * self.pred_terminal_loss_weight
333
+ pred_dist_loss * self.pred_distance_loss_weight
331
334
  )
332
335
 
333
336
  # maybe early return loss
File without changes