x-transformers 1.27.14__py3-none-any.whl → 1.27.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/dpo.py CHANGED
@@ -16,18 +16,32 @@ def freeze_all_layers_(module):
16
16
  for param in module.parameters():
17
17
  param.requires_grad = False
18
18
 
19
- def log(t, eps = 1e-20):
20
- return torch.log(t.clamp(min = eps))
21
-
22
- def log_prob(prob, indices, eps = 1e-20):
23
- indices = rearrange(indices, '... -> ... 1')
24
- log_probs = log(prob.gather(-1, indices), eps = eps)
25
- return rearrange(log_probs, '... 1 -> ...')
26
-
27
19
  def log_prob_from_model_and_seq(model, seq):
28
20
  logits = model(seq)
29
- prob = logits.softmax(dim = -1)
30
- return log_prob(prob, seq)
21
+ log_prob = logits.log_softmax(dim = -1)
22
+ indices = rearrange(seq, '... -> ... 1')
23
+ log_probs = log_prob.gather(-1, indices)
24
+ return rearrange(log_probs, '... 1 -> ...')
25
+
26
+ def masked_mean(log_probs, mask = None):
27
+ if not exists(mask):
28
+ return log_probs.mean(dim = -1)
29
+
30
+ log_probs = log_probs.masked_fill(~mask, 0.)
31
+ num = log_probs.sum(dim = -1)
32
+ den = mask.sum(dim = -1)
33
+ return num / den.clamp(min = 1e-5)
34
+
35
+ def maybe_and_mask(*masks):
36
+ masks = [*filter(exists, masks)]
37
+ if len(masks) == 0:
38
+ return None
39
+
40
+ mask, *rest_masks = masks
41
+ for rest_mask in rest_masks:
42
+ mask = mask & rest_mask
43
+
44
+ return mask
31
45
 
32
46
  # main class
33
47
 
@@ -36,7 +50,8 @@ class DPO(Module):
36
50
  self,
37
51
  model: TransformerWrapper,
38
52
  *,
39
- beta = 0.1
53
+ beta = 0.1,
54
+ pad_id = None
40
55
  ):
41
56
  super().__init__()
42
57
  self.policy_model = model
@@ -45,6 +60,7 @@ class DPO(Module):
45
60
  freeze_all_layers_(self.ref_model)
46
61
 
47
62
  self.beta = beta
63
+ self.pad_id = pad_id
48
64
 
49
65
  def parameters(self):
50
66
  return self.policy_model.parameters()
@@ -53,11 +69,21 @@ class DPO(Module):
53
69
  self,
54
70
  preferred_seq,
55
71
  unpreferred_seq,
56
- prompt_mask = None
72
+ *,
73
+ prompt_mask,
74
+ preferred_seq_mask = None,
75
+ unpreferred_seq_mask = None,
57
76
  ):
58
77
  assert preferred_seq.ndim == 2
59
78
  assert preferred_seq.shape == unpreferred_seq.shape
60
79
 
80
+ if exists(self.pad_id):
81
+ if not exists(preferred_seq_mask):
82
+ preferred_seq_mask = preferred_seq != self.pad_id
83
+
84
+ if not exists(unpreferred_seq_mask):
85
+ unpreferred_seq_mask = unpreferred_seq != self.pad_id
86
+
61
87
  """
62
88
  Following Appendix B in https://arxiv.org/abs/2305.18290
63
89
  """
@@ -70,12 +96,19 @@ class DPO(Module):
70
96
  policy_preferred_logprob = log_prob_from_model_and_seq(self.policy_model, preferred_seq)
71
97
  policy_unpreferred_logprob = log_prob_from_model_and_seq(self.policy_model, unpreferred_seq)
72
98
 
99
+ # masked mean of log probs
100
+
101
+ preferred_seq_mask = maybe_and_mask(~prompt_mask, preferred_seq_mask)
102
+ unpreferred_seq_mask = maybe_and_mask(~prompt_mask, unpreferred_seq_mask)
103
+
104
+ ref_preferred_logprob, policy_preferred_logprob = map(lambda t: masked_mean(t, preferred_seq_mask), (ref_preferred_logprob, policy_preferred_logprob))
105
+ ref_unpreferred_logprob, policy_unpreferred_logprob = map(lambda t: masked_mean(t, unpreferred_seq_mask), (ref_unpreferred_logprob, policy_unpreferred_logprob))
106
+
107
+ # main dpo formula
108
+
73
109
  policy_logratios = policy_preferred_logprob - policy_unpreferred_logprob
74
110
  ref_logratios = ref_preferred_logprob - ref_unpreferred_logprob
75
111
 
76
112
  losses = -F.logsigmoid(self.beta * (policy_logratios - ref_logratios))
77
113
 
78
- if exists(prompt_mask):
79
- losses = losses[~prompt_mask]
80
-
81
114
  return losses.mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.27.14
3
+ Version: 1.27.15
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -2,13 +2,13 @@ x_transformers/__init__.py,sha256=0-2m0LtLpZiZYGwO-6OMYXofx5hbFb_FJOHMxIBqQr4,67
2
2
  x_transformers/attend.py,sha256=Y3PzYqD3G_x1bYPd6mlp27dp3obaum1O-TOOQaARctc,10188
3
3
  x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
4
4
  x_transformers/continuous.py,sha256=92Wczoaz6dJalix-e3mdIzW0xyRIx3GlBSgsSQOsJeI,6123
5
- x_transformers/dpo.py,sha256=ek9dgiSs05xeCn8ORceOgKy6LJOnNDw-OJDqxAVLecM,2243
5
+ x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
7
7
  x_transformers/x_transformers.py,sha256=3caIQMDP2pxVuAA-CdEteUqX9ikNSanrmzKjkvzogjE,63619
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
10
- x_transformers-1.27.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.27.14.dist-info/METADATA,sha256=fXXkd4baN2z6pg5aWlMy-6Jpwb6PtKH-Bntnr6EdYWg,662
12
- x_transformers-1.27.14.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
13
- x_transformers-1.27.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.27.14.dist-info/RECORD,,
10
+ x_transformers-1.27.15.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.27.15.dist-info/METADATA,sha256=XkrLQTcz-jpF-uZECWTOm1uFAtDVf1Zfm4NEI43dylg,662
12
+ x_transformers-1.27.15.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
13
+ x_transformers-1.27.15.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.27.15.dist-info/RECORD,,