x-transformers 1.27.6__tar.gz → 1.27.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {x-transformers-1.27.6/x_transformers.egg-info → x-transformers-1.27.8}/PKG-INFO +1 -1
  2. {x-transformers-1.27.6 → x-transformers-1.27.8}/README.md +11 -0
  3. {x-transformers-1.27.6 → x-transformers-1.27.8}/setup.py +1 -1
  4. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers/__init__.py +4 -0
  5. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers/continuous.py +2 -1
  6. x-transformers-1.27.8/x_transformers/dpo.py +81 -0
  7. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers/x_transformers.py +2 -2
  8. {x-transformers-1.27.6 → x-transformers-1.27.8/x_transformers.egg-info}/PKG-INFO +1 -1
  9. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers.egg-info/SOURCES.txt +1 -0
  10. {x-transformers-1.27.6 → x-transformers-1.27.8}/LICENSE +0 -0
  11. {x-transformers-1.27.6 → x-transformers-1.27.8}/setup.cfg +0 -0
  12. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers/attend.py +0 -0
  13. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers/autoregressive_wrapper.py +0 -0
  14. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers/nonautoregressive_wrapper.py +0 -0
  15. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  16. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers/xval.py +0 -0
  17. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers.egg-info/dependency_links.txt +0 -0
  18. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers.egg-info/requires.txt +0 -0
  19. {x-transformers-1.27.6 → x-transformers-1.27.8}/x_transformers.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.27.6
3
+ Version: 1.27.8
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -2076,4 +2076,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2076
2076
  }
2077
2077
  ```
2078
2078
 
2079
+ ```bibtex
2080
+ @article{Rafailov2023DirectPO,
2081
+ title = {Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
2082
+ author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
2083
+ journal = {ArXiv},
2084
+ year = {2023},
2085
+ volume = {abs/2305.18290},
2086
+ url = {https://api.semanticscholar.org/CorpusID:258959321}
2087
+ }
2088
+ ```
2089
+
2079
2090
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
3
3
  setup(
4
4
  name = 'x-transformers',
5
5
  packages = find_packages(exclude=['examples']),
6
- version = '1.27.6',
6
+ version = '1.27.8',
7
7
  license='MIT',
8
8
  description = 'X-Transformers - Pytorch',
9
9
  author = 'Phil Wang',
@@ -23,3 +23,7 @@ from x_transformers.xval import (
23
23
  )
24
24
 
25
25
  from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
26
+
27
+ from x_transformers.dpo import (
28
+ DPO
29
+ )
@@ -84,6 +84,7 @@ class ContinuousTransformerWrapper(nn.Module):
84
84
  mask = None,
85
85
  return_attn = False,
86
86
  mems = None,
87
+ mem_masks = None,
87
88
  pos = None,
88
89
  prepend_embeds = None,
89
90
  prepend_mask = None,
@@ -125,7 +126,7 @@ class ContinuousTransformerWrapper(nn.Module):
125
126
 
126
127
  # attention layers
127
128
 
128
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
129
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, return_hiddens = True, **kwargs)
129
130
 
130
131
  # splice out memory tokens
131
132
 
@@ -0,0 +1,81 @@
1
+ from copy import deepcopy
2
+
3
+ import torch
4
+ from torch.nn import Module
5
+ import torch.nn.functional as F
6
+ from x_transformers.x_transformers import TransformerWrapper
7
+
8
+ from einops import rearrange
9
+
10
+ # helper functions
11
+
12
+ def exists(v):
13
+ return v is not None
14
+
15
+ def freeze_all_layers_(module):
16
+ for param in module.parameters():
17
+ param.requires_grad = False
18
+
19
+ def log(t, eps = 1e-20):
20
+ return torch.log(t.clamp(min = eps))
21
+
22
+ def log_prob(prob, indices, eps = 1e-20):
23
+ indices = rearrange(indices, '... -> ... 1')
24
+ log_probs = log(prob.gather(-1, indices), eps = eps)
25
+ return rearrange(log_probs, '... 1 -> ...')
26
+
27
+ def log_prob_from_model_and_seq(model, seq):
28
+ logits = model(seq)
29
+ prob = logits.softmax(dim = -1)
30
+ return log_prob(prob, seq)
31
+
32
+ # main class
33
+
34
+ class DPO(Module):
35
+ def __init__(
36
+ self,
37
+ model: TransformerWrapper,
38
+ *,
39
+ beta = 0.1
40
+ ):
41
+ super().__init__()
42
+ self.policy_model = model
43
+
44
+ self.ref_model = deepcopy(model)
45
+ freeze_all_layers_(self.ref_model)
46
+
47
+ self.beta = beta
48
+
49
+ def parameters(self):
50
+ return self.policy_model.parameters()
51
+
52
+ def forward(
53
+ self,
54
+ preferred_seq,
55
+ unpreferred_seq,
56
+ prompt_mask = None
57
+ ):
58
+ assert preferred_seq.ndim == 2
59
+ assert preferred_seq.shape == unpreferred_seq.shape
60
+
61
+ """
62
+ Following Appendix B in https://arxiv.org/abs/2305.18290
63
+ """
64
+
65
+ with torch.no_grad():
66
+ self.ref_model.eval()
67
+ ref_preferred_logprob = log_prob_from_model_and_seq(self.ref_model, preferred_seq)
68
+ ref_unpreferred_logprob = log_prob_from_model_and_seq(self.ref_model, unpreferred_seq)
69
+
70
+ policy_preferred_logprob = log_prob_from_model_and_seq(self.policy_model, preferred_seq)
71
+ policy_unpreferred_logprob = log_prob_from_model_and_seq(self.policy_model, unpreferred_seq)
72
+
73
+ policy_logratios = policy_preferred_logprob - policy_unpreferred_logprob
74
+ ref_logratios = ref_preferred_logprob - ref_unpreferred_logprob
75
+
76
+ losses = -F.logsigmoid(self.beta * (policy_logratios - ref_logratios))
77
+
78
+ if exists(prompt_mask):
79
+ losses = losses[~prompt_mask]
80
+
81
+ return losses.mean()
@@ -788,8 +788,8 @@ class Attention(nn.Module):
788
788
  # add memory key / values
789
789
  self.num_mem_kv = num_mem_kv
790
790
  if num_mem_kv > 0:
791
- self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
792
- self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
791
+ self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
792
+ self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
793
793
 
794
794
  # attention on attention
795
795
  self.attn_on_attn = on_attn
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.27.6
3
+ Version: 1.27.8
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,6 +5,7 @@ x_transformers/__init__.py
5
5
  x_transformers/attend.py
6
6
  x_transformers/autoregressive_wrapper.py
7
7
  x_transformers/continuous.py
8
+ x_transformers/dpo.py
8
9
  x_transformers/nonautoregressive_wrapper.py
9
10
  x_transformers/x_transformers.py
10
11
  x_transformers/xl_autoregressive_wrapper.py
File without changes