x-transformers 2.7.6__py3-none-any.whl → 2.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,208 @@
1
+ from __future__ import annotations
2
+
3
+ # applying the cvae + detr design from ACT (Zhou et al.) to GPT
4
+ # for steering, diversity rlvr, map-elites in epo, and other possibilities
5
+
6
+ import torch
7
+ from torch import nn, Tensor, is_tensor, tensor
8
+ import torch.nn.functional as F
9
+ from torch.nn import Module, ModuleList
10
+
11
+ from x_transformers.x_transformers import (
12
+ Encoder,
13
+ Decoder,
14
+ TransformerWrapper
15
+ )
16
+
17
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
18
+
19
+ from einops.layers.torch import Rearrange
20
+ from einops import rearrange, reduce, repeat
21
+
22
+ # helper functions
23
+
24
+ def exists(v):
25
+ return v is not None
26
+
27
+ def default(v, d):
28
+ return v if exists(v) else d
29
+
30
+ # classes
31
+
32
+ class GPTVAE(Module):
33
+ def __init__(
34
+ self,
35
+ *,
36
+ num_tokens,
37
+ dim,
38
+ depth,
39
+ enc_depth,
40
+ max_seq_len,
41
+ dim_latent = None,
42
+ attn_dim_head = 64,
43
+ heads = 8,
44
+ enc_kwargs: dict = dict(),
45
+ dec_kwargs: dict = dict(),
46
+ vae_kl_loss_weight = 1.,
47
+ latents_dropout_prob = 0.5, # what percentage of the time to dropout the latents completely
48
+ pad_id = -1,
49
+ **kwargs
50
+ ):
51
+ super().__init__()
52
+ dim_latent = default(dim_latent, dim)
53
+
54
+ self.encoder = TransformerWrapper(
55
+ num_tokens = num_tokens,
56
+ max_seq_len = max_seq_len + 1,
57
+ return_only_embed = True,
58
+ average_pool_embed = True,
59
+ attn_layers = Encoder(
60
+ dim = dim,
61
+ depth = enc_depth,
62
+ attn_dim_head = attn_dim_head,
63
+ heads = heads,
64
+ **kwargs,
65
+ **enc_kwargs
66
+ ),
67
+ )
68
+
69
+ self.to_latent_mean_log_variance = nn.Sequential(
70
+ nn.Linear(dim, dim_latent * 2),
71
+ Rearrange('b (two d) -> two b d', two = 2)
72
+ )
73
+
74
+ self.from_latent_to_prepend_token = nn.Sequential(
75
+ nn.Linear(dim_latent, dim),
76
+ Rearrange('b d -> b 1 d')
77
+ )
78
+
79
+ self.decoder = TransformerWrapper(
80
+ num_tokens = num_tokens,
81
+ max_seq_len = max_seq_len,
82
+ attn_layers = Decoder(
83
+ dim = dim,
84
+ depth = depth,
85
+ attn_dim_head = attn_dim_head,
86
+ heads = heads,
87
+ **kwargs,
88
+ **dec_kwargs
89
+ ),
90
+ )
91
+
92
+ self.ar_wrapped_decoder = AutoregressiveWrapper(self.decoder, ignore_index = pad_id)
93
+
94
+ self.pad_id = pad_id
95
+
96
+ # loss weights - vae kl loss
97
+
98
+ self.vae_kl_loss_weight = vae_kl_loss_weight
99
+
100
+ self.latents_dropout = nn.Dropout(latents_dropout_prob)
101
+
102
+ @property
103
+ def device(self):
104
+ return next(self.parameters()).device
105
+
106
+ def encode_to_latents(
107
+ self,
108
+ seq,
109
+ return_mean_log_var = False
110
+ ):
111
+ mask = seq != self.pad_id
112
+ pooled = self.encoder(seq, mask = mask)
113
+
114
+ latents_mean, latents_log_var = self.to_latent_mean_log_variance(pooled)
115
+ latents_std = (0.5 * latents_log_var).exp()
116
+
117
+ # reparam trick
118
+
119
+ latents = latents_mean + latents_std * torch.randn_like(latents_mean)
120
+
121
+ if not return_mean_log_var:
122
+ return latents
123
+
124
+ return latents, (latents_mean, latents_log_var)
125
+
126
+ @torch.no_grad()
127
+ def generate(
128
+ self,
129
+ prompts,
130
+ seq_len,
131
+ latents = None,
132
+ seq_for_latents = None,
133
+ **generate_kwargs
134
+ ):
135
+ assert prompts.ndim in {1, 2}
136
+ batch = prompts.shape[0] if prompts.ndim == 2 else 1
137
+
138
+ # if seq_for_latents passed in, derive latents from it
139
+
140
+ if exists(seq_for_latents):
141
+ assert not exists(latents), 'latents should not be passed in if given the seq from which to derive them'
142
+
143
+ latents = self.encode_to_latents(seq_for_latents)
144
+
145
+ # prepend embeds
146
+
147
+ prepend_embeds = None
148
+ if exists(latents):
149
+ if not is_tensor(latents):
150
+ latents = tensor(latents, device = self.device)
151
+
152
+ if latents.ndim == 1: # repeat latents
153
+ latents = repeat(latents, 'd -> b d', b = batch)
154
+
155
+ prepend_embeds = self.from_latent_to_prepend_token(latents)
156
+
157
+ # generated
158
+
159
+ generated = self.ar_wrapped_decoder.generate(
160
+ prompts,
161
+ seq_len,
162
+ prepend_embeds = prepend_embeds,
163
+ **generate_kwargs
164
+ )
165
+
166
+ return generated
167
+
168
+ def forward(
169
+ self,
170
+ seq,
171
+ return_all_losses = False
172
+ ):
173
+ batch, device = seq.shape[0], seq.device
174
+
175
+ latents, (latents_mean, latents_log_var) = self.encode_to_latents(seq, return_mean_log_var = True)
176
+
177
+ dropped_latents = ~self.latents_dropout(torch.ones((batch,), device = device)).bool()
178
+
179
+ prepend_embeds = self.from_latent_to_prepend_token(latents)
180
+
181
+ ar_loss = self.ar_wrapped_decoder(
182
+ seq,
183
+ prepend_embeds = prepend_embeds,
184
+ seq_start_pos = dropped_latents.long() # sequence starts at 1 and does not attend to the first style latent
185
+ )
186
+
187
+ # vae kl loss
188
+
189
+ vae_kl_loss = (
190
+ latents_log_var.exp()
191
+ + latents_mean.square()
192
+ - latents_log_var
193
+ - 1.
194
+ ).sum(dim = -1).mean()
195
+
196
+ # return losses
197
+
198
+ total_loss = (
199
+ ar_loss +
200
+ vae_kl_loss * self.vae_kl_loss_weight
201
+ )
202
+
203
+ if not return_all_losses:
204
+ return total_loss
205
+
206
+ losses = (ar_loss, vae_kl_loss)
207
+
208
+ return total_loss, losses
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.7.6
3
+ Version: 2.8.1
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2540,4 +2540,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2540
2540
  }
