x-transformers 2.8.4__tar.gz → 2.9.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.8.4 → x_transformers-2.9.0}/PKG-INFO +2 -1
- {x_transformers-2.8.4 → x_transformers-2.9.0}/pyproject.toml +2 -1
- x_transformers-2.9.0/train_with_muon.py +132 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/.github/FUNDING.yml +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/.gitignore +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/LICENSE +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/README.md +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/data/README.md +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/data/enwik8.gz +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/all-attention.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/attention-on-attention.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/deepnorm.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/fcm.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/ffglu.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/flash-attention.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/gate_values.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/gating.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/macaron-1.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/macaron-2.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/memory-transformer.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/normformer.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/pia.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/resi_dual.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/residual_attn.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/rezero.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/rotary.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/sandwich-2.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/sandwich.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/sandwich_norm.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/scalenorm.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/talking-heads.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/topk-attention.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/images/xval.png +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/tests/test_x_transformers.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/train_belief_state.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/train_copy.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/train_enwik8.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/train_gpt_vae.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/train_length_extrapolate.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/train_parity.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/__init__.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/attend.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/continuous.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/dpo.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.8.4 → x_transformers-2.9.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.9.0
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -40,6 +40,7 @@ Requires-Dist: loguru
|
|
40
40
|
Requires-Dist: packaging>=21.0
|
41
41
|
Requires-Dist: torch>=2.0
|
42
42
|
Provides-Extra: examples
|
43
|
+
Requires-Dist: adam-atan2-pytorch>=0.2.2; extra == 'examples'
|
43
44
|
Requires-Dist: lion-pytorch; extra == 'examples'
|
44
45
|
Requires-Dist: tqdm; extra == 'examples'
|
45
46
|
Provides-Extra: test
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "x-transformers"
|
3
|
-
version = "2.
|
3
|
+
version = "2.9.0"
|
4
4
|
description = "X-Transformers"
|
5
5
|
authors = [
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
@@ -35,6 +35,7 @@ Repository = "https://github.com/lucidrains/x-transformers"
|
|
35
35
|
|
36
36
|
[project.optional-dependencies]
|
37
37
|
examples = [
|
38
|
+
"adam-atan2-pytorch>=0.2.2",
|
38
39
|
"lion-pytorch",
|
39
40
|
"tqdm",
|
40
41
|
]
|
@@ -0,0 +1,132 @@
|
|
1
|
+
# /// script
|
2
|
+
# dependencies = [
|
3
|
+
# "x-transformers",
|
4
|
+
# "adam-atan2-pytorch>=0.2.2",
|
5
|
+
# ]
|
6
|
+
# ///
|
7
|
+
|
8
|
+
from x_transformers import TransformerWrapper, Decoder
|
9
|
+
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
10
|
+
|
11
|
+
import random
|
12
|
+
import tqdm
|
13
|
+
import gzip
|
14
|
+
import numpy as np
|
15
|
+
import torch
|
16
|
+
import torch.optim as optim
|
17
|
+
from torch.nn import functional as F
|
18
|
+
from torch.utils.data import DataLoader, Dataset
|
19
|
+
|
20
|
+
from adam_atan2_pytorch import MuonAdamAtan2
|
21
|
+
|
22
|
+
# constants
|
23
|
+
|
24
|
+
NUM_BATCHES = int(1e5)
|
25
|
+
BATCH_SIZE = 4
|
26
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
27
|
+
LEARNING_RATE = 1e-4
|
28
|
+
MUON_LEARNING_RATE = 1e-3
|
29
|
+
VALIDATE_EVERY = 100
|
30
|
+
GENERATE_EVERY = 500
|
31
|
+
GENERATE_LENGTH = 1024
|
32
|
+
SEQ_LEN = 1024
|
33
|
+
|
34
|
+
# helpers
|
35
|
+
|
36
|
+
def cycle(loader):
|
37
|
+
while True:
|
38
|
+
for data in loader:
|
39
|
+
yield data
|
40
|
+
|
41
|
+
def decode_token(token):
|
42
|
+
return str(chr(max(32, token)))
|
43
|
+
|
44
|
+
def decode_tokens(tokens):
|
45
|
+
return ''.join(list(map(decode_token, tokens)))
|
46
|
+
|
47
|
+
# instantiate GPT-like decoder model
|
48
|
+
|
49
|
+
model = TransformerWrapper(
|
50
|
+
num_tokens = 256,
|
51
|
+
max_seq_len = SEQ_LEN,
|
52
|
+
attn_layers = Decoder(
|
53
|
+
dim = 512,
|
54
|
+
depth = 6,
|
55
|
+
heads = 8,
|
56
|
+
rotary_pos_emb = True
|
57
|
+
)
|
58
|
+
)
|
59
|
+
|
60
|
+
ar_wrapper = AutoregressiveWrapper(model)
|
61
|
+
model.cuda()
|
62
|
+
|
63
|
+
# prepare enwik8 data
|
64
|
+
|
65
|
+
with gzip.open('./data/enwik8.gz') as file:
|
66
|
+
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
|
67
|
+
train_x, valid_x = np.split(data, [int(90e6)])
|
68
|
+
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
|
69
|
+
|
70
|
+
class TextSamplerDataset(Dataset):
|
71
|
+
def __init__(self, data, seq_len):
|
72
|
+
super().__init__()
|
73
|
+
self.data = data
|
74
|
+
self.seq_len = seq_len
|
75
|
+
|
76
|
+
def __getitem__(self, index):
|
77
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
|
78
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
79
|
+
return full_seq.cuda()
|
80
|
+
|
81
|
+
def __len__(self):
|
82
|
+
return self.data.size(0) // self.seq_len
|
83
|
+
|
84
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
85
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
86
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
87
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
88
|
+
|
89
|
+
# optimizer
|
90
|
+
|
91
|
+
optim = MuonAdamAtan2(
|
92
|
+
muon_params = model.muon_parameters(),
|
93
|
+
params = model.parameters(),
|
94
|
+
remove_muon_params_from_params = True,
|
95
|
+
lr = LEARNING_RATE,
|
96
|
+
muon_lr = MUON_LEARNING_RATE,
|
97
|
+
)
|
98
|
+
|
99
|
+
# training
|
100
|
+
|
101
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
102
|
+
model.train()
|
103
|
+
|
104
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
105
|
+
loss = ar_wrapper(next(train_loader))
|
106
|
+
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
|
107
|
+
|
108
|
+
print(f'training loss: {loss.item()}')
|
109
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
110
|
+
optim.step()
|
111
|
+
optim.zero_grad()
|
112
|
+
|
113
|
+
if i % VALIDATE_EVERY == 0:
|
114
|
+
model.eval()
|
115
|
+
with torch.no_grad():
|
116
|
+
loss = ar_wrapper(next(val_loader))
|
117
|
+
print(f'validation loss: {loss.item()}')
|
118
|
+
|
119
|
+
if i % GENERATE_EVERY == 0:
|
120
|
+
model.eval()
|
121
|
+
inp = random.choice(val_dataset)[:-1]
|
122
|
+
prime = decode_tokens(inp)
|
123
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
124
|
+
|
125
|
+
sample = ar_wrapper.generate(
|
126
|
+
prompts = inp,
|
127
|
+
seq_len = GENERATE_LENGTH,
|
128
|
+
cache_kv = True
|
129
|
+
)
|
130
|
+
|
131
|
+
output_str = decode_tokens(sample)
|
132
|
+
print(output_str)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|