x-transformers 2.1.16__tar.gz → 2.1.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {x_transformers-2.1.16 → x_transformers-2.1.18}/PKG-INFO +1 -1
  2. {x_transformers-2.1.16 → x_transformers-2.1.18}/pyproject.toml +1 -1
  3. {x_transformers-2.1.16 → x_transformers-2.1.18}/tests/test_x_transformers.py +1 -2
  4. x_transformers-2.1.18/train_belief_state.py +133 -0
  5. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/belief_state_wrapper.py +28 -23
  6. {x_transformers-2.1.16 → x_transformers-2.1.18}/.github/FUNDING.yml +0 -0
  7. {x_transformers-2.1.16 → x_transformers-2.1.18}/.github/workflows/python-publish.yml +0 -0
  8. {x_transformers-2.1.16 → x_transformers-2.1.18}/.github/workflows/python-test.yaml +0 -0
  9. {x_transformers-2.1.16 → x_transformers-2.1.18}/.gitignore +0 -0
  10. {x_transformers-2.1.16 → x_transformers-2.1.18}/LICENSE +0 -0
  11. {x_transformers-2.1.16 → x_transformers-2.1.18}/README.md +0 -0
  12. {x_transformers-2.1.16 → x_transformers-2.1.18}/data/README.md +0 -0
  13. {x_transformers-2.1.16 → x_transformers-2.1.18}/data/enwik8.gz +0 -0
  14. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/all-attention.png +0 -0
  15. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/attention-on-attention.png +0 -0
  16. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/cosine-sim-attention.png +0 -0
  17. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/deepnorm.png +0 -0
  18. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/dynamic-pos-bias-linear.png +0 -0
  19. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/dynamic-pos-bias-log.png +0 -0
  20. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  21. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/dynamic-pos-bias.png +0 -0
  22. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/enhanced-recurrence.png +0 -0
  23. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/fcm.png +0 -0
  24. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/ffglu.png +0 -0
  25. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/flash-attention.png +0 -0
  26. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/gate_values.png +0 -0
  27. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/gating.png +0 -0
  28. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/length-extrapolation-scale.png +0 -0
  29. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/macaron-1.png +0 -0
  30. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/macaron-2.png +0 -0
  31. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/memory-transformer.png +0 -0
  32. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/normformer.png +0 -0
  33. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/pia.png +0 -0
  34. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/qknorm-analysis.png +0 -0
  35. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/resi_dual.png +0 -0
  36. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/residual_attn.png +0 -0
  37. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/rezero.png +0 -0
  38. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/rotary.png +0 -0
  39. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/sandwich-2.png +0 -0
  40. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/sandwich.png +0 -0
  41. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/sandwich_norm.png +0 -0
  42. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/scalenorm.png +0 -0
  43. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/talking-heads.png +0 -0
  44. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/topk-attention.png +0 -0
  45. {x_transformers-2.1.16 → x_transformers-2.1.18}/images/xval.png +0 -0
  46. {x_transformers-2.1.16 → x_transformers-2.1.18}/train_copy.py +0 -0
  47. {x_transformers-2.1.16 → x_transformers-2.1.18}/train_enwik8.py +0 -0
  48. {x_transformers-2.1.16 → x_transformers-2.1.18}/train_length_extrapolate.py +0 -0
  49. {x_transformers-2.1.16 → x_transformers-2.1.18}/train_parity.py +0 -0
  50. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/__init__.py +0 -0
  51. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/attend.py +0 -0
  52. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/autoregressive_wrapper.py +0 -0
  53. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/continuous.py +0 -0
  54. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/dpo.py +0 -0
  55. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/multi_input.py +0 -0
  56. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/neo_mlp.py +0 -0
  57. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/nonautoregressive_wrapper.py +0 -0
  58. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/x_transformers.py +0 -0
  59. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  60. {x_transformers-2.1.16 → x_transformers-2.1.18}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.16
