x-transformers 1.27.14__py3-none-any.whl → 1.27.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/dpo.py +48 -15
- {x_transformers-1.27.14.dist-info → x_transformers-1.27.15.dist-info}/METADATA +1 -1
- {x_transformers-1.27.14.dist-info → x_transformers-1.27.15.dist-info}/RECORD +6 -6
- {x_transformers-1.27.14.dist-info → x_transformers-1.27.15.dist-info}/LICENSE +0 -0
- {x_transformers-1.27.14.dist-info → x_transformers-1.27.15.dist-info}/WHEEL +0 -0
- {x_transformers-1.27.14.dist-info → x_transformers-1.27.15.dist-info}/top_level.txt +0 -0
x_transformers/dpo.py
CHANGED
@@ -16,18 +16,32 @@ def freeze_all_layers_(module):
|
|
16
16
|
for param in module.parameters():
|
17
17
|
param.requires_grad = False
|
18
18
|
|
19
|
-
def log(t, eps = 1e-20):
|
20
|
-
return torch.log(t.clamp(min = eps))
|
21
|
-
|
22
|
-
def log_prob(prob, indices, eps = 1e-20):
|
23
|
-
indices = rearrange(indices, '... -> ... 1')
|
24
|
-
log_probs = log(prob.gather(-1, indices), eps = eps)
|
25
|
-
return rearrange(log_probs, '... 1 -> ...')
|
26
|
-
|
27
19
|
def log_prob_from_model_and_seq(model, seq):
|
28
20
|
logits = model(seq)
|
29
|
-
|
30
|
-
|
21
|
+
log_prob = logits.log_softmax(dim = -1)
|
22
|
+
indices = rearrange(seq, '... -> ... 1')
|
23
|
+
log_probs = log_prob.gather(-1, indices)
|
24
|
+
return rearrange(log_probs, '... 1 -> ...')
|
25
|
+
|
26
|
+
def masked_mean(log_probs, mask = None):
|
27
|
+
if not exists(mask):
|
28
|
+
return log_probs.mean(dim = -1)
|
29
|
+
|
30
|
+
log_probs = log_probs.masked_fill(~mask, 0.)
|
31
|
+
num = log_probs.sum(dim = -1)
|
32
|
+
den = mask.sum(dim = -1)
|
33
|
+
return num / den.clamp(min = 1e-5)
|
34
|
+
|
35
|
+
def maybe_and_mask(*masks):
|
36
|
+
masks = [*filter(exists, masks)]
|
37
|
+
if len(masks) == 0:
|
38
|
+
return None
|
39
|
+
|
40
|
+
mask, *rest_masks = masks
|
41
|
+
for rest_mask in rest_masks:
|
42
|
+
mask = mask & rest_mask
|
43
|
+
|
44
|
+
return mask
|
31
45
|
|
32
46
|
# main class
|
33
47
|
|
@@ -36,7 +50,8 @@ class DPO(Module):
|
|
36
50
|
self,
|
37
51
|
model: TransformerWrapper,
|
38
52
|
*,
|
39
|
-
beta = 0.1
|
53
|
+
beta = 0.1,
|
54
|
+
pad_id = None
|
40
55
|
):
|
41
56
|
super().__init__()
|
42
57
|
self.policy_model = model
|
@@ -45,6 +60,7 @@ class DPO(Module):
|
|
45
60
|
freeze_all_layers_(self.ref_model)
|
46
61
|
|
47
62
|
self.beta = beta
|
63
|
+
self.pad_id = pad_id
|
48
64
|
|
49
65
|
def parameters(self):
|
50
66
|
return self.policy_model.parameters()
|
@@ -53,11 +69,21 @@ class DPO(Module):
|
|
53
69
|
self,
|
54
70
|
preferred_seq,
|
55
71
|
unpreferred_seq,
|
56
|
-
|
72
|
+
*,
|
73
|
+
prompt_mask,
|
74
|
+
preferred_seq_mask = None,
|
75
|
+
unpreferred_seq_mask = None,
|
57
76
|
):
|
58
77
|
assert preferred_seq.ndim == 2
|
59
78
|
assert preferred_seq.shape == unpreferred_seq.shape
|
60
79
|
|
80
|
+
if exists(self.pad_id):
|
81
|
+
if not exists(preferred_seq_mask):
|
82
|
+
preferred_seq_mask = preferred_seq != self.pad_id
|
83
|
+
|
84
|
+
if not exists(unpreferred_seq_mask):
|
85
|
+
unpreferred_seq_mask = unpreferred_seq != self.pad_id
|
86
|
+
|
61
87
|
"""
|
62
88
|
Following Appendix B in https://arxiv.org/abs/2305.18290
|
63
89
|
"""
|
@@ -70,12 +96,19 @@ class DPO(Module):
|
|
70
96
|
policy_preferred_logprob = log_prob_from_model_and_seq(self.policy_model, preferred_seq)
|
71
97
|
policy_unpreferred_logprob = log_prob_from_model_and_seq(self.policy_model, unpreferred_seq)
|
72
98
|
|
99
|
+
# masked mean of log probs
|
100
|
+
|
101
|
+
preferred_seq_mask = maybe_and_mask(~prompt_mask, preferred_seq_mask)
|
102
|
+
unpreferred_seq_mask = maybe_and_mask(~prompt_mask, unpreferred_seq_mask)
|
103
|
+
|
104
|
+
ref_preferred_logprob, policy_preferred_logprob = map(lambda t: masked_mean(t, preferred_seq_mask), (ref_preferred_logprob, policy_preferred_logprob))
|
105
|
+
ref_unpreferred_logprob, policy_unpreferred_logprob = map(lambda t: masked_mean(t, unpreferred_seq_mask), (ref_unpreferred_logprob, policy_unpreferred_logprob))
|
106
|
+
|
107
|
+
# main dpo formula
|
108
|
+
|
73
109
|
policy_logratios = policy_preferred_logprob - policy_unpreferred_logprob
|
74
110
|
ref_logratios = ref_preferred_logprob - ref_unpreferred_logprob
|
75
111
|
|
76
112
|
losses = -F.logsigmoid(self.beta * (policy_logratios - ref_logratios))
|
77
113
|
|
78
|
-
if exists(prompt_mask):
|
79
|
-
losses = losses[~prompt_mask]
|
80
|
-
|
81
114
|
return losses.mean()
|
@@ -2,13 +2,13 @@ x_transformers/__init__.py,sha256=0-2m0LtLpZiZYGwO-6OMYXofx5hbFb_FJOHMxIBqQr4,67
|
|
2
2
|
x_transformers/attend.py,sha256=Y3PzYqD3G_x1bYPd6mlp27dp3obaum1O-TOOQaARctc,10188
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
|
4
4
|
x_transformers/continuous.py,sha256=92Wczoaz6dJalix-e3mdIzW0xyRIx3GlBSgsSQOsJeI,6123
|
5
|
-
x_transformers/dpo.py,sha256=
|
5
|
+
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
7
7
|
x_transformers/x_transformers.py,sha256=3caIQMDP2pxVuAA-CdEteUqX9ikNSanrmzKjkvzogjE,63619
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
|
10
|
-
x_transformers-1.27.
|
11
|
-
x_transformers-1.27.
|
12
|
-
x_transformers-1.27.
|
13
|
-
x_transformers-1.27.
|
14
|
-
x_transformers-1.27.
|
10
|
+
x_transformers-1.27.15.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.27.15.dist-info/METADATA,sha256=XkrLQTcz-jpF-uZECWTOm1uFAtDVf1Zfm4NEI43dylg,662
|
12
|
+
x_transformers-1.27.15.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
13
|
+
x_transformers-1.27.15.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.27.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|