2541
2541
  ```
2542
2542
 
2543
+ ```bibtex
2544
+ @misc{zhao2023learningfinegrainedbimanualmanipulation,
2545
+ title = {Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
2546
+ author = {Tony Z. Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
2547
+ year = {2023},
2548
+ eprint = {2304.13705},
2549
+ archivePrefix = {arXiv},
2550
+ primaryClass = {cs.RO},
2551
+ url = {https://arxiv.org/abs/2304.13705},
2552
+ }
2553
+ ```
2554
+
2543
2555
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -5,6 +5,7 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
+ x_transformers/gpt_vae.py,sha256=Q2pzQ6iXRnP2Bfa6g-fs4US-JTouXB5-MfKw3sTwWmU,5561
8
9
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
10
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
11
  x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
@@ -12,7 +13,7 @@ x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,
12
13
  x_transformers/x_transformers.py,sha256=odnCZAKZKrQLXmpaWhiPVB5elGjt8kerDbO3-yeC-60,124764
13
14
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
15
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.7.6.dist-info/METADATA,sha256=n-AKJXX2Ko3XlehMOv5MojPrFaHdRi4lRkvcGAFOXR4,93739
16
- x_transformers-2.7.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.7.6.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.7.6.dist-info/RECORD,,
16
+ x_transformers-2.8.1.dist-info/METADATA,sha256=_PnvoOSFJAgrpEfpNNljxdeYQ3BhDYJvVOp7yjaF-iM,94136
17
+ x_transformers-2.8.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
+ x_transformers-2.8.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
19
+ x_transformers-2.8.1.dist-info/RECORD,,