x-transformers 2.7.6__tar.gz → 2.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. {x_transformers-2.7.6 → x_transformers-2.8.0}/PKG-INFO +13 -1
  2. {x_transformers-2.7.6 → x_transformers-2.8.0}/README.md +12 -0
  3. {x_transformers-2.7.6 → x_transformers-2.8.0}/pyproject.toml +1 -1
  4. x_transformers-2.8.0/train_gpt_vae.py +131 -0
  5. x_transformers-2.8.0/x_transformers/gpt_vae.py +200 -0
  6. {x_transformers-2.7.6 → x_transformers-2.8.0}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.7.6 → x_transformers-2.8.0}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.7.6 → x_transformers-2.8.0}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.7.6 → x_transformers-2.8.0}/.gitignore +0 -0
  10. {x_transformers-2.7.6 → x_transformers-2.8.0}/LICENSE +0 -0
  11. {x_transformers-2.7.6 → x_transformers-2.8.0}/data/README.md +0 -0
  12. {x_transformers-2.7.6 → x_transformers-2.8.0}/data/enwik8.gz +0 -0
  13. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/all-attention.png +0 -0
  14. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/attention-on-attention.png +0 -0
  15. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/cosine-sim-attention.png +0 -0
  16. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/deepnorm.png +0 -0
  17. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/dynamic-pos-bias-linear.png +0 -0
  18. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/dynamic-pos-bias-log.png +0 -0
  19. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  20. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/dynamic-pos-bias.png +0 -0
  21. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/enhanced-recurrence.png +0 -0
  22. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/fcm.png +0 -0
  23. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/ffglu.png +0 -0
  24. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/flash-attention.png +0 -0
  25. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/gate_values.png +0 -0
  26. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/gating.png +0 -0
  27. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/length-extrapolation-scale.png +0 -0
  28. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/macaron-1.png +0 -0
  29. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/macaron-2.png +0 -0
  30. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/memory-transformer.png +0 -0
  31. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/normformer.png +0 -0
  32. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/pia.png +0 -0
  33. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/qknorm-analysis.png +0 -0
  34. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/resi_dual.png +0 -0
  35. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/residual_attn.png +0 -0
  36. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/rezero.png +0 -0
  37. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/rotary.png +0 -0
  38. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/sandwich-2.png +0 -0
  39. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/sandwich.png +0 -0
  40. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/sandwich_norm.png +0 -0
  41. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/scalenorm.png +0 -0
  42. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/talking-heads.png +0 -0
  43. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/topk-attention.png +0 -0
  44. {x_transformers-2.7.6 → x_transformers-2.8.0}/images/xval.png +0 -0
  45. {x_transformers-2.7.6 → x_transformers-2.8.0}/tests/test_x_transformers.py +0 -0
  46. {x_transformers-2.7.6 → x_transformers-2.8.0}/train_belief_state.py +0 -0
  47. {x_transformers-2.7.6 → x_transformers-2.8.0}/train_copy.py +0 -0
  48. {x_transformers-2.7.6 → x_transformers-2.8.0}/train_entropy_tokenizer.py +0 -0
  49. {x_transformers-2.7.6 → x_transformers-2.8.0}/train_enwik8.py +0 -0
  50. {x_transformers-2.7.6 → x_transformers-2.8.0}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.7.6 → x_transformers-2.8.0}/train_parity.py +0 -0
  52. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/__init__.py +0 -0
  53. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/attend.py +0 -0
  54. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/autoregressive_wrapper.py +0 -0
  55. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/belief_state_wrapper.py +0 -0
  56. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/continuous.py +0 -0
  57. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/dpo.py +0 -0
  58. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/entropy_based_tokenizer.py +0 -0
  59. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/multi_input.py +0 -0
  60. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/neo_mlp.py +0 -0
  61. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  62. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/up_wrapper.py +0 -0
  63. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/x_transformers.py +0 -0
  64. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  65. {x_transformers-2.7.6 → x_transformers-2.8.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.7.6
3
+ Version: 2.8.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2540,4 +2540,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2540
2540
  }
