x-transformers 2.1.15__tar.gz → 2.1.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {x_transformers-2.1.15 → x_transformers-2.1.17}/PKG-INFO +1 -1
  2. {x_transformers-2.1.15 → x_transformers-2.1.17}/pyproject.toml +1 -1
  3. {x_transformers-2.1.15 → x_transformers-2.1.17}/tests/test_x_transformers.py +1 -2
  4. x_transformers-2.1.17/train_belief_state.py +133 -0
  5. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/belief_state_wrapper.py +65 -26
  6. {x_transformers-2.1.15 → x_transformers-2.1.17}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.1.15 → x_transformers-2.1.17}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.1.15 → x_transformers-2.1.17}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.1.15 → x_transformers-2.1.17}/.gitignore +0 -0
  10. {x_transformers-2.1.15 → x_transformers-2.1.17}/LICENSE +0 -0
  11. {x_transformers-2.1.15 → x_transformers-2.1.17}/README.md +0 -0
  12. {x_transformers-2.1.15 → x_transformers-2.1.17}/data/README.md +0 -0
  13. {x_transformers-2.1.15 → x_transformers-2.1.17}/data/enwik8.gz +0 -0
  14. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/all-attention.png +0 -0
  15. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/deepnorm.png +0 -0
  18. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/fcm.png +0 -0
  24. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/ffglu.png +0 -0
  25. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/flash-attention.png +0 -0
  26. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/gate_values.png +0 -0
  27. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/gating.png +0 -0
  28. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/macaron-1.png +0 -0
  30. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/macaron-2.png +0 -0
  31. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/normformer.png +0 -0
  33. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/pia.png +0 -0
  34. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/resi_dual.png +0 -0
  36. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/residual_attn.png +0 -0
  37. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/rezero.png +0 -0
  38. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/rotary.png +0 -0
  39. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/sandwich.png +0 -0
  41. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/scalenorm.png +0 -0
  43. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/talking-heads.png +0 -0
  44. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/topk-attention.png +0 -0
  45. {x_transformers-2.1.15 → x_transformers-2.1.17}/images/xval.png +0 -0
  46. {x_transformers-2.1.15 → x_transformers-2.1.17}/train_copy.py +0 -0
  47. {x_transformers-2.1.15 → x_transformers-2.1.17}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.15 → x_transformers-2.1.17}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.15 → x_transformers-2.1.17}/train_parity.py +0 -0
  50. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/__init__.py +0 -0
  51. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/attend.py +0 -0
  52. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/autoregressive_wrapper.py +0 -0
  53. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/continuous.py +0 -0
  54. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/dpo.py +0 -0
  55. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/multi_input.py +0 -0
  56. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/neo_mlp.py +0 -0
  57. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/nonautoregressive_wrapper.py +0 -0
  58. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/x_transformers.py +0 -0
  59. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.15
