x-transformers 2.1.2__py3-none-any.whl → 2.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,162 @@
1
+ # Belief State Transformer
2
+
3
+ # https://arxiv.org/abs/2410.23506
4
+ # https://www.youtube.com/watch?v=aqhbRtB2Fyg
5
+
6
+ import torch
7
+ from torch.autograd import Function
8
+ from torch.nn import Module, ModuleList
9
+ from torch import nn, cat, stack, arange, cartesian_prod
10
+ import torch.nn.functional as F
11
+
12
+ from x_transformers.x_transformers import (
13
+ Decoder,
14
+ TransformerWrapper
15
+ )
16
+
17
+ import einx
18
+ from einops import rearrange, repeat
19
+
20
+ # helper functions
21
+
22
+ def exists(v):
23
+ return v is not None
24
+
25
+ def default(v, d):
26
+ return v if exists(v) else d
27
+
28
+ # wrappers
29
+
30
+ class BeliefStateWrapper(Module):
31
+ """
32
+ Figure 13. in https://arxiv.org/abs/2410.23506
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ forward_decoder: TransformerWrapper,
38
+ backward_decoder: TransformerWrapper,
39
+ train_frac_forward_backward_pairs: float = 1.
40
+ ):
41
+ super().__init__()
42
+ assert forward_decoder.emb_dim == backward_decoder.emb_dim, 'forward and backwards model must have the same embedding dimension'
43
+ assert forward_decoder.num_tokens == backward_decoder.num_tokens, 'forward and backwards model must have the same number of tokens'
44
+
45
+ dim = forward_decoder.emb_dim
46
+ num_tokens = forward_decoder.num_tokens
47
+
48
+ # the suffix token
49
+
50
+ self.suffix_token = nn.Parameter(torch.zeros(dim))
51
+ nn.init.normal_(self.suffix_token, std = 0.02)
52
+
53
+ # the text prediction head, which predicts for the combinations of prefix and suffix the next and previous token for forwards and backward sequences
54
+
55
+ self.text_head = nn.Sequential(
56
+ nn.Linear(dim * 2, dim),
57
+ nn.LeakyReLU(),
58
+ nn.Linear(dim, num_tokens * 2),
59
+ )
60
+
61
+ # the two decoders, one which is causal forward, the other causal backwards
62
+
63
+ self.forward_decoder = forward_decoder
64
+ self.backward_decoder = backward_decoder
65
+
66
+ # what fraction of forward backward pairs to train on
67
+ # for further memory efficiency
68
+
69
+ assert 0 < train_frac_forward_backward_pairs <= 1.
70
+ self.train_frac_fb_pairs = train_frac_forward_backward_pairs
71
+ self.needs_subsample_fb_pairs = train_frac_forward_backward_pairs < 1.
72
+
73
+ def forward(
74
+ self,
75
+ seq
76
+ ):
77
+ batch, seq_len, device = *seq.shape, seq.device
78
+
79
+ # forward autoregressive
80
+
81
+ forward_embeds = self.forward_decoder(seq, return_embeddings = True)
82
+
83
+ # backward autoregressive
84
+
85
+ backward_seq = seq.flip(1)
86
+
87
+ suffix_tokens = repeat(self.suffix_token, 'd -> b 1 d', b = batch)
88
+
89
+ backward_embeds = self.backward_decoder(
90
+ backward_seq,
91
+ prepend_embeds = suffix_tokens,
92
+ return_embeddings = True
93
+ )
94
+
95
+ backward_embeds = backward_embeds.flip(1)
96
+
97
+ # trick to reduce memory on backwards pass
98
+
99
+ orig_forward_embeds, forward_embeds = forward_embeds, forward_embeds.detach()
100
+ orig_backward_embeds, backward_embeds = backward_embeds, backward_embeds.detach()
101
+
102
+ forward_embeds.requires_grad_()
103
+ backward_embeds.requires_grad_()
104
+
105
+ # belief state objective
106
+
107
+ seq_arange = arange(seq_len, device = device)
108
+
109
+ fb_pairs = cartesian_prod(seq_arange, seq_arange)
110
+
111
+ # filter down to valid pairs, as in figure 11
112
+ # f - forward, b - backward, i - indices
113
+
114
+ fi, bi = fb_pairs.unbind(dim = -1)
115
+ valid_mask = (bi - fi) >= 2
116
+
117
+ fb_pairs = fb_pairs[valid_mask]
118
+
119
+ # maybe subsample fb pairs
120
+
121
+ if self.needs_subsample_fb_pairs:
122
+ num_pairs = fb_pairs.shape[0]
123
+
124
+ num_subsampled = max(int(num_pairs * self.train_frac_fb_pairs), 1)
125
+
126
+ rand_subsampled_indices = torch.randperm(num_pairs, device = device)[:num_subsampled]
127
+
128
+ fb_pairs = fb_pairs[rand_subsampled_indices]
129
+
130
+ # get labels for both
131
+
132
+ fi, bi = fb_pairs.unbind(dim = -1)
133
+
134
+ labels_fi, labels_bi = (fi + 1), bi
135
+
136
+ forward_labels, backward_labels = seq[:, fi], seq[:, bi]
137
+ labels = stack((forward_labels, backward_labels), dim = -1)
138
+
139
+ # get the forward and backward embedding pairs and feed them through the text head for both forward and backward predictions
140
+
141
+ fb_embeds = cat((
142
+ forward_embeds[:, fi],
143
+ backward_embeds[:, bi]
144
+ ), dim = -1)
145
+
146
+ logits = self.text_head(fb_embeds)
147
+
148
+ # cross entropy loss
149
+
150
+ fb_loss = F.cross_entropy(
151
+ rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
152
+ rearrange(labels, 'b n fb -> b (fb n)')
153
+ )
154
+
155
+ # backwards
156
+
157
+ fb_loss.backward()
158
+
159
+ orig_forward_embeds.backward(forward_embeds.grad)
160
+ orig_backward_embeds.backward(backward_embeds.grad)
161
+
162
+ return fb_loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.2
3
+ Version: 2.1.5
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2435,4 +2435,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2435
2435
  }
2436
2436
  ```
2437
2437
 
2438
+ ```bibtex
2439
+ @inproceedings{Hu2024TheBS,
2440
+ title = {The Belief State Transformer},
2441
+ author = {Edward S. Hu and Kwangjun Ahn and Qinghua Liu and Haoran Xu and Manan Tomar and Ada Langford and Dinesh Jayaraman and Alex Lamb and John Langford},
2442
+ year = {2024},
2443
+ url = {https://api.semanticscholar.org/CorpusID:273707334}
2444
+ }
2445
+ ```
2446
+
2438
2447
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,7 @@
1
1
  x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
2
2
  x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
+ x_transformers/belief_state.py,sha256=GfYDeDqmhldozECgFsJ9zhd6O5NMvdYA5OwueVs8SB4,4742
4
5
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
5
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -9,7 +10,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
9
10
  x_transformers/x_transformers.py,sha256=-80N4sqUr3sR51Ms4wCfc4jhxnPwf0ApNR4xfIsasfQ,110142
10
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-2.1.2.dist-info/METADATA,sha256=-LsNGhf7qKzttPNU7VOSVqigs61_Nuw4r0LBlZDT_Qo,87227
13
- x_transformers-2.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- x_transformers-2.1.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
15
- x_transformers-2.1.2.dist-info/RECORD,,
13
+ x_transformers-2.1.5.dist-info/METADATA,sha256=-nQpm1eBGBXkLEXGiQK06NaIWa13CEnBujBNfTzvnJ8,87570
14
+ x_transformers-2.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.5.dist-info/RECORD,,