x-transformers 1.27.6__py3-none-any.whl → 1.27.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,3 +23,7 @@ from x_transformers.xval import (
23
23
  )
24
24
 
25
25
  from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
26
+
27
+ from x_transformers.dpo import (
28
+ DPO
29
+ )
@@ -84,6 +84,7 @@ class ContinuousTransformerWrapper(nn.Module):
84
84
  mask = None,
85
85
  return_attn = False,
86
86
  mems = None,
87
+ mem_masks = None,
87
88
  pos = None,
88
89
  prepend_embeds = None,
89
90
  prepend_mask = None,
@@ -125,7 +126,7 @@ class ContinuousTransformerWrapper(nn.Module):
125
126
 
126
127
  # attention layers
127
128
 
128
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
129
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, mem_masks = mem_masks, return_hiddens = True, **kwargs)
129
130
 
130
131
  # splice out memory tokens
131
132
 
x_transformers/dpo.py ADDED
@@ -0,0 +1,81 @@
1
+ from copy import deepcopy
2
+
3
+ import torch
4
+ from torch.nn import Module
5
+ import torch.nn.functional as F
6
+ from x_transformers.x_transformers import TransformerWrapper
7
+
8
+ from einops import rearrange
9
+
10
+ # helper functions
11
+
12
+ def exists(v):
13
+ return v is not None
14
+
15
+ def freeze_all_layers_(module):
16
+ for param in module.parameters():
17
+ param.requires_grad = False
18
+
19
+ def log(t, eps = 1e-20):
20
+ return torch.log(t.clamp(min = eps))
21
+
22
+ def log_prob(prob, indices, eps = 1e-20):
23
+ indices = rearrange(indices, '... -> ... 1')
24
+ log_probs = log(prob.gather(-1, indices), eps = eps)
25
+ return rearrange(log_probs, '... 1 -> ...')
26
+
27
+ def log_prob_from_model_and_seq(model, seq):
28
+ logits = model(seq)
29
+ prob = logits.softmax(dim = -1)
30
+ return log_prob(prob, seq)
31
+
32
+ # main class
33
+
34
+ class DPO(Module):
35
+ def __init__(
36
+ self,
37
+ model: TransformerWrapper,
38
+ *,
39
+ beta = 0.1
40
+ ):
41
+ super().__init__()
42
+ self.policy_model = model
43
+
44
+ self.ref_model = deepcopy(model)
45
+ freeze_all_layers_(self.ref_model)
46
+
47
+ self.beta = beta
48
+
49
+ def parameters(self):
50
+ return self.policy_model.parameters()
51
+
52
+ def forward(
53
+ self,
54
+ preferred_seq,
55
+ unpreferred_seq,
56
+ prompt_mask = None
57
+ ):
58
+ assert preferred_seq.ndim == 2
59
+ assert preferred_seq.shape == unpreferred_seq.shape
60
+
61
+ """
62
+ Following Appendix B in https://arxiv.org/abs/2305.18290
63
+ """
64
+
65
+ with torch.no_grad():
66
+ self.ref_model.eval()
67
+ ref_preferred_logprob = log_prob_from_model_and_seq(self.ref_model, preferred_seq)
68
+ ref_unpreferred_logprob = log_prob_from_model_and_seq(self.ref_model, unpreferred_seq)
69
+
70
+ policy_preferred_logprob = log_prob_from_model_and_seq(self.policy_model, preferred_seq)
71
+ policy_unpreferred_logprob = log_prob_from_model_and_seq(self.policy_model, unpreferred_seq)
72
+
73
+ policy_logratios = policy_preferred_logprob - policy_unpreferred_logprob
74
+ ref_logratios = ref_preferred_logprob - ref_unpreferred_logprob
75
+
76
+ losses = -F.logsigmoid(self.beta * (policy_logratios - ref_logratios))
77
+
78
+ if exists(prompt_mask):
79
+ losses = losses[~prompt_mask]
80
+
81
+ return losses.mean()
@@ -788,8 +788,8 @@ class Attention(nn.Module):
788
788
  # add memory key / values
789
789
  self.num_mem_kv = num_mem_kv
790
790
  if num_mem_kv > 0:
791
- self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
792
- self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
791
+ self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
792
+ self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
793
793
 
794
794
  # attention on attention
795
795
  self.attn_on_attn = on_attn
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.27.6
3
+ Version: 1.27.8
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -0,0 +1,14 @@
1
+ x_transformers/__init__.py,sha256=0-2m0LtLpZiZYGwO-6OMYXofx5hbFb_FJOHMxIBqQr4,673
2
+ x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
3
+ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
4
+ x_transformers/continuous.py,sha256=92Wczoaz6dJalix-e3mdIzW0xyRIx3GlBSgsSQOsJeI,6123
5
+ x_transformers/dpo.py,sha256=ek9dgiSs05xeCn8ORceOgKy6LJOnNDw-OJDqxAVLecM,2243
6
+ x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
7
+ x_transformers/x_transformers.py,sha256=c8axLT-n2zz3mvQ1tBbE4KUs-8qL7yFsgtIujyh1JDg,63408
8
+ x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
+ x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
10
+ x_transformers-1.27.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.27.8.dist-info/METADATA,sha256=LYDYUsXQOHYBZRr_5pepdN9HSzaW-2nFX5pEzEOFkcA,661
12
+ x_transformers-1.27.8.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
13
+ x_transformers-1.27.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.27.8.dist-info/RECORD,,
@@ -1,13 +0,0 @@
1
- x_transformers/__init__.py,sha256=pXc_U4M3ONUQcpNgZySDIlCF1rp7u4FFmcOYjc4WuXw,629
2
- x_transformers/attend.py,sha256=MFl_FbgPsm9mziZPTi_s8QbxASETwbGeciMH8sUIwT8,10188
3
- x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRTJCNrYRdJ-Ww,9618
4
- x_transformers/continuous.py,sha256=SAZGR-3BgXU7OEQtjg1_9FnrUBpIyVfXfpMrH-oL5rU,6074
5
- x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=8nBe_MQLfQDHf59pM_c2IiTVHA9frFbFyg4n8S00ZVI,63402
7
- x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
- x_transformers/xval.py,sha256=ulEPep6i5Hl7H-H9vGfdsmHdprUmK8ajB306jViyV2c,8147
9
- x_transformers-1.27.6.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
10
- x_transformers-1.27.6.dist-info/METADATA,sha256=32gbNOf9pJgUoNTdaplhW1mcB4ECiJo5OTCpIVNFWCA,661
11
- x_transformers-1.27.6.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
12
- x_transformers-1.27.6.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
13
- x_transformers-1.27.6.dist-info/RECORD,,