3
+ Version: 2.1.17
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.15"
3
+ version = "2.1.17"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -732,8 +732,7 @@ def test_belief_state_wrapper(
732
732
 
733
733
  seq = torch.randint(0, 20000, (2, 16))
734
734
 
735
- loss = model(seq, backward = False)
736
- loss.backward()
735
+ loss = model(seq) # backwards happen automatically
737
736
 
738
737
  suffix = None
739
738
  if goal_suffix:
@@ -0,0 +1,133 @@
1
+ from x_transformers import TransformerWrapper, Decoder, BeliefStateWrapper
2
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
3
+
4
+ import random
5
+ import tqdm
6
+ import gzip
7
+ import numpy as np
8
+ import torch
9
+ import torch.optim as optim
10
+ from torch.nn import functional as F
11
+ from torch.utils.data import DataLoader, Dataset
12
+
13
+ # constants
14
+
15
+ NUM_BATCHES = int(1e5)
16
+ BATCH_SIZE = 2
17
+ GRADIENT_ACCUMULATE_EVERY = 8
18
+ LEARNING_RATE = 1e-4
19
+ VALIDATE_EVERY = 100
20
+ GENERATE_EVERY = 500
21
+ GENERATE_LENGTH = 256
22
+ SEQ_LEN = 256
23
+
24
+ # helpers
25
+
26
+ def cycle(loader):
27
+ while True:
28
+ for data in loader:
29
+ yield data
30
+
31
+ def decode_token(token):
32
+ return str(chr(max(32, token)))
33
+
34
+ def decode_tokens(tokens):
35
+ return ''.join(list(map(decode_token, tokens)))
36
+
37
+ # instantiate GPT-like decoder model for forward and backwards
38
+
39
+ forward_model = TransformerWrapper(
40
+ num_tokens = 256,
41
+ max_seq_len = SEQ_LEN,
42
+ attn_layers = Decoder(
43
+ dim = 512,
44
+ depth = 6,
45
+ heads = 8,
46
+ rotary_pos_emb = True
47
+ )
48
+ )
49
+
50
+ backward_model = TransformerWrapper(
51
+ num_tokens = 256,
52
+ max_seq_len = SEQ_LEN,
53
+ attn_layers = Decoder(
54
+ dim = 512,
55
+ depth = 4, # do a smaller backwards
56
+ heads = 8,
57
+ rotary_pos_emb = True
58
+ )
59
+ )
60
+
61
+ model = BeliefStateWrapper(
62
+ forward_decoder = forward_model,
63
+ backward_decoder = backward_model
64
+ )
65
+
66
+ model.cuda()
67
+
68
+ # prepare enwik8 data
69
+
70
+ with gzip.open('./data/enwik8.gz') as file:
71
+ data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
72
+ train_x, valid_x = np.split(data, [int(90e6)])
73
+ data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
74
+
75
+ class TextSamplerDataset(Dataset):
76
+ def __init__(self, data, seq_len):
77
+ super().__init__()
78
+ self.data = data
79
+ self.seq_len = seq_len
80
+
81
+ def __getitem__(self, index):
82
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
83
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
84
+ return full_seq.cuda()
85
+
86
+ def __len__(self):
87
+ return self.data.size(0) // self.seq_len
88
+
89
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
90
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
91
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
92
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
93
+
94
+ # optimizer
95
+
96
+ optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
97
+
98
+ # training
99
+
100
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
101
+ model.train()
102
+
103
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
104
+ loss = model(
105
+ next(train_loader),
106
+ loss_scale = 1./ GRADIENT_ACCUMULATE_EVERY
107
+ )
108
+
109
+ print(f'training loss: {loss.item()}')
110
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
111
+ optim.step()
112
+ optim.zero_grad()
113
+
114
+ if i % VALIDATE_EVERY == 0:
115
+ model.eval()
116
+ with torch.no_grad():
117
+ loss = model(next(val_loader), return_loss_only = True)
118
+ print(f'validation loss: {loss.item()}')
119
+
120
+ if i % GENERATE_EVERY == 0:
121
+ model.eval()
122
+ inp = random.choice(val_dataset)[:-1]
123
+ prime = decode_tokens(inp)
124
+ print(f'%s \n\n %s', (prime, '*' * 100))
125
+
126
+ sample = model.generate_with_suffix_cond(
127
+ prompts = inp,
128
+ seq_len = GENERATE_LENGTH,
129
+ cache_kv = True
130
+ )
131
+
132
+ output_str = decode_tokens(sample)
133
+ print(output_str)
@@ -23,7 +23,8 @@ from x_transformers.x_transformers import (
23
23
  )
24
24
 
25
25
  import einx
26
- from einops import rearrange, repeat
26
+ from einops import rearrange, repeat, pack, unpack
27
+ from einops.layers.torch import Rearrange
27
28
 
28
29
  # helper functions
29
30
 
@@ -55,7 +56,9 @@ class BeliefStateWrapper(Module):
55
56
  backward_decoder: TransformerWrapper | None = None,
56
57
  train_frac_forward_backward_pairs: float = 1.,
57
58
  text_head: Module | None = None,
58
- backward_ar_loss_weight: float = 1. # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
59
+ backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
60
+ pred_terminal = False,
61
+ pred_terminal_loss_weight: float = 1.
59
62
  ):
60
63
  super().__init__()
61
64
  backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token
@@ -66,6 +69,8 @@ class BeliefStateWrapper(Module):
66
69
  dim = forward_decoder.emb_dim
67
70
  num_tokens = forward_decoder.num_tokens
68
71
 
72
+ self.to_forward_logits = nn.Linear(dim, num_tokens, bias = False)
73
+
69
74
  # the suffix token
70
75
 
71
76
  self.suffix_token = nn.Parameter(torch.zeros(dim))
@@ -82,6 +87,17 @@ class BeliefStateWrapper(Module):
82
87
 
83
88
  self.text_head = text_head
84
89
 
90
+ # predicting terminal state (when suffix and prefix predict the same token)
91
+
92
+ self.to_terminal_logit = nn.Sequential(
93
+ nn.Linear(dim * 2, dim),
94
+ nn.LeakyReLU(),
95
+ nn.Linear(dim, 1),
96
+ Rearrange('... 1 -> ...')
97
+ ) if pred_terminal else None
98
+
99
+ self.pred_terminal_loss_weight = pred_terminal_loss_weight
100
+
85
101
  # the two decoders, one which is causal forward, the other causal backwards
86
102
 
87
103
  self.forward_decoder = forward_decoder
@@ -112,7 +128,7 @@ class BeliefStateWrapper(Module):
112
128
  prompts,
113
129
  seq_len,
114
130
  temperature = 1.25,
