x-transformers 1.27.6__py3-none-any.whl → 1.27.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/__init__.py +4 -0
- x_transformers/continuous.py +2 -1
- x_transformers/dpo.py +81 -0
- x_transformers/x_transformers.py +2 -2
- {x_transformers-1.27.6.dist-info → x_transformers-1.27.8.dist-info}/METADATA +1 -1
- x_transformers-1.27.8.dist-info/RECORD +14 -0
- x_transformers-1.27.6.dist-info/RECORD +0 -13
- {x_transformers-1.27.6.dist-info → x_transformers-1.27.8.dist-info}/LICENSE +0 -0
- {x_transformers-1.27.6.dist-info → x_transformers-1.27.8.dist-info}/WHEEL +0 -0
- {x_transformers-1.27.6.dist-info → x_transformers-1.27.8.dist-info}/top_level.txt +0 -0
x_transformers/__init__.py
CHANGED
x_transformers/continuous.py
CHANGED
@@ -84,6 +84,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
84
84
|
mask = None,
|
85
85
|
return_attn = False,
|
86
86
|
mems = None,
|
87
|
+
mem_masks = None,
|
87
88
|
pos = None,
|
88
89
|
prepend_embeds = None,
|
89
90
|
prepend_mask = None,
|
@@ -125,7 +126,7 @@ class ContinuousTransformerWrapper(nn.Module):
|
|
125
126
|
|
126
127
|
# attention layers
|
127
128
|
|
128
|
-
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
|
129
|
+
x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, return_hiddens = True, **kwargs)
|
129
130
|
|
130
131
|
# splice out memory tokens
|
131
132
|
|
x_transformers/dpo.py
ADDED
@@ -0,0 +1,81 @@
|
|
1
|
+
from copy import deepcopy
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch.nn import Module
|
5
|
+
import torch.nn.functional as F
|
6
|
+
from x_transformers.x_transformers import TransformerWrapper
|
7
|
+
|
8
|
+
from einops import rearrange
|
9
|
+
|
10
|
+
# helper functions
|
11
|
+
|
12
|
+
def exists(v):
|
13
|
+
return v is not None
|
14
|
+
|
15
|
+
def freeze_all_layers_(module):
|
16
|
+
for param in module.parameters():
|
17
|
+
param.requires_grad = False
|
18
|
+
|
19
|
+
def log(t, eps = 1e-20):
|
20
|
+
return torch.log(t.clamp(min = eps))
|
21
|
+
|
22
|
+
def log_prob(prob, indices, eps = 1e-20):
|
23
|
+
indices = rearrange(indices, '... -> ... 1')
|
24
|
+
log_probs = log(prob.gather(-1, indices), eps = eps)
|
25
|
+
return rearrange(log_probs, '... 1 -> ...')
|
26
|
+
|
27
|
+
def log_prob_from_model_and_seq(model, seq):
|
28
|
+
logits = model(seq)
|
29
|
+
prob = logits.softmax(dim = -1)
|
30
|
+
return log_prob(prob, seq)
|
31
|
+
|
32
|
+
# main class
|
33
|
+
|
34
|
+
class DPO(Module):
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
model: TransformerWrapper,
|
38
|
+
*,
|
39
|
+
beta = 0.1
|
40
|
+
):
|
41
|
+
super().__init__()
|
42
|
+
self.policy_model = model
|
43
|
+
|
44
|
+
self.ref_model = deepcopy(model)
|
45
|
+
freeze_all_layers_(self.ref_model)
|
46
|
+
|
47
|
+
self.beta = beta
|
48
|
+
|
49
|
+
def parameters(self):
|
50
|
+
return self.policy_model.parameters()
|
51
|
+
|
52
|
+
def forward(
|
53
|
+
self,
|
54
|
+
preferred_seq,
|
55
|
+
unpreferred_seq,
|
56
|
+
prompt_mask = None
|
57
|
+
):
|
58
|
+
assert preferred_seq.ndim == 2
|
59
|
+
assert preferred_seq.shape == unpreferred_seq.shape
|
60
|
+
|
61
|
+
"""
|
62
|
+
Following Appendix B in https://arxiv.org/abs/2305.18290
|
63
|
+
"""
|
64
|
+
|
65
|
+
with torch.no_grad():
|
66
|
+
self.ref_model.eval()
|
67
|
+
ref_preferred_logprob = log_prob_from_model_and_seq(self.ref_model, preferred_seq)
|
68
|
+
ref_unpreferred_logprob = log_prob_from_model_and_seq(self.ref_model, unpreferred_seq)
|
69
|
+
|
70
|
+
policy_preferred_logprob = log_prob_from_model_and_seq(self.policy_model, preferred_seq)
|
71
|
+
policy_unpreferred_logprob = log_prob_from_model_and_seq(self.policy_model, unpreferred_seq)
|
72
|
+
|
73
|
+
policy_logratios = policy_preferred_logprob - policy_unpreferred_logprob
|
74
|
+
ref_logratios = ref_preferred_logprob - ref_unpreferred_logprob
|
75
|
+
|
76
|
+
losses = -F.logsigmoid(self.beta * (policy_logratios - ref_logratios))
|
77
|
+
|
78
|
+
if exists(prompt_mask):
|
79
|
+
losses = losses[~prompt_mask]
|
80
|
+
|
81
|
+
return losses.mean()
|
x_transformers/x_transformers.py
CHANGED
@@ -788,8 +788,8 @@ class Attention(nn.Module):
|
|
788
788
|
# add memory key / values
|
789
789
|
self.num_mem_kv = num_mem_kv
|
790
790
|
if num_mem_kv > 0:
|
791
|
-
self.mem_k = nn.Parameter(torch.randn(
|
792
|
-
self.mem_v = nn.Parameter(torch.randn(
|
791
|
+
self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
|
792
|
+
self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
|
793
793
|
|
794
794
|
# attention on attention
|
795
795
|
self.attn_on_attn = on_attn
|
@@ -0,0 +1,14 @@
|
|
1
|
+
x_transformers/__init__.py,sha256=0-2m0LtLpZiZYGwO-6OMYXofx5hbFb_FJOHMxIBqQr4,673
|
2
|
+
x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
|
4
|
+
x_transformers/continuous.py,sha256=92Wczoaz6dJalix-e3mdIzW0xyRIx3GlBSgsSQOsJeI,6123
|
5
|
+
x_transformers/dpo.py,sha256=ek9dgiSs05xeCn8ORceOgKy6LJOnNDw-OJDqxAVLecM,2243
|
6
|
+
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
7
|
+
x_transformers/x_transformers.py,sha256=c8axLT-n2zz3mvQ1tBbE4KUs-8qL7yFsgtIujyh1JDg,63408
|
8
|
+
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
|
+
x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
|
10
|
+
x_transformers-1.27.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.27.8.dist-info/METADATA,sha256=LYDYUsXQOHYBZRr_5pepdN9HSzaW-2nFX5pEzEOFkcA,661
|
12
|
+
x_transformers-1.27.8.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
13
|
+
x_transformers-1.27.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.27.8.dist-info/RECORD,,
|
@@ -1,13 +0,0 @@
|
|
1
|
-
x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
|
2
|
-
x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
|
4
|
-
x_transformers/continuous.py,sha256=SAZGR-3BgXU7OEQtjg1_9FnrUBpIyVfXfpMrH-oL5rU,6074
|
5
|
-
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=8nBe_MQLfQDHf59pM_c2IiTVHA9frFbFyg4n8S00ZVI,63402
|
7
|
-
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
-
x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
|
9
|
-
x_transformers-1.27.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
10
|
-
x_transformers-1.27.6.dist-info/METADATA,sha256=32gbNOf9pJgUoNTdaplhW1mcB4ECiJo5OTCpIVNFWCA,661
|
11
|
-
x_transformers-1.27.6.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
12
|
-
x_transformers-1.27.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
13
|
-
x_transformers-1.27.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|