x-transformers 2.7.6__py3-none-any.whl → 2.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/gpt_vae.py +208 -0
- {x_transformers-2.7.6.dist-info → x_transformers-2.8.1.dist-info}/METADATA +13 -1
- {x_transformers-2.7.6.dist-info → x_transformers-2.8.1.dist-info}/RECORD +5 -4
- {x_transformers-2.7.6.dist-info → x_transformers-2.8.1.dist-info}/WHEEL +0 -0
- {x_transformers-2.7.6.dist-info → x_transformers-2.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,208 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
# applying the cvae + detr design from ACT (Zhou et al.) to GPT
|
4
|
+
# for steering, diversity rlvr, map-elites in epo, and other possibilities
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn, Tensor, is_tensor, tensor
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from torch.nn import Module, ModuleList
|
10
|
+
|
11
|
+
from x_transformers.x_transformers import (
|
12
|
+
Encoder,
|
13
|
+
Decoder,
|
14
|
+
TransformerWrapper
|
15
|
+
)
|
16
|
+
|
17
|
+
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
18
|
+
|
19
|
+
from einops.layers.torch import Rearrange
|
20
|
+
from einops import rearrange, reduce, repeat
|
21
|
+
|
22
|
+
# helper functions
|
23
|
+
|
24
|
+
def exists(v):
|
25
|
+
return v is not None
|
26
|
+
|
27
|
+
def default(v, d):
|
28
|
+
return v if exists(v) else d
|
29
|
+
|
30
|
+
# classes
|
31
|
+
|
32
|
+
class GPTVAE(Module):
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
*,
|
36
|
+
num_tokens,
|
37
|
+
dim,
|
38
|
+
depth,
|
39
|
+
enc_depth,
|
40
|
+
max_seq_len,
|
41
|
+
dim_latent = None,
|
42
|
+
attn_dim_head = 64,
|
43
|
+
heads = 8,
|
44
|
+
enc_kwargs: dict = dict(),
|
45
|
+
dec_kwargs: dict = dict(),
|
46
|
+
vae_kl_loss_weight = 1.,
|
47
|
+
latents_dropout_prob = 0.5, # what percentage of the time to dropout the latents completely
|
48
|
+
pad_id = -1,
|
49
|
+
**kwargs
|
50
|
+
):
|
51
|
+
super().__init__()
|
52
|
+
dim_latent = default(dim_latent, dim)
|
53
|
+
|
54
|
+
self.encoder = TransformerWrapper(
|
55
|
+
num_tokens = num_tokens,
|
56
|
+
max_seq_len = max_seq_len + 1,
|
57
|
+
return_only_embed = True,
|
58
|
+
average_pool_embed = True,
|
59
|
+
attn_layers = Encoder(
|
60
|
+
dim = dim,
|
61
|
+
depth = enc_depth,
|
62
|
+
attn_dim_head = attn_dim_head,
|
63
|
+
heads = heads,
|
64
|
+
**kwargs,
|
65
|
+
**enc_kwargs
|
66
|
+
),
|
67
|
+
)
|
68
|
+
|
69
|
+
self.to_latent_mean_log_variance = nn.Sequential(
|
70
|
+
nn.Linear(dim, dim_latent * 2),
|
71
|
+
Rearrange('b (two d) -> two b d', two = 2)
|
72
|
+
)
|
73
|
+
|
74
|
+
self.from_latent_to_prepend_token = nn.Sequential(
|
75
|
+
nn.Linear(dim_latent, dim),
|
76
|
+
Rearrange('b d -> b 1 d')
|
77
|
+
)
|
78
|
+
|
79
|
+
self.decoder = TransformerWrapper(
|
80
|
+
num_tokens = num_tokens,
|
81
|
+
max_seq_len = max_seq_len,
|
82
|
+
attn_layers = Decoder(
|
83
|
+
dim = dim,
|
84
|
+
depth = depth,
|
85
|
+
attn_dim_head = attn_dim_head,
|
86
|
+
heads = heads,
|
87
|
+
**kwargs,
|
88
|
+
**dec_kwargs
|
89
|
+
),
|
90
|
+
)
|
91
|
+
|
92
|
+
self.ar_wrapped_decoder = AutoregressiveWrapper(self.decoder, ignore_index = pad_id)
|
93
|
+
|
94
|
+
self.pad_id = pad_id
|
95
|
+
|
96
|
+
# loss weights - vae kl loss
|
97
|
+
|
98
|
+
self.vae_kl_loss_weight = vae_kl_loss_weight
|
99
|
+
|
100
|
+
self.latents_dropout = nn.Dropout(latents_dropout_prob)
|
101
|
+
|
102
|
+
@property
|
103
|
+
def device(self):
|
104
|
+
return next(self.parameters()).device
|
105
|
+
|
106
|
+
def encode_to_latents(
|
107
|
+
self,
|
108
|
+
seq,
|
109
|
+
return_mean_log_var = False
|
110
|
+
):
|
111
|
+
mask = seq != self.pad_id
|
112
|
+
pooled = self.encoder(seq, mask = mask)
|
113
|
+
|
114
|
+
latents_mean, latents_log_var = self.to_latent_mean_log_variance(pooled)
|
115
|
+
latents_std = (0.5 * latents_log_var).exp()
|
116
|
+
|
117
|
+
# reparam trick
|
118
|
+
|
119
|
+
latents = latents_mean + latents_std * torch.randn_like(latents_mean)
|
120
|
+
|
121
|
+
if not return_mean_log_var:
|
122
|
+
return latents
|
123
|
+
|
124
|
+
return latents, (latents_mean, latents_log_var)
|
125
|
+
|
126
|
+
@torch.no_grad()
|
127
|
+
def generate(
|
128
|
+
self,
|
129
|
+
prompts,
|
130
|
+
seq_len,
|
131
|
+
latents = None,
|
132
|
+
seq_for_latents = None,
|
133
|
+
**generate_kwargs
|
134
|
+
):
|
135
|
+
assert prompts.ndim in {1, 2}
|
136
|
+
batch = prompts.shape[0] if prompts.ndim == 2 else 1
|
137
|
+
|
138
|
+
# if seq_for_latents passed in, derive latents from it
|
139
|
+
|
140
|
+
if exists(seq_for_latents):
|
141
|
+
assert not exists(latents), 'latents should not be passed in if given the seq from which to derive them'
|
142
|
+
|
143
|
+
latents = self.encode_to_latents(seq_for_latents)
|
144
|
+
|
145
|
+
# prepend embeds
|
146
|
+
|
147
|
+
prepend_embeds = None
|
148
|
+
if exists(latents):
|
149
|
+
if not is_tensor(latents):
|
150
|
+
latents = tensor(latents, device = self.device)
|
151
|
+
|
152
|
+
if latents.ndim == 1: # repeat latents
|
153
|
+
latents = repeat(latents, 'd -> b d', b = batch)
|
154
|
+
|
155
|
+
prepend_embeds = self.from_latent_to_prepend_token(latents)
|
156
|
+
|
157
|
+
# generated
|
158
|
+
|
159
|
+
generated = self.ar_wrapped_decoder.generate(
|
160
|
+
prompts,
|
161
|
+
seq_len,
|
162
|
+
prepend_embeds = prepend_embeds,
|
163
|
+
**generate_kwargs
|
164
|
+
)
|
165
|
+
|
166
|
+
return generated
|
167
|
+
|
168
|
+
def forward(
|
169
|
+
self,
|
170
|
+
seq,
|
171
|
+
return_all_losses = False
|
172
|
+
):
|
173
|
+
batch, device = seq.shape[0], seq.device
|
174
|
+
|
175
|
+
latents, (latents_mean, latents_log_var) = self.encode_to_latents(seq, return_mean_log_var = True)
|
176
|
+
|
177
|
+
dropped_latents = ~self.latents_dropout(torch.ones((batch,), device = device)).bool()
|
178
|
+
|
179
|
+
prepend_embeds = self.from_latent_to_prepend_token(latents)
|
180
|
+
|
181
|
+
ar_loss = self.ar_wrapped_decoder(
|
182
|
+
seq,
|
183
|
+
prepend_embeds = prepend_embeds,
|
184
|
+
seq_start_pos = dropped_latents.long() # sequence starts at 1 and does not attend to the first style latent
|
185
|
+
)
|
186
|
+
|
187
|
+
# vae kl loss
|
188
|
+
|
189
|
+
vae_kl_loss = (
|
190
|
+
latents_log_var.exp()
|
191
|
+
+ latents_mean.square()
|
192
|
+
- latents_log_var
|
193
|
+
- 1.
|
194
|
+
).sum(dim = -1).mean()
|
195
|
+
|
196
|
+
# return losses
|
197
|
+
|
198
|
+
total_loss = (
|
199
|
+
ar_loss +
|
200
|
+
vae_kl_loss * self.vae_kl_loss_weight
|
201
|
+
)
|
202
|
+
|
203
|
+
if not return_all_losses:
|
204
|
+
return total_loss
|
205
|
+
|
206
|
+
losses = (ar_loss, vae_kl_loss)
|
207
|
+
|
208
|
+
return total_loss, losses
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.8.1
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -2540,4 +2540,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
|
|
2540
2540
|
}
|
2541
2541
|
```
|
2542
2542
|
|
2543
|
+
```bibtex
|
2544
|
+
@misc{zhao2023learningfinegrainedbimanualmanipulation,
|
2545
|
+
title = {Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
|
2546
|
+
author = {Tony Z. Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
|
2547
|
+
year = {2023},
|
2548
|
+
eprint = {2304.13705},
|
2549
|
+
archivePrefix = {arXiv},
|
2550
|
+
primaryClass = {cs.RO},
|
2551
|
+
url = {https://arxiv.org/abs/2304.13705},
|
2552
|
+
}
|
2553
|
+
```
|
2554
|
+
|
2543
2555
|
*solve intelligence... then use that to solve everything else.* - Demis Hassabis
|
@@ -5,6 +5,7 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
|
|
5
5
|
x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
|
+
x_transformers/gpt_vae.py,sha256=Q2pzQ6iXRnP2Bfa6g-fs4US-JTouXB5-MfKw3sTwWmU,5561
|
8
9
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
9
10
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
10
11
|
x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
|
@@ -12,7 +13,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
|
|
12
13
|
x_transformers/x_transformers.py,sha256=odnCZAKZKrQLXmpaWhiPVB5elGjt8kerDbO3-yeC-60,124764
|
13
14
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
14
15
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
15
|
-
x_transformers-2.
|
16
|
-
x_transformers-2.
|
17
|
-
x_transformers-2.
|
18
|
-
x_transformers-2.
|
16
|
+
x_transformers-2.8.1.dist-info/METADATA,sha256=_PnvoOSFJAgrpEfpNNljxdeYQ3BhDYJvVOp7yjaF-iM,94136
|
17
|
+
x_transformers-2.8.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
18
|
+
x_transformers-2.8.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
19
|
+
x_transformers-2.8.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|