115
- cache_kv = True,
131
+ cache_kv = False,
116
132
  suffix: Tensor | None = None, # the goal conditioning
117
133
  filter_logits_fn = min_p,
118
134
  filter_kwargs = dict(
@@ -122,6 +138,8 @@ class BeliefStateWrapper(Module):
122
138
  ):
123
139
  max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
124
140
 
141
+ prompts, batch_ps = pack([prompts], '* d')
142
+
125
143
  batch, orig_seq_len = prompts.shape
126
144
 
127
145
  out = prompts
@@ -183,14 +201,19 @@ class BeliefStateWrapper(Module):
183
201
 
184
202
  # concat sample
185
203
 
186
- out = torch.cat((out, sample), dim=-1)
204
+ out = torch.cat((out, sample), dim = -1)
205
+
206
+ out = out[:, orig_seq_len:]
187
207
 
188
- return out[:, orig_seq_len:]
208
+ out, = unpack(out, batch_ps, '* n')
209
+
210
+ return out
189
211
 
190
212
  def forward(
191
213
  self,
192
214
  seq,
193
- backward = True
215
+ return_loss_only = False,
216
+ loss_scale = 1.
194
217
  ):
195
218
  batch, seq_len, device = *seq.shape, seq.device
196
219
 
@@ -230,6 +253,7 @@ class BeliefStateWrapper(Module):
230
253
  # f - forward, b - backward, i - indices
231
254
 
232
255
  fi, bi = fb_pairs.unbind(dim = -1)
256
+
233
257
  valid_mask = (bi - fi) >= 2
234
258
 
235
259
  fb_pairs = fb_pairs[valid_mask]
@@ -251,8 +275,9 @@ class BeliefStateWrapper(Module):
251
275
 
252
276
  labels_fi, labels_bi = (fi + 1), bi
253
277
 
254
- forward_labels, backward_labels = seq[:, fi], seq[:, bi]
255
- labels = stack((forward_labels, backward_labels), dim = -1)
278
+ forward_labels, backward_labels = seq[:, labels_fi], seq[:, labels_bi]
279
+
280
+ labels = cat((forward_labels, backward_labels), dim = -1)
256
281
 
257
282
  # get the forward and backward embedding pairs and feed them through the text head for both forward and backward predictions
258
283
 
@@ -265,33 +290,47 @@ class BeliefStateWrapper(Module):
265
290
 
266
291
  # cross entropy loss
267
292
 
268
- fb_loss = F.cross_entropy(
293
+ loss = F.cross_entropy(
269
294
  rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
270
- rearrange(labels, 'b n fb -> b (fb n)'),
295
+ labels,
271
296
  reduction = 'none' if self.needs_loss_weight else 'mean'
272
297
  )
273
298
 
299
+ # maybe predict terminal
300
+
301
+ if exists(self.to_terminal_logit):
302
+ terminal_logits = self.to_terminal_logit(fb_embeds)
303
+
304
+ terminal_labels = ((bi - fi) == 2).float() # distance is exactly 2
305
+ terminal_labels = repeat(terminal_labels, 'n -> b n', b = batch)
306
+
307
+ is_end_loss = F.binary_cross_entropy_with_logits(
308
+ terminal_logits,
309
+ terminal_labels
310
+ )
311
+
312
+ loss = (
313
+ loss +
314
+ is_end_loss * self.pred_terminal_loss_weight
315
+ )
316
+
317
+ # maybe early return loss
318
+
319
+ if return_loss_only:
320
+ return loss
321
+
274
322
  # maybe loss weighting
275
323
 
276
324
  if self.needs_loss_weight:
277
- fb_loss = rearrange(fb_loss, 'b (fb n) -> b fb n', fb = 2)
278
- fb_loss = einx.multiply('b fb n, fb', fb_loss, self.loss_weights)
279
- fb_loss = fb_loss.mean()
325
+ loss = rearrange(loss, 'b (fb n) -> b fb n', fb = 2)
326
+ loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
327
+ loss = loss.mean()
280
328
 
281
329
  # backwards
282
330
 
283
- orig_backward = getattr(fb_loss, 'backward')
284
-
285
- def patched_backward_fn(*args, **kwargs):
286
- orig_backward(*args, **kwargs)
287
- orig_forward_embeds.backward(forward_embeds.grad)
288
- orig_backward_embeds.backward(backward_embeds.grad)
289
-
290
- # can allow the researcher to call .backward from the outside
331
+ (loss * loss_scale).backward()
291
332
 
292
- if backward:
293
- patched_backward_fn()
294
- else:
295
- setattr(fb_loss, 'backward', patched_backward_fn)
333
+ orig_forward_embeds.backward(forward_embeds.grad)
334
+ orig_backward_embeds.backward(backward_embeds.grad)
296
335
 
297
- return fb_loss
336
+ return loss
File without changes