x-transformers 2.8.3__tar.gz → 2.9.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.8.3 → x_transformers-2.9.0}/PKG-INFO +2 -1
- {x_transformers-2.8.3 → x_transformers-2.9.0}/pyproject.toml +2 -1
- {x_transformers-2.8.3 → x_transformers-2.9.0}/tests/test_x_transformers.py +4 -1
- x_transformers-2.9.0/train_with_muon.py +132 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/x_transformers.py +14 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/.github/FUNDING.yml +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/.gitignore +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/LICENSE +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/README.md +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/data/README.md +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/data/enwik8.gz +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/all-attention.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/attention-on-attention.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/deepnorm.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/fcm.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/ffglu.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/flash-attention.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/gate_values.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/gating.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/macaron-1.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/macaron-2.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/memory-transformer.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/normformer.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/pia.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/resi_dual.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/residual_attn.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/rezero.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/rotary.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/sandwich-2.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/sandwich.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/sandwich_norm.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/scalenorm.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/talking-heads.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/topk-attention.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/images/xval.png +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/train_belief_state.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/train_copy.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/train_entropy_tokenizer.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/train_enwik8.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/train_gpt_vae.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/train_length_extrapolate.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/train_parity.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/__init__.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/attend.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/continuous.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/dpo.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/entropy_based_tokenizer.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/gpt_vae.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/up_wrapper.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.8.3 → x_transformers-2.9.0}/x_transformers/xval.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.9.0
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -40,6 +40,7 @@ Requires-Dist: loguru
|
|
40
40
|
Requires-Dist: packaging>=21.0
|
41
41
|
Requires-Dist: torch>=2.0
|
42
42
|
Provides-Extra: examples
|
43
|
+
Requires-Dist: adam-atan2-pytorch>=0.2.2; extra == 'examples'
|
43
44
|
Requires-Dist: lion-pytorch; extra == 'examples'
|
44
45
|
Requires-Dist: tqdm; extra == 'examples'
|
45
46
|
Provides-Extra: test
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "x-transformers"
|
3
|
-
version = "2.
|
3
|
+
version = "2.9.0"
|
4
4
|
description = "X-Transformers"
|
5
5
|
authors = [
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
@@ -35,6 +35,7 @@ Repository = "https://github.com/lucidrains/x-transformers"
|
|
35
35
|
|
36
36
|
[project.optional-dependencies]
|
37
37
|
examples = [
|
38
|
+
"adam-atan2-pytorch>=0.2.2",
|
38
39
|
"lion-pytorch",
|
39
40
|
"tqdm",
|
40
41
|
]
|
@@ -1362,7 +1362,7 @@ def test_vae():
|
|
1362
1362
|
out = model.generate(seq[:, :512], 512, seq_for_latents = style)
|
1363
1363
|
|
1364
1364
|
def test_muon_params():
|
1365
|
-
from x_transformers import Attention, FeedForward
|
1365
|
+
from x_transformers import Attention, FeedForward, Encoder
|
1366
1366
|
|
1367
1367
|
attn = Attention(dim = 512, dim_out = 384)
|
1368
1368
|
assert len(list(attn.muon_parameters())) == 2
|
@@ -1370,3 +1370,6 @@ def test_muon_params():
|
|
1370
1370
|
ff = FeedForward(dim = 512)
|
1371
1371
|
|
1372
1372
|
assert len(list(ff.muon_parameters())) == 2
|
1373
|
+
|
1374
|
+
enc = Encoder(dim = 512, depth = 2)
|
1375
|
+
assert len(enc.muon_parameters()) > 0
|
@@ -0,0 +1,132 @@
|
|
1
|
+
# /// script
|
2
|
+
# dependencies = [
|
3
|
+
# "x-transformers",
|
4
|
+
# "adam-atan2-pytorch>=0.2.2",
|
5
|
+
# ]
|
6
|
+
# ///
|
7
|
+
|
8
|
+
from x_transformers import TransformerWrapper, Decoder
|
9
|
+
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
10
|
+
|
11
|
+
import random
|
12
|
+
import tqdm
|
13
|
+
import gzip
|
14
|
+
import numpy as np
|
15
|
+
import torch
|
16
|
+
import torch.optim as optim
|
17
|
+
from torch.nn import functional as F
|
18
|
+
from torch.utils.data import DataLoader, Dataset
|
19
|
+
|
20
|
+
from adam_atan2_pytorch import MuonAdamAtan2
|
21
|
+
|
22
|
+
# constants
|
23
|
+
|
24
|
+
NUM_BATCHES = int(1e5)
|
25
|
+
BATCH_SIZE = 4
|
26
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
27
|
+
LEARNING_RATE = 1e-4
|
28
|
+
MUON_LEARNING_RATE = 1e-3
|
29
|
+
VALIDATE_EVERY = 100
|
30
|
+
GENERATE_EVERY = 500
|
31
|
+
GENERATE_LENGTH = 1024
|
32
|
+
SEQ_LEN = 1024
|
33
|
+
|
34
|
+
# helpers
|
35
|
+
|
36
|
+
def cycle(loader):
|
37
|
+
while True:
|
38
|
+
for data in loader:
|
39
|
+
yield data
|
40
|
+
|
41
|
+
def decode_token(token):
|
42
|
+
return str(chr(max(32, token)))
|
43
|
+
|
44
|
+
def decode_tokens(tokens):
|
45
|
+
return ''.join(list(map(decode_token, tokens)))
|
46
|
+
|
47
|
+
# instantiate GPT-like decoder model
|
48
|
+
|
49
|
+
model = TransformerWrapper(
|
50
|
+
num_tokens = 256,
|
51
|
+
max_seq_len = SEQ_LEN,
|
52
|
+
attn_layers = Decoder(
|
53
|
+
dim = 512,
|
54
|
+
depth = 6,
|
55
|
+
heads = 8,
|
56
|
+
rotary_pos_emb = True
|
57
|
+
)
|
58
|
+
)
|
59
|
+
|
60
|
+
ar_wrapper = AutoregressiveWrapper(model)
|
61
|
+
model.cuda()
|
62
|
+
|
63
|
+
# prepare enwik8 data
|
64
|
+
|
65
|
+
with gzip.open('./data/enwik8.gz') as file:
|
66
|
+
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
|
67
|
+
train_x, valid_x = np.split(data, [int(90e6)])
|
68
|
+
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
|
69
|
+
|
70
|
+
class TextSamplerDataset(Dataset):
|
71
|
+
def __init__(self, data, seq_len):
|
72
|
+
super().__init__()
|
73
|
+
self.data = data
|
74
|
+
self.seq_len = seq_len
|
75
|
+
|
76
|
+
def __getitem__(self, index):
|
77
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
|
78
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
79
|
+
return full_seq.cuda()
|
80
|
+
|
81
|
+
def __len__(self):
|
82
|
+
return self.data.size(0) // self.seq_len
|
83
|
+
|
84
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
85
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
86
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
87
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
88
|
+
|
89
|
+
# optimizer
|
90
|
+
|
91
|
+
optim = MuonAdamAtan2(
|
92
|
+
muon_params = model.muon_parameters(),
|
93
|
+
params = model.parameters(),
|
94
|
+
remove_muon_params_from_params = True,
|
95
|
+
lr = LEARNING_RATE,
|
96
|
+
muon_lr = MUON_LEARNING_RATE,
|
97
|
+
)
|
98
|
+
|
99
|
+
# training
|
100
|
+
|
101
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
102
|
+
model.train()
|
103
|
+
|
104
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
105
|
+
loss = ar_wrapper(next(train_loader))
|
106
|
+
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
|
107
|
+
|
108
|
+
print(f'training loss: {loss.item()}')
|
109
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
110
|
+
optim.step()
|
111
|
+
optim.zero_grad()
|
112
|
+
|
113
|
+
if i % VALIDATE_EVERY == 0:
|
114
|
+
model.eval()
|
115
|
+
with torch.no_grad():
|
116
|
+
loss = ar_wrapper(next(val_loader))
|
117
|
+
print(f'validation loss: {loss.item()}')
|
118
|
+
|
119
|
+
if i % GENERATE_EVERY == 0:
|
120
|
+
model.eval()
|
121
|
+
inp = random.choice(val_dataset)[:-1]
|
122
|
+
prime = decode_tokens(inp)
|
123
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
124
|
+
|
125
|
+
sample = ar_wrapper.generate(
|
126
|
+
prompts = inp,
|
127
|
+
seq_len = GENERATE_LENGTH,
|
128
|
+
cache_kv = True
|
129
|
+
)
|
130
|
+
|
131
|
+
output_str = decode_tokens(sample)
|
132
|
+
print(output_str)
|
@@ -2493,6 +2493,17 @@ class AttentionLayers(Module):
|
|
2493
2493
|
for attn_layer, attn_inter in zip(attn_layers, attn_intermeds):
|
2494
2494
|
attn_layer.qk_clip_(attn_inter, tau = tau)
|
2495
2495
|
|
2496
|
+
def muon_parameters(self):
|
2497
|
+
params = []
|
2498
|
+
|
2499
|
+
for m in self.modules():
|
2500
|
+
if not isinstance(m, (Attention, FeedForward)):
|
2501
|
+
continue
|
2502
|
+
|
2503
|
+
params.extend(list(m.muon_parameters()))
|
2504
|
+
|
2505
|
+
return params
|
2506
|
+
|
2496
2507
|
def forward(
|
2497
2508
|
self,
|
2498
2509
|
x,
|
@@ -3230,6 +3241,9 @@ class TransformerWrapper(Module):
|
|
3230
3241
|
):
|
3231
3242
|
self.attn_layers.attn_qk_clip_(intermediates, tau = tau)
|
3232
3243
|
|
3244
|
+
def muon_parameters(self):
|
3245
|
+
return self.attn_layers.muon_parameters()
|
3246
|
+
|
3233
3247
|
def forward(
|
3234
3248
|
self,
|
3235
3249
|
x,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|