x-transformers 2.7.5__py3-none-any.whl → 2.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,200 @@
1
+ from __future__ import annotations
2
+
3
+ # applying the cvae + detr design from ACT (Zhou et al.) to GPT
4
+ # for steering, diversity rlvr, map-elites in epo, and other possibilities
5
+
6
+ import torch
7
+ from torch import nn, Tensor, is_tensor, tensor
8
+ import torch.nn.functional as F
9
+ from torch.nn import Module, ModuleList
10
+
11
+ from x_transformers.x_transformers import (
12
+ Encoder,
13
+ Decoder,
14
+ TransformerWrapper
15
+ )
16
+
17
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
18
+
19
+ from einops.layers.torch import Rearrange
20
+ from einops import rearrange, reduce, repeat
21
+
22
+ # helper functions
23
+
24
+ def exists(v):
25
+ return v is not None
26
+
27
+ def default(v, d):
28
+ return v if exists(v) else d
29
+
30
+ # classes
31
+
32
+ class GPTVAE(Module):
33
+ def __init__(
34
+ self,
35
+ *,
36
+ num_tokens,
37
+ dim,
38
+ depth,
39
+ enc_depth,
40
+ max_seq_len,
41
+ dim_latent = None,
42
+ attn_dim_head = 64,
43
+ heads = 8,
44
+ enc_kwargs: dict = dict(),
45
+ dec_kwargs: dict = dict(),
46
+ vae_kl_loss_weight = 1.,
47
+ latents_dropout_prob = 0.5, # what percentage of the time to dropout the latents completely
48
+ pad_id = -1,
49
+ **kwargs
50
+ ):
51
+ super().__init__()
52
+ dim_latent = default(dim_latent, dim)
53
+
54
+ self.encoder = TransformerWrapper(
55
+ num_tokens = num_tokens,
56
+ max_seq_len = max_seq_len + 1,
57
+ return_only_embed = True,
58
+ average_pool_embed = True,
59
+ attn_layers = Encoder(
60
+ dim = dim,
61
+ depth = enc_depth,
62
+ attn_dim_head = attn_dim_head,
63
+ heads = heads,
64
+ **kwargs,
65
+ **enc_kwargs
66
+ ),
67
+ )
68
+
69
+ self.to_latent_mean_log_variance = nn.Sequential(
70
+ nn.Linear(dim, dim_latent * 2),
71
+ Rearrange('b (two d) -> two b 1 d', two = 2)
72
+ )
73
+
74
+ self.from_latent_to_prepend_token = nn.Linear(dim_latent, dim)
75
+
76
+ self.decoder = TransformerWrapper(
77
+ num_tokens = num_tokens,
78
+ max_seq_len = max_seq_len,
79
+ attn_layers = Decoder(
80
+ dim = dim,
81
+ depth = depth,
82
+ attn_dim_head = attn_dim_head,
83
+ heads = heads,
84
+ **kwargs,
85
+ **dec_kwargs
86
+ ),
87
+ )
88
+
89
+ self.ar_wrapped_decoder = AutoregressiveWrapper(self.decoder, ignore_index = pad_id)
90
+
91
+ self.pad_id = pad_id
92
+
93
+ # loss weights - vae kl loss
94
+
95
+ self.vae_kl_loss_weight = vae_kl_loss_weight
96
+
97
+ self.latents_dropout = nn.Dropout(latents_dropout_prob)
98
+
99
+ @property
100
+ def device(self):
101
+ return next(self.parameters()).device
102
+
103
+ def encode_to_latents(
104
+ self,
105
+ seq,
106
+ return_mean_log_var = False
107
+ ):
108
+ mask = seq != self.pad_id
109
+ pooled = self.encoder(seq, mask = mask)
110
+
111
+ latents_mean, latents_log_var = self.to_latent_mean_log_variance(pooled)
112
+ latents_std = (0.5 * latents_log_var).exp()
113
+
114
+ # reparam trick
115
+
116
+ latents = latents_mean + latents_std * torch.randn_like(latents_mean)
117
+
118
+ if not return_mean_log_var:
119
+ return latents
120
+
121
+ return latents, (latents_mean, latents_log_var)
122
+
123
+ @torch.no_grad()
124
+ def generate(
125
+ self,
126
+ prompts,
127
+ seq_len,
128
+ latents = None,
129
+ **generate_kwargs
130
+ ):
131
+ assert prompts.ndim in {1, 2}
132
+ batch = prompts.shape[0] if prompts.ndim == 2 else 1
133
+
134
+ # prepend embeds
135
+
136
+ prepend_embeds = None
137
+ if exists(latents):
138
+ if not is_tensor(latents):
139
+ latents = tensor(latents, device = self.device)
140
+
141
+ if latents.ndim == 1: # repeat latents
142
+ latents = repeat(latents, 'd -> b d', b = batch)
143
+
144
+ prepend_embeds = self.from_latent_to_prepend_token(latents)
145
+
146
+ if exists(prepend_embeds):
147
+ prepend_embeds = rearrange(prepend_embeds, 'b d -> b 1 d')
148
+
149
+ # generated
150
+
151
+ generated = self.ar_wrapped_decoder.generate(
152
+ prompts,
153
+ seq_len,
154
+ prepend_embeds = prepend_embeds,
155
+ **generate_kwargs
156
+ )
157
+
158
+ return generated
159
+
160
+ def forward(
161
+ self,
162
+ seq,
163
+ return_all_losses = False
164
+ ):
165
+ batch, device = seq.shape[0], seq.device
166
+
167
+ latents, (latents_mean, latents_log_var) = self.encode_to_latents(seq, return_mean_log_var = True)
168
+
169
+ dropped_latents = ~self.latents_dropout(torch.ones((batch,), device = device)).bool()
170
+
171
+ prepend_embeds = self.from_latent_to_prepend_token(latents)
172
+
173
+ ar_loss = self.ar_wrapped_decoder(
174
+ seq,
175
+ prepend_embeds = prepend_embeds,
176
+ seq_start_pos = dropped_latents.long() # sequence starts at 1 and does not attend to the first style latent
177
+ )
178
+
179
+ # vae kl loss
180
+
181
+ vae_kl_loss = (
182
+ latents_log_var.exp()
183
+ + latents_mean.square()
184
+ - latents_log_var
185
+ - 1.
186
+ ).sum(dim = -1).mean()
187
+
188
+ # return losses
189
+
190
+ total_loss = (
191
+ ar_loss +
192
+ vae_kl_loss * self.vae_kl_loss_weight
193
+ )
194
+
195
+ if not return_all_losses:
196
+ return total_loss
197
+
198
+ losses = (ar_loss, vae_kl_loss)
199
+
200
+ return total_loss, losses
@@ -2469,12 +2469,15 @@ class AttentionLayers(Module):
2469
2469
  ):