3
+ Version: 2.1.18
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.1.16"
3
+ version = "2.1.18"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -732,8 +732,7 @@ def test_belief_state_wrapper(
732
732
 
733
733
  seq = torch.randint(0, 20000, (2, 16))
734
734
 
735
- loss = model(seq, backward = False)
736
- loss.backward()
735
+ loss = model(seq) # backwards happen automatically
737
736
 
738
737
  suffix = None
739
738
  if goal_suffix:
@@ -0,0 +1,133 @@
1
+ from x_transformers import TransformerWrapper, Decoder, BeliefStateWrapper
2
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
3
+
4
+ import random
5
+ import tqdm
6
+ import gzip
7
+ import numpy as np
8
+ import torch
9
+ import torch.optim as optim
10
+ from torch.nn import functional as F
11
+ from torch.utils.data import DataLoader, Dataset
12
+
13
+ # constants
14
+
15
+ NUM_BATCHES = int(1e5)
16
+ BATCH_SIZE = 2
17
+ GRADIENT_ACCUMULATE_EVERY = 8
18
+ LEARNING_RATE = 1e-4
19
+ VALIDATE_EVERY = 100
20
+ GENERATE_EVERY = 500
21
+ GENERATE_LENGTH = 256
22
+ SEQ_LEN = 256
23
+
24
+ # helpers
25
+
26
+ def cycle(loader):
27
+ while True:
28
+ for data in loader:
29
+ yield data
30
+
31
+ def decode_token(token):
32
+ return str(chr(max(32, token)))
33
+
34
+ def decode_tokens(tokens):
35
+ return ''.join(list(map(decode_token, tokens)))
36
+
37
+ # instantiate GPT-like decoder model for forward and backwards
38
+
39
+ forward_model = TransformerWrapper(
40
+ num_tokens = 256,
41
+ max_seq_len = SEQ_LEN,
42
+ attn_layers = Decoder(
43
+ dim = 512,
44
+ depth = 6,
45
+ heads = 8,
46
+ rotary_pos_emb = True
47
+ )
48
+ )
49
+
50
+ backward_model = TransformerWrapper(
51
+ num_tokens = 256,
52
+ max_seq_len = SEQ_LEN,
53
+ attn_layers = Decoder(
54
+ dim = 512,
55
+ depth = 4, # do a smaller backwards
56
+ heads = 8,
57
+ rotary_pos_emb = True
58
+ )
59
+ )
60
+
61
+ model = BeliefStateWrapper(
62
+ forward_decoder = forward_model,
63
+ backward_decoder = backward_model
64
+ )
65
+
66
+ model.cuda()
67
+
68
+ # prepare enwik8 data
69
+
70
+ with gzip.open('./data/enwik8.gz') as file:
71
+ data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
72
+ train_x, valid_x = np.split(data, [int(90e6)])
73
+ data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
74
+
75
+ class TextSamplerDataset(Dataset):
76
+ def __init__(self, data, seq_len):
77
+ super().__init__()
78
+ self.data = data
79
+ self.seq_len = seq_len
80
+
81
+ def __getitem__(self, index):
82
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
83
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
84
+ return full_seq.cuda()
85
+
86
+ def __len__(self):
87
+ return self.data.size(0) // self.seq_len
88
+
89
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
90
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
91
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
92
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
93
+
94
+ # optimizer
95
+
96
+ optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
97
+
98
+ # training
99
+
100
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
101
+ model.train()
102
+
103
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
104
+ loss = model(
105
+ next(train_loader),
106
+ loss_scale = 1./ GRADIENT_ACCUMULATE_EVERY
107
+ )
108
+
109
+ print(f'training loss: {loss.item()}')
110
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
111
+ optim.step()
112
+ optim.zero_grad()
113
+
114
+ if i % VALIDATE_EVERY == 0:
115
+ model.eval()
116
+ with torch.no_grad():
117
+ loss = model(next(val_loader), return_loss_only = True)
118
+ print(f'validation loss: {loss.item()}')
119
+
120
+ if i % GENERATE_EVERY == 0:
121
+ model.eval()
122
+ inp = random.choice(val_dataset)[:-1]
123
+ prime = decode_tokens(inp)
124
+ print(f'%s \n\n %s', (prime, '*' * 100))
125
+
126
+ sample = model.generate_with_suffix_cond(
127
+ prompts = inp,
128
+ seq_len = GENERATE_LENGTH,
129
+ cache_kv = True
130
+ )
131
+
132
+ output_str = decode_tokens(sample)
133
+ print(output_str)
@@ -23,7 +23,7 @@ from x_transformers.x_transformers import (
23
23
  )
24
24
 
25
25
  import einx
26
- from einops import rearrange, repeat
26
+ from einops import rearrange, repeat, pack, unpack
27
27
  from einops.layers.torch import Rearrange
28
28
 
29
29
  # helper functions
@@ -126,7 +126,7 @@ class BeliefStateWrapper(Module):
126
126
  prompts,
127
127
  seq_len,
128
128
  temperature = 1.25,
129
- cache_kv = True,
129
+ cache_kv = False,
130
130
  suffix: Tensor | None = None, # the goal conditioning
131
131
  filter_logits_fn = min_p,
132
132
  filter_kwargs = dict(
@@ -136,6 +136,8 @@ class BeliefStateWrapper(Module):
136
136
  ):
137
137
  max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
138
138
 
139
+ prompts, batch_ps = pack([prompts], '* d')
140
+
139
141
  batch, orig_seq_len = prompts.shape
140
142
 
141
143
  out = prompts
@@ -197,14 +199,19 @@ class BeliefStateWrapper(Module):
197
199
 
198
200
  # concat sample
199
201
 
200
- out = torch.cat((out, sample), dim=-1)
202
+ out = torch.cat((out, sample), dim = -1)
203
+
204
+ out = out[:, orig_seq_len:]
201
205
 
202
- return out[:, orig_seq_len:]
206
+ out, = unpack(out, batch_ps, '* n')
207
+
208
+ return out
203
209
 
204
210
  def forward(
205
211
  self,
206
212
  seq,
207
- backward = True
213
+ return_loss_only = False,
214
+ loss_scale = 1.
208
215
  ):
209
216
  batch, seq_len, device = *seq.shape, seq.device
210
217
 
@@ -244,6 +251,7 @@ class BeliefStateWrapper(Module):
244
251
  # f - forward, b - backward, i - indices
245
252
 
246
253
  fi, bi = fb_pairs.unbind(dim = -1)
254
+
247
255
  valid_mask = (bi - fi) >= 2
248
256
 
249
257
  fb_pairs = fb_pairs[valid_mask]
@@ -265,8 +273,9 @@ class BeliefStateWrapper(Module):
265
273
 
266
274
  labels_fi, labels_bi = (fi + 1), bi
267
275
 
268
- forward_labels, backward_labels = seq[:, fi], seq[:, bi]
269
- labels = stack((forward_labels, backward_labels), dim = -1)
276
+ forward_labels, backward_labels = seq[:, labels_fi], seq[:, labels_bi]
277
+
278
+ labels = cat((forward_labels, backward_labels), dim = -1)
270
279
 
271
280
  # get the forward and backward embedding pairs and feed them through the text head for both forward and backward predictions
272
281
 
@@ -281,7 +290,7 @@ class BeliefStateWrapper(Module):
281
290
 
282
291
  loss = F.cross_entropy(
283
292
  rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
284
- rearrange(labels, 'b n fb -> b (fb n)'),
293
+ labels,
285
294
  reduction = 'none' if self.needs_loss_weight else 'mean'
286
295
  )
287
296
 
@@ -290,12 +299,12 @@ class BeliefStateWrapper(Module):
290
299
  if exists(self.to_terminal_logit):
291
300
  terminal_logits = self.to_terminal_logit(fb_embeds)
292
301
 
293
- labels = ((bi - fi) == 2).float() # distance is exactly 2
294
- labels = repeat(labels, 'n -> b n', b = batch)
302
+ terminal_labels = ((bi - fi) == 2).float() # distance is exactly 2
303
+ terminal_labels = repeat(terminal_labels, 'n -> b n', b = batch)
295
304
 
296
305
  is_end_loss = F.binary_cross_entropy_with_logits(
297
306
  terminal_logits,
298
- labels
307
+ terminal_labels
299
308
  )
300
309
 
301
310
  loss = (
@@ -303,6 +312,11 @@ class BeliefStateWrapper(Module):
303
312
  is_end_loss * self.pred_terminal_loss_weight
304
313
  )
305
314
 
315
+ # maybe early return loss
316
+
317
+ if return_loss_only:
318
+ return loss
319
+
306
320
  # maybe loss weighting
307
321
 
308
322
  if self.needs_loss_weight:
@@ -312,18 +326,9 @@ class BeliefStateWrapper(Module):
312
326
 
313
327
  # backwards
314
328
 
315
- orig_backward = getattr(loss, 'backward')
316
-
317
- def patched_backward_fn(*args, **kwargs):
318
- orig_backward(*args, **kwargs)
319
- orig_forward_embeds.backward(forward_embeds.grad)
320
- orig_backward_embeds.backward(backward_embeds.grad)
321
-
322
- # can allow the researcher to call .backward from the outside
329
+ (loss * loss_scale).backward()
323
330
 
324
- if backward:
325
- patched_backward_fn()
326
- else:
327
- setattr(loss, 'backward', patched_backward_fn)
331
+ orig_forward_embeds.backward(forward_embeds.grad)
332
+ orig_backward_embeds.backward(backward_embeds.grad)
328
333
 
329
334
  return loss
File without changes