2541
2541
  ```
2542
2542
 
2543
+ ```bibtex
2544
+ @misc{zhao2023learningfinegrainedbimanualmanipulation,
2545
+ title = {Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
2546
+ author = {Tony Z. Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
2547
+ year = {2023},
2548
+ eprint = {2304.13705},
2549
+ archivePrefix = {arXiv},
2550
+ primaryClass = {cs.RO},
2551
+ url = {https://arxiv.org/abs/2304.13705},
2552
+ }
2553
+ ```
2554
+
2543
2555
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -2492,4 +2492,16 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
2492
2492
  }
2493
2493
  ```
2494
2494
 
2495
+ ```bibtex
2496
+ @misc{zhao2023learningfinegrainedbimanualmanipulation,
2497
+ title = {Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
2498
+ author = {Tony Z. Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
2499
+ year = {2023},
2500
+ eprint = {2304.13705},
2501
+ archivePrefix = {arXiv},
2502
+ primaryClass = {cs.RO},
2503
+ url = {https://arxiv.org/abs/2304.13705},
2504
+ }
2505
+ ```
2506
+
2495
2507
  *solve intelligence... then use that to solve everything else.* - Demis Hassabis
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.7.6"
3
+ version = "2.8.0"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,131 @@
1
+
2
+ from x_transformers.gpt_vae import GPTVAE
3
+
4
+ import random
5
+ import tqdm
6
+ import gzip
7
+ import numpy as np
8
+
9
+ import torch
10
+ import torch.optim as optim
11
+ from torch import tensor
12
+ from torch.nn import functional as F
13
+ from torch.utils.data import DataLoader, Dataset
14
+
15
+ # constants
16
+
17
+ NUM_BATCHES = int(1e5)
18
+ BATCH_SIZE = 4
19
+ GRADIENT_ACCUMULATE_EVERY = 4
20
+ LEARNING_RATE = 1e-4
21
+ VALIDATE_EVERY = 100
22
+ GENERATE_EVERY = 500
23
+ GENERATE_LENGTH = 512
24
+ SEQ_LEN = 512
25
+
26
+ # helpers
27
+
28
+ def cycle(loader):
29
+ while True:
30
+ for data in loader:
31
+ yield data
32
+
33
+ def decode_token(token):
34
+ return str(chr(max(32, token)))
35
+
36
+ def decode_tokens(tokens):
37
+ return ''.join(list(map(decode_token, tokens)))
38
+
39
+ # instantiate GPT-like decoder model
40
+
41
+ model = GPTVAE(
42
+ num_tokens = 256,
43
+ max_seq_len = SEQ_LEN,
44
+ dim = 512,
45
+ depth = 6,
46
+ heads = 8,
47
+ rotary_pos_emb = True,
48
+ enc_depth = 3,
49
+ vae_kl_loss_weight = 1.,
50
+ dim_latent = 1 # compress to 1 as an example
51
+ ).cuda()
52
+
53
+ latents = tensor([1.]).cuda()
54
+
55
+ # prepare enwik8 data
56
+
57
+ with gzip.open('./data/enwik8.gz') as file:
58
+ data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
59
+ train_x, valid_x = np.split(data, [int(90e6)])
60
+ data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
61
+
62
+ class TextSamplerDataset(Dataset):
63
+ def __init__(self, data, seq_len):
64
+ super().__init__()
65
+ self.data = data
66
+ self.seq_len = seq_len
67
+
68
+ def __getitem__(self, index):
69
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
70
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
71
+ return full_seq.cuda()
72
+
73
+ def __len__(self):
74
+ return self.data.size(0) // self.seq_len
75
+
76
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
77
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
78
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
79
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
80
+
81
+ # optimizer
82
+
83
+ optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
84
+
85
+ # training
86
+
87
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
88
+ model.train()
89
+
90
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
91
+ loss, (ar_loss, vae_kl_loss) = model(next(train_loader), return_all_losses = True)
92
+ (loss / GRADIENT_ACCUMULATE_EVERY).backward()
93
+
94
+ print(f'training loss: {ar_loss.item():.4f}\t| kl loss: {vae_kl_loss.item():.4f}')
95
+
96
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
97
+ optim.step()
98
+ optim.zero_grad()
99
+
100
+ if i % VALIDATE_EVERY == 0:
101
+ model.eval()
102
+ with torch.no_grad():
103
+ loss, (ar_loss, _) = model(next(val_loader), return_all_losses = True)
104
+ print(f'validation loss: {ar_loss.item():.4f}')
105
+
106
+ if i % GENERATE_EVERY == 0:
107
+ model.eval()
108
+ inp = random.choice(val_dataset)[:-1]
109
+ prime = decode_tokens(inp)
110
+ print(f'%s \n\n %s', (prime, '*' * 100))
111
+
112
+ sample = model.generate(
113
+ prompts = inp,
114
+ seq_len = GENERATE_LENGTH,
115
+ cache_kv = True,
116
+ latents = latents
117
+ )
118
+
119
+ output_str = decode_tokens(sample)
120
+
121
+ print(f'\n\nlatent {latents.tolist()} - ', output_str)
122
+
123
+ sample_other_direction = model.generate(
124
+ prompts = inp,
125
+ seq_len = GENERATE_LENGTH,
126
+ cache_kv = True,
127
+ latents = -latents
128
+ )
129
+
130
+ output_str = decode_tokens(sample_other_direction)
131
+ print(f'\n\nlatent {(-latents).tolist()} - ', output_str)
@@ -0,0 +1,200 @@
1
+ from __future__ import annotations
2
+
3
+ # applying the cvae + detr design from ACT (Zhou et al.) to GPT
4
+ # for steering, diversity rlvr, map-elites in epo, and other possibilities
5
+
6
+ import torch
7
+ from torch import nn, Tensor, is_tensor, tensor
8
+ import torch.nn.functional as F
9
+ from torch.nn import Module, ModuleList
10
+
11
+ from x_transformers.x_transformers import (
12
+ Encoder,
13
+ Decoder,
14
+ TransformerWrapper
15
+ )
16
+
17
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
18
+
19
+ from einops.layers.torch import Rearrange
20
+ from einops import rearrange, reduce, repeat
21
+
22
+ # helper functions
23
+
24
+ def exists(v):
25
+ return v is not None
26
+
27
+ def default(v, d):
28
+ return v if exists(v) else d
29
+
30
+ # classes
31
+
32
+ class GPTVAE(Module):
33
+ def __init__(
34
+ self,
35
+ *,
36
+ num_tokens,
37
+ dim,
38
+ depth,
39
+ enc_depth,
40
+ max_seq_len,
41
+ dim_latent = None,
42
+ attn_dim_head = 64,
43
+ heads = 8,
44
+ enc_kwargs: dict = dict(),
45
+ dec_kwargs: dict = dict(),
46
+ vae_kl_loss_weight = 1.,
47
+ latents_dropout_prob = 0.5, # what percentage of the time to dropout the latents completely
48
+ pad_id = -1,
49
+ **kwargs
50
+ ):
51
+ super().__init__()
52
+ dim_latent = default(dim_latent, dim)
53
+
54
+ self.encoder = TransformerWrapper(
55
+ num_tokens = num_tokens,
56
+ max_seq_len = max_seq_len + 1,
57
+ return_only_embed = True,
58
+ average_pool_embed = True,
59
+ attn_layers = Encoder(
60
+ dim = dim,
61
+ depth = enc_depth,
62
+ attn_dim_head = attn_dim_head,
63
+ heads = heads,
64
+ **kwargs,
65
+ **enc_kwargs
66
+ ),
67
+ )
68
+
69
+ self.to_latent_mean_log_variance = nn.Sequential(
70
+ nn.Linear(dim, dim_latent * 2),
71
+ Rearrange('b (two d) -> two b 1 d', two = 2)
72
+ )
73
+
74
+ self.from_latent_to_prepend_token = nn.Linear(dim_latent, dim)
75
+
76
+ self.decoder = TransformerWrapper(
77
+ num_tokens = num_tokens,
78
+ max_seq_len = max_seq_len,
79
+ attn_layers = Decoder(
80
+ dim = dim,
81
+ depth = depth,
82
+ attn_dim_head = attn_dim_head,
83
+ heads = heads,
84
+ **kwargs,
85
+ **dec_kwargs
86
+ ),
87
+ )
88
+
89
+ self.ar_wrapped_decoder = AutoregressiveWrapper(self.decoder, ignore_index = pad_id)
90
+
91
+ self.pad_id = pad_id
92
+
93
+ # loss weights - vae kl loss
94
+
95
+ self.vae_kl_loss_weight = vae_kl_loss_weight
96
+
97
+ self.latents_dropout = nn.Dropout(latents_dropout_prob)
98
+
99
+ @property
100
+ def device(self):
101
+ return next(self.parameters()).device
102
+
103
+ def encode_to_latents(
104
+ self,
105
+ seq,
106
+ return_mean_log_var = False
107
+ ):
108
+ mask = seq != self.pad_id
109
+ pooled = self.encoder(seq, mask = mask)
110
+
111
+ latents_mean, latents_log_var = self.to_latent_mean_log_variance(pooled)
112
+ latents_std = (0.5 * latents_log_var).exp()
113
+
114
+ # reparam trick
115
+
116
+ latents = latents_mean + latents_std * torch.randn_like(latents_mean)
117
+
118
+ if not return_mean_log_var:
119
+ return latents
120
+
121
+ return latents, (latents_mean, latents_log_var)
122
+
123
+ @torch.no_grad()
124
+ def generate(
125
+ self,
126
+ prompts,
127
+ seq_len,
128
+ latents = None,
129
+ **generate_kwargs
130
+ ):
131
+ assert prompts.ndim in {1, 2}
132
+ batch = prompts.shape[0] if prompts.ndim == 2 else 1
133
+
134
+ # prepend embeds
135
+
136
+ prepend_embeds = None
137
+ if exists(latents):
138
+ if not is_tensor(latents):
139
+ latents = tensor(latents, device = self.device)
140
+
141
+ if latents.ndim == 1: # repeat latents
142
+ latents = repeat(latents, 'd -> b d', b = batch)
143
+
144
+ prepend_embeds = self.from_latent_to_prepend_token(latents)
145
+
146
+ if exists(prepend_embeds):
147
+ prepend_embeds = rearrange(prepend_embeds, 'b d -> b 1 d')
148
+
149
+ # generated
150
+
151
+ generated = self.ar_wrapped_decoder.generate(
152
+ prompts,
153
+ seq_len,
154
+ prepend_embeds = prepend_embeds,
155
+ **generate_kwargs
156
+ )
157
+
158
+ return generated
159
+
160
+ def forward(
161
+ self,
162
+ seq,
163
+ return_all_losses = False
164
+ ):
165
+ batch, device = seq.shape[0], seq.device
166
+
167
+ latents, (latents_mean, latents_log_var) = self.encode_to_latents(seq, return_mean_log_var = True)
168
+
169
+ dropped_latents = ~self.latents_dropout(torch.ones((batch,), device = device)).bool()
170
+
171
+ prepend_embeds = self.from_latent_to_prepend_token(latents)
172
+
173
+ ar_loss = self.ar_wrapped_decoder(
174
+ seq,
175
+ prepend_embeds = prepend_embeds,
176
+ seq_start_pos = dropped_latents.long() # sequence starts at 1 and does not attend to the first style latent
177
+ )
178
+
179
+ # vae kl loss
180
+
181
+ vae_kl_loss = (
182
+ latents_log_var.exp()
183
+ + latents_mean.square()
184
+ - latents_log_var
185
+ - 1.
186
+ ).sum(dim = -1).mean()
187
+
188
+ # return losses
189
+
190
+ total_loss = (
191
+ ar_loss +
192
+ vae_kl_loss * self.vae_kl_loss_weight
193
+ )
194
+
195
+ if not return_all_losses:
196
+ return total_loss
197
+
198
+ losses = (ar_loss, vae_kl_loss)
199
+
200
+ return total_loss, losses
File without changes