2470
2470
  # pairs up the attention intermediates with each attention module and does qk clip proposed by kimi team
2471
2471
 
2472
- for (_, layer, _), layer_type, attn_inter in zip(self.layers, self.layer_types, intermediates.attn_intermediates):
2472
+ layer_and_layer_types = (self.layers, self.layer_types)
2473
2473
 
2474
- if layer_type not in ('a', 'c'):
2475
- continue
2474
+ attn_layers = [layer for (_, layer, _), layer_type in zip(self.layers, self.layer_types) if layer_type in ('a', 'c')]
2475
+ attn_intermeds = intermediates.attn_intermediates
2476
+
2477
+ assert len(attn_layers) == len(attn_intermeds)
2476
2478
 
2477
- layer.qk_clip_(attn_inter, tau = tau)
2479
+ for attn_layer, attn_inter in zip(attn_layers, attn_intermeds):
2480
+ attn_layer.qk_clip_(attn_inter, tau = tau)
2478
2481
 
2479
2482
  def forward(
2480
2483
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.7.5
3
+ Version: 2.8.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2540,4 +2540,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2540
2540
  }
2541
2541
  ```
2542
2542
 
2543
+ ```bibtex
2544
+ @misc{zhao2023learningfinegrainedbimanualmanipulation,
2545
+ title = {Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
2546
+ author = {Tony Z. Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
2547
+ year = {2023},
2548
+ eprint = {2304.13705},
2549
+ archivePrefix = {arXiv},
2550
+ primaryClass = {cs.RO},
2551
+ url = {https://arxiv.org/abs/2304.13705},
2552
+ }
2553
+ ```
2554
+
2543
2555
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -5,14 +5,15 @@ x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTN
5
5
  x_transformers/continuous.py,sha256=WwpQCjyVY4PtuEAOFY68zqgklbF9I7AL5w6874YlDe8,13249
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
+ x_transformers/gpt_vae.py,sha256=yqL1K2yJ6RSP_MC6XSHI3hjiUnaptddg6CUnbEX4Bsk,5281
8
9
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
9
10
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
11
  x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
11
12
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=xaGBkYCy6CqL0q9icWmL_WzCeU6ZztEYEkMtN71L2z4,124576
13
+ x_transformers/x_transformers.py,sha256=odnCZAKZKrQLXmpaWhiPVB5elGjt8kerDbO3-yeC-60,124764
13
14
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
15
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.7.5.dist-info/METADATA,sha256=m6f4PIgJFKKWlsGAydi_Bg5-7-0IRlor0pRY_zBh5s8,93739
16
- x_transformers-2.7.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.7.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.7.5.dist-info/RECORD,,
16
+ x_transformers-2.8.0.dist-info/METADATA,sha256=jPo0ZPhD1d_aocaDqFYWXA7EXPAcxWeUYNDzKpY1yi8,94136
17
+ x_transformers-2.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
18
+ x_transformers-2.8.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
19
+ x_transformers-2.8.0.dist-info/RECORD,,