x-transformers 2.1.16__tar.gz → 2.1.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {x_transformers-2.1.16 → x_transformers-2.1.17}/PKG-INFO +1 -1
  2. {x_transformers-2.1.16 → x_transformers-2.1.17}/pyproject.toml +1 -1
  3. {x_transformers-2.1.16 → x_transformers-2.1.17}/tests/test_x_transformers.py +1 -2
  4. x_transformers-2.1.17/train_belief_state.py +133 -0
  5. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/belief_state_wrapper.py +30 -23
  6. {x_transformers-2.1.16 → x_transformers-2.1.17}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.1.16 → x_transformers-2.1.17}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.1.16 → x_transformers-2.1.17}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.1.16 → x_transformers-2.1.17}/.gitignore +0 -0
  10. {x_transformers-2.1.16 → x_transformers-2.1.17}/LICENSE +0 -0
  11. {x_transformers-2.1.16 → x_transformers-2.1.17}/README.md +0 -0
  12. {x_transformers-2.1.16 → x_transformers-2.1.17}/data/README.md +0 -0
  13. {x_transformers-2.1.16 → x_transformers-2.1.17}/data/enwik8.gz +0 -0
  14. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/all-attention.png +0 -0
  15. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/deepnorm.png +0 -0
  18. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/fcm.png +0 -0
  24. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/ffglu.png +0 -0
  25. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/flash-attention.png +0 -0
  26. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/gate_values.png +0 -0
  27. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/gating.png +0 -0
  28. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/macaron-1.png +0 -0
  30. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/macaron-2.png +0 -0
  31. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/normformer.png +0 -0
  33. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/pia.png +0 -0
  34. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/resi_dual.png +0 -0
  36. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/residual_attn.png +0 -0
  37. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/rezero.png +0 -0
  38. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/rotary.png +0 -0
  39. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/sandwich.png +0 -0
  41. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/scalenorm.png +0 -0
  43. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/talking-heads.png +0 -0
  44. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/topk-attention.png +0 -0
  45. {x_transformers-2.1.16 → x_transformers-2.1.17}/images/xval.png +0 -0
  46. {x_transformers-2.1.16 → x_transformers-2.1.17}/train_copy.py +0 -0
  47. {x_transformers-2.1.16 → x_transformers-2.1.17}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.16 → x_transformers-2.1.17}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.16 → x_transformers-2.1.17}/train_parity.py +0 -0
  50. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/__init__.py +0 -0
  51. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/attend.py +0 -0
  52. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/autoregressive_wrapper.py +0 -0
  53. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/continuous.py +0 -0
  54. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/dpo.py +0 -0
  55. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/multi_input.py +0 -0
  56. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/neo_mlp.py +0 -0
  57. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/nonautoregressive_wrapper.py +0 -0
  58. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/x_transformers.py +0 -0
  59. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.16 → x_transformers-2.1.17}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.16
