x-transformers 2.8.4__tar.gz → 2.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {x_transformers-2.8.4 → x_transformers-2.9.0}/PKG-INFO +2 -1
  2. {x_transformers-2.8.4 → x_transformers-2.9.0}/pyproject.toml +2 -1
  3. x_transformers-2.9.0/train_with_muon.py +132 -0
  4. {x_transformers-2.8.4 → x_transformers-2.9.0}/.github/FUNDING.yml +0 -0
  5. {x_transformers-2.8.4 → x_transformers-2.9.0}/.github/workflows/python-publish.yml +0 -0
  6. {x_transformers-2.8.4 → x_transformers-2.9.0}/.github/workflows/python-test.yaml +0 -0
  7. {x_transformers-2.8.4 → x_transformers-2.9.0}/.gitignore +0 -0
  8. {x_transformers-2.8.4 → x_transformers-2.9.0}/LICENSE +0 -0
  9. {x_transformers-2.8.4 → x_transformers-2.9.0}/README.md +0 -0
  10. {x_transformers-2.8.4 → x_transformers-2.9.0}/data/README.md +0 -0
  11. {x_transformers-2.8.4 → x_transformers-2.9.0}/data/enwik8.gz +0 -0
  12. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/all-attention.png +0 -0
  13. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/attention-on-attention.png +0 -0
  14. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/cosine-sim-attention.png +0 -0
  15. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/deepnorm.png +0 -0
  16. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/dynamic-pos-bias-linear.png +0 -0
  17. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/dynamic-pos-bias-log.png +0 -0
  18. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
  19. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/dynamic-pos-bias.png +0 -0
  20. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/enhanced-recurrence.png +0 -0
  21. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/fcm.png +0 -0
  22. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/ffglu.png +0 -0
  23. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/flash-attention.png +0 -0
  24. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/gate_values.png +0 -0
  25. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/gating.png +0 -0
  26. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/length-extrapolation-scale.png +0 -0
  27. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/macaron-1.png +0 -0
  28. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/macaron-2.png +0 -0
  29. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/memory-transformer.png +0 -0
  30. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/normformer.png +0 -0
  31. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/pia.png +0 -0
  32. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/qknorm-analysis.png +0 -0
  33. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/resi_dual.png +0 -0
  34. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/residual_attn.png +0 -0
  35. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/rezero.png +0 -0
  36. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/rotary.png +0 -0
  37. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/sandwich-2.png +0 -0
  38. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/sandwich.png +0 -0
  39. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/sandwich_norm.png +0 -0
  40. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/scalenorm.png +0 -0
  41. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/talking-heads.png +0 -0
  42. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/topk-attention.png +0 -0
  43. {x_transformers-2.8.4 → x_transformers-2.9.0}/images/xval.png +0 -0
  44. {x_transformers-2.8.4 → x_transformers-2.9.0}/tests/test_x_transformers.py +0 -0
  45. {x_transformers-2.8.4 → x_transformers-2.9.0}/train_belief_state.py +0 -0
  46. {x_transformers-2.8.4 → x_transformers-2.9.0}/train_copy.py +0 -0
  47. {x_transformers-2.8.4 → x_transformers-2.9.0}/train_entropy_tokenizer.py +0 -0
  48. {x_transformers-2.8.4 → x_transformers-2.9.0}/train_enwik8.py +0 -0
  49. {x_transformers-2.8.4 → x_transformers-2.9.0}/train_gpt_vae.py +0 -0
  50. {x_transformers-2.8.4 → x_transformers-2.9.0}/train_length_extrapolate.py +0 -0
  51. {x_transformers-2.8.4 → x_transformers-2.9.0}/train_parity.py +0 -0
  52. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/__init__.py +0 -0
  53. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/attend.py +0 -0
  54. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/autoregressive_wrapper.py +0 -0
  55. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/belief_state_wrapper.py +0 -0
  56. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/continuous.py +0 -0
  57. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/dpo.py +0 -0
  58. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/entropy_based_tokenizer.py +0 -0
  59. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/gpt_vae.py +0 -0
  60. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/multi_input.py +0 -0
  61. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/neo_mlp.py +0 -0
  62. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
  63. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/up_wrapper.py +0 -0
  64. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/x_transformers.py +0 -0
  65. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
  66. {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.8.4
3
+ Version: 2.9.0
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -40,6 +40,7 @@ Requires-Dist: loguru
40
40
  Requires-Dist: packaging>=21.0
41
41
  Requires-Dist: torch>=2.0
42
42
  Provides-Extra: examples
43
+ Requires-Dist: adam-atan2-pytorch>=0.2.2; extra == 'examples'
43
44
  Requires-Dist: lion-pytorch; extra == 'examples'
44
45
  Requires-Dist: tqdm; extra == 'examples'
45
46
  Provides-Extra: test
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "x-transformers"
3
- version = "2.8.4"
3
+ version = "2.9.0"
4
4
  description = "X-Transformers"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -35,6 +35,7 @@ Repository = "https://github.com/lucidrains/x-transformers"
35
35
 
36
36
  [project.optional-dependencies]
37
37
  examples = [
38
+ "adam-atan2-pytorch>=0.2.2",
38
39
  "lion-pytorch",
39
40
  "tqdm",
40
41
  ]
@@ -0,0 +1,132 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "x-transformers",
4
+ # "adam-atan2-pytorch>=0.2.2",
5
+ # ]
6
+ # ///
7
+
8
+ from x_transformers import TransformerWrapper, Decoder
9
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
10
+
11
+ import random
12
+ import tqdm
13
+ import gzip
14
+ import numpy as np
15
+ import torch
16
+ import torch.optim as optim
17
+ from torch.nn import functional as F
18
+ from torch.utils.data import DataLoader, Dataset
19
+
20
+ from adam_atan2_pytorch import MuonAdamAtan2
21
+
22
+ # constants
23
+
24
+ NUM_BATCHES = int(1e5)
25
+ BATCH_SIZE = 4
26
+ GRADIENT_ACCUMULATE_EVERY = 4
27
+ LEARNING_RATE = 1e-4
28
+ MUON_LEARNING_RATE = 1e-3
29
+ VALIDATE_EVERY = 100
30
+ GENERATE_EVERY = 500
31
+ GENERATE_LENGTH = 1024
32
+ SEQ_LEN = 1024
33
+
34
+ # helpers
35
+
36
+ def cycle(loader):
37
+ while True:
38
+ for data in loader:
39
+ yield data
40
+
41
+ def decode_token(token):
42
+ return str(chr(max(32, token)))
43
+
44
+ def decode_tokens(tokens):
45
+ return ''.join(list(map(decode_token, tokens)))
46
+
47
+ # instantiate GPT-like decoder model
48
+
49
+ model = TransformerWrapper(
50
+ num_tokens = 256,
51
+ max_seq_len = SEQ_LEN,
52
+ attn_layers = Decoder(
53
+ dim = 512,
54
+ depth = 6,
55
+ heads = 8,
56
+ rotary_pos_emb = True
57
+ )
58
+ )
59
+
60
+ ar_wrapper = AutoregressiveWrapper(model)
61
+ model.cuda()
62
+
63
+ # prepare enwik8 data
64
+
65
+ with gzip.open('./data/enwik8.gz') as file:
66
+ data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
67
+ train_x, valid_x = np.split(data, [int(90e6)])
68
+ data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
69
+
70
+ class TextSamplerDataset(Dataset):
71
+ def __init__(self, data, seq_len):
72
+ super().__init__()
73
+ self.data = data
74
+ self.seq_len = seq_len
75
+
76
+ def __getitem__(self, index):
77
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
78
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
79
+ return full_seq.cuda()
80
+
81
+ def __len__(self):
82
+ return self.data.size(0) // self.seq_len
83
+
84
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
85
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
86
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
87
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
88
+
89
+ # optimizer
90
+
91
+ optim = MuonAdamAtan2(
92
+ muon_params = model.muon_parameters(),
93
+ params = model.parameters(),
94
+ remove_muon_params_from_params = True,
95
+ lr = LEARNING_RATE,
96
+ muon_lr = MUON_LEARNING_RATE,
97
+ )
98
+
99
+ # training
100
+
101
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
102
+ model.train()
103
+
104
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
105
+ loss = ar_wrapper(next(train_loader))
106
+ (loss / GRADIENT_ACCUMULATE_EVERY).backward()
107
+
108
+ print(f'training loss: {loss.item()}')
109
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
110
+ optim.step()
111
+ optim.zero_grad()
112
+
113
+ if i % VALIDATE_EVERY == 0:
114
+ model.eval()
115
+ with torch.no_grad():
116
+ loss = ar_wrapper(next(val_loader))
117
+ print(f'validation loss: {loss.item()}')
118
+
119
+ if i % GENERATE_EVERY == 0:
120
+ model.eval()
121
+ inp = random.choice(val_dataset)[:-1]
122
+ prime = decode_tokens(inp)
123
+ print(f'%s \n\n %s', (prime, '*' * 100))
124
+
125
+ sample = ar_wrapper.generate(
126
+ prompts = inp,
127
+ seq_len = GENERATE_LENGTH,
128
+ cache_kv = True
129
+ )
130
+
131
+ output_str = decode_tokens(sample)
132
+ print(output_str)
File without changes
File without changes