3
+ Version: 2.1.17
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.16"
3
+ version = "2.1.17"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -732,8 +732,7 @@ def test_belief_state_wrapper(
732
732
 
733
733
  seq = torch.randint(0, 20000, (2, 16))
734
734
 
735
- loss = model(seq, backward = False)
736
- loss.backward()
735
+ loss = model(seq) # backwards happen automatically
737
736
 
738
737
  suffix = None
739
738
  if goal_suffix:
@@ -0,0 +1,133 @@
1
+ from x_transformers import TransformerWrapper, Decoder, BeliefStateWrapper
2
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
3
+
4
+ import random
5
+ import tqdm
6
+ import gzip
7
+ import numpy as np
8
+ import torch
9
+ import torch.optim as optim
10
+ from torch.nn import functional as F
11
+ from torch.utils.data import DataLoader, Dataset
12
+
13
+ # constants
14
+
15
+ NUM_BATCHES = int(1e5)
16
+ BATCH_SIZE = 2
17
+ GRADIENT_ACCUMULATE_EVERY = 8
18
+ LEARNING_RATE = 1e-4
19
+ VALIDATE_EVERY = 100
20
+ GENERATE_EVERY = 500
21
+ GENERATE_LENGTH = 256
22
+ SEQ_LEN = 256
23
+
24
+ # helpers
25
+
26
+ def cycle(loader):
27
+ while True:
28
+ for data in loader:
29
+ yield data
30
+
31
+ def decode_token(token):
32
+ return str(chr(max(32, token)))
33
+
34
+ def decode_tokens(tokens):
35
+ return ''.join(list(map(decode_token, tokens)))
36
+
37
+ # instantiate GPT-like decoder model for forward and backwards
38
+
39
+ forward_model = TransformerWrapper(
40
+ num_tokens = 256,
41
+ max_seq_len = SEQ_LEN,
42
+ attn_layers = Decoder(
43
+ dim = 512,
44
+ depth = 6,
45
+ heads = 8,
46
+ rotary_pos_emb = True
47
+ )
48
+ )
49
+
50
+ backward_model = TransformerWrapper(
51
+ num_tokens = 256,
52
+ max_seq_len = SEQ_LEN,
53
+ attn_layers = Decoder(
54
+ dim = 512,
55
+ depth = 4, # do a smaller backwards
56
+ heads = 8,
57
+ rotary_pos_emb = True
58
+ )
59
+ )
60
+
61
+ model = BeliefStateWrapper(
62
+ forward_decoder = forward_model,
63
+ backward_decoder = backward_model
64
+ )
65
+
66
+ model.cuda()
67
+
68
+ # prepare enwik8 data
69
+
70
+ with gzip.open('./data/enwik8.gz') as file:
71
+ data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
72
+ train_x, valid_x = np.split(data, [int(90e6)])
73
+ data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
74
+
75
+ class TextSamplerDataset(Dataset):
76
+ def __init__(self, data, seq_len):
77
+ super().__init__()
78
+ self.data = data
79
+ self.seq_len = seq_len
80
+
81
+ def __getitem__(self, index):
82
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
83
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
84
+ return full_seq.cuda()
85
+
86
+ def __len__(self):
87
+ return self.data.size(0) // self.seq_len
88
+
89
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
90
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
91
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
92
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
93
+
94
+ # optimizer
95
+
96
+ optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
97
+
98
+ # training
99
+
100
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
101
+ model.train()
102
+
103
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
104
+ loss = model(
105
+ next(train_loader),
106
+ loss_scale = 1./ GRADIENT_ACCUMULATE_EVERY
107
+ )
108
+
109
+ print(f'training loss: {loss.item()}')
110
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
111
+ optim.step()
112
+ optim.zero_grad()
113
+
114
+ if i % VALIDATE_EVERY == 0:
115
+ model.eval()
116
+ with torch.no_grad():
117
+ loss = model(next(val_loader), return_loss_only = True)
118
+ print(f'validation loss: {loss.item()}')
119
+
120
+ if i % GENERATE_EVERY == 0:
121
+ model.eval()
122
+ inp = random.choice(val_dataset)[:-1]
123
+ prime = decode_tokens(inp)
124
+ print(f'%s \n\n %s', (prime, '*' * 100))
125
+
126
+ sample = model.generate_with_suffix_cond(
127
+ prompts = inp,
128
+ seq_len = GENERATE_LENGTH,
129
+ cache_kv = True
130
+ )
131
+
132
+ output_str = decode_tokens(sample)
133
+ print(output_str)
@@ -23,7 +23,7 @@ from x_transformers.x_transformers import (
23
23
  )
24
24
 
25
25
  import einx
26
- from einops import rearrange, repeat
26
+ from einops import rearrange, repeat, pack, unpack
27
27
  from einops.layers.torch import Rearrange
28
28
 
29
29
  # helper functions
@@ -69,6 +69,8 @@ class BeliefStateWrapper(Module):
69
69
  dim = forward_decoder.emb_dim
70
70
  num_tokens = forward_decoder.num_tokens
71
71
 
72
+ self.to_forward_logits = nn.Linear(dim, num_tokens, bias = False)
73
+
72
74
  # the suffix token
73
75
 
74
76
  self.suffix_token = nn.Parameter(torch.zeros(dim))
@@ -126,7 +128,7 @@ class BeliefStateWrapper(Module):
126
128
  prompts,
127
129
  seq_len,
128
130
  temperature = 1.25,
129
- cache_kv = True,
131
+ cache_kv = False,
130
132
  suffix: Tensor | None = None, # the goal conditioning
131
133
  filter_logits_fn = min_p,
132
134
  filter_kwargs = dict(
@@ -136,6 +138,8 @@ class BeliefStateWrapper(Module):
136
138
  ):
137
139
  max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
138
140
 
141
+ prompts, batch_ps = pack([prompts], '* d')
142
+
139
143
  batch, orig_seq_len = prompts.shape
140
144
 
141
145
  out = prompts
@@ -197,14 +201,19 @@ class BeliefStateWrapper(Module):
197
201
 
198
202
  # concat sample
199
203
 
200
- out = torch.cat((out, sample), dim=-1)
204
+ out = torch.cat((out, sample), dim = -1)
205
+
206
+ out = out[:, orig_seq_len:]
201
207
 
202
- return out[:, orig_seq_len:]
208
+ out, = unpack(out, batch_ps, '* n')
209
+
210
+ return out
203
211
 
204
212
  def forward(
205
213
  self,
206
214
  seq,
207
- backward = True
215
+ return_loss_only = False,
216
+ loss_scale = 1.
208
217
  ):
209
218
  batch, seq_len, device = *seq.shape, seq.device
210
219
 
@@ -244,6 +253,7 @@ class BeliefStateWrapper(Module):
244
253
  # f - forward, b - backward, i - indices
245
254
 
246
255
  fi, bi = fb_pairs.unbind(dim = -1)
256
+
247
257
  valid_mask = (bi - fi) >= 2
248
258
 
249
259
  fb_pairs = fb_pairs[valid_mask]
@@ -265,8 +275,9 @@ class BeliefStateWrapper(Module):
265
275
 
266
276
  labels_fi, labels_bi = (fi + 1), bi
267
277
 
268
- forward_labels, backward_labels = seq[:, fi], seq[:, bi]
269
- labels = stack((forward_labels, backward_labels), dim = -1)
278
+ forward_labels, backward_labels = seq[:, labels_fi], seq[:, labels_bi]
279
+
280
+ labels = cat((forward_labels, backward_labels), dim = -1)
270
281
 
271
282
  # get the forward and backward embedding pairs and feed them through the text head for both forward and backward predictions
272
283
 
@@ -281,7 +292,7 @@ class BeliefStateWrapper(Module):
281
292
 
282
293
  loss = F.cross_entropy(
283
294
  rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
284
- rearrange(labels, 'b n fb -> b (fb n)'),
295
+ labels,
285
296
  reduction = 'none' if self.needs_loss_weight else 'mean'
286
297
  )
287
298
 
@@ -290,12 +301,12 @@ class BeliefStateWrapper(Module):
290
301
  if exists(self.to_terminal_logit):
291
302
  terminal_logits = self.to_terminal_logit(fb_embeds)
292
303
 
293
- labels = ((bi - fi) == 2).float() # distance is exactly 2
294
- labels = repeat(labels, 'n -> b n', b = batch)
304
+ terminal_labels = ((bi - fi) == 2).float() # distance is exactly 2
305
+ terminal_labels = repeat(terminal_labels, 'n -> b n', b = batch)
295
306
 
296
307
  is_end_loss = F.binary_cross_entropy_with_logits(
297
308
  terminal_logits,
298
- labels
309
+ terminal_labels
299
310
  )
300
311
 
301
312
  loss = (
@@ -303,6 +314,11 @@ class BeliefStateWrapper(Module):
303
314
  is_end_loss * self.pred_terminal_loss_weight
304
315
  )
305
316
 
317
+ # maybe early return loss
318
+
319
+ if return_loss_only:
320
+ return loss
321
+
306
322
  # maybe loss weighting
307
323
 
308
324
  if self.needs_loss_weight:
@@ -312,18 +328,9 @@ class BeliefStateWrapper(Module):
312
328
 
313
329
  # backwards
314
330
 
315
- orig_backward = getattr(loss, 'backward')
316
-
317
- def patched_backward_fn(*args, **kwargs):
318
- orig_backward(*args, **kwargs)
319
- orig_forward_embeds.backward(forward_embeds.grad)
320
- orig_backward_embeds.backward(backward_embeds.grad)
321
-
322
- # can allow the researcher to call .backward from the outside
331
+ (loss * loss_scale).backward()
323
332
 
324
- if backward:
325
- patched_backward_fn()
326
- else:
327
- setattr(loss, 'backward', patched_backward_fn)
333
+ orig_forward_embeds.backward(forward_embeds.grad)
334
+ orig_backward_embeds.backward(backward_embeds.grad)
328
335
 
329
336
  return loss
File without changes