x-transformers 2.0.3__tar.gz → 2.0.5__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {x_transformers-2.0.3 → x_transformers-2.0.5}/PKG-INFO +1 -1
- x_transformers-2.0.5/data/README.md +3 -0
- x_transformers-2.0.5/data/enwik8.gz +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/pyproject.toml +1 -1
- x_transformers-2.0.5/train_length_extrapolate.py +137 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/x_transformers.py +5 -6
- {x_transformers-2.0.3 → x_transformers-2.0.5}/.github/FUNDING.yml +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/.gitignore +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/LICENSE +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/README.md +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/all-attention.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/attention-on-attention.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/deepnorm.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/fcm.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/ffglu.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/flash-attention.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/gate_values.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/gating.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/macaron-1.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/macaron-2.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/memory-transformer.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/normformer.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/pia.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/resi_dual.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/residual_attn.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/rezero.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/rotary.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/sandwich-2.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/sandwich.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/sandwich_norm.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/scalenorm.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/talking-heads.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/topk-attention.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/images/xval.png +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/tests/test_x_transformers.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/train_copy.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/train_enwik8.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/train_parity.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/__init__.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/attend.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/continuous.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/dpo.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.0.3 → x_transformers-2.0.5}/x_transformers/xval.py +0 -0
Binary file
|
@@ -0,0 +1,137 @@
|
|
1
|
+
from x_transformers import TransformerWrapper, Decoder
|
2
|
+
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
3
|
+
|
4
|
+
import random
|
5
|
+
import tqdm
|
6
|
+
import gzip
|
7
|
+
import numpy as np
|
8
|
+
import torch
|
9
|
+
import torch.optim as optim
|
10
|
+
from torch.nn import functional as F
|
11
|
+
from torch.utils.data import DataLoader, Dataset
|
12
|
+
|
13
|
+
# constants
|
14
|
+
|
15
|
+
NUM_BATCHES = int(1e5)
|
16
|
+
BATCH_SIZE = 4
|
17
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
18
|
+
LEARNING_RATE = 1e-4
|
19
|
+
GENERATE_EVERY = 500
|
20
|
+
GENERATE_LENGTH = 256
|
21
|
+
SEQ_LEN = 256
|
22
|
+
|
23
|
+
VALIDATE_EVERY = 100
|
24
|
+
VALIDATE_SEQ_LENS = (256, 512, 1024, 2048, 4096)
|
25
|
+
|
26
|
+
# helpers
|
27
|
+
|
28
|
+
def cycle(loader):
|
29
|
+
while True:
|
30
|
+
for data in loader:
|
31
|
+
yield data
|
32
|
+
|
33
|
+
def decode_token(token):
|
34
|
+
return str(chr(max(32, token)))
|
35
|
+
|
36
|
+
def decode_tokens(tokens):
|
37
|
+
return ''.join(list(map(decode_token, tokens)))
|
38
|
+
|
39
|
+
# instantiate GPT-like decoder model
|
40
|
+
|
41
|
+
model = TransformerWrapper(
|
42
|
+
num_tokens = 256,
|
43
|
+
max_seq_len = SEQ_LEN,
|
44
|
+
use_abs_pos_emb = False,
|
45
|
+
attn_layers = Decoder(
|
46
|
+
dim = 512,
|
47
|
+
depth = 6,
|
48
|
+
heads = 8,
|
49
|
+
dynamic_pos_bias = True,
|
50
|
+
)
|
51
|
+
)
|
52
|
+
|
53
|
+
model = AutoregressiveWrapper(model)
|
54
|
+
model.cuda()
|
55
|
+
|
56
|
+
# prepare enwik8 data
|
57
|
+
|
58
|
+
with gzip.open('./data/enwik8.gz') as file:
|
59
|
+
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
|
60
|
+
train_x, valid_x = np.split(data, [int(90e6)])
|
61
|
+
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
|
62
|
+
|
63
|
+
class TextSamplerDataset(Dataset):
|
64
|
+
def __init__(self, data, seq_len):
|
65
|
+
super().__init__()
|
66
|
+
self.data = data
|
67
|
+
self.seq_len = seq_len
|
68
|
+
|
69
|
+
def __getitem__(self, index):
|
70
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
|
71
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
72
|
+
return full_seq.cuda()
|
73
|
+
|
74
|
+
def __len__(self):
|
75
|
+
return self.data.size(0) // self.seq_len
|
76
|
+
|
77
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
78
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
79
|
+
|
80
|
+
val_dataset_generate = TextSamplerDataset(data_val, SEQ_LEN)
|
81
|
+
|
82
|
+
# validation loaders with different sequence lengths
|
83
|
+
|
84
|
+
val_loaders = dict()
|
85
|
+
|
86
|
+
for valid_seq_len in VALIDATE_SEQ_LENS:
|
87
|
+
val_dataset = TextSamplerDataset(data_val, valid_seq_len)
|
88
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
89
|
+
|
90
|
+
val_loaders[valid_seq_len] = val_loader
|
91
|
+
|
92
|
+
# optimizer
|
93
|
+
|
94
|
+
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
95
|
+
|
96
|
+
# training
|
97
|
+
|
98
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
99
|
+
model.train()
|
100
|
+
|
101
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
102
|
+
loss = model(next(train_loader))
|
103
|
+
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
|
104
|
+
|
105
|
+
print(f'training loss: {loss.item()}')
|
106
|
+
|
107
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
108
|
+
optim.step()
|
109
|
+
optim.zero_grad()
|
110
|
+
|
111
|
+
if i % VALIDATE_EVERY == 0:
|
112
|
+
print(f'validation losses:\n')
|
113
|
+
|
114
|
+
model.eval()
|
115
|
+
with torch.no_grad():
|
116
|
+
for valid_seq_len in VALIDATE_SEQ_LENS:
|
117
|
+
val_loader = val_loaders[valid_seq_len]
|
118
|
+
|
119
|
+
loss = model(next(val_loader))
|
120
|
+
print(f'[{valid_seq_len}]:\t {loss.item()}')
|
121
|
+
|
122
|
+
print('\n')
|
123
|
+
|
124
|
+
if i % GENERATE_EVERY == 0:
|
125
|
+
model.eval()
|
126
|
+
inp = random.choice(val_dataset_generate)[:-1]
|
127
|
+
prime = decode_tokens(inp)
|
128
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
129
|
+
|
130
|
+
sample = model.generate(
|
131
|
+
prompts = inp,
|
132
|
+
seq_len = GENERATE_LENGTH,
|
133
|
+
cache_kv = True
|
134
|
+
)
|
135
|
+
|
136
|
+
output_str = decode_tokens(sample)
|
137
|
+
print(f'{output_str}\n\n')
|
@@ -449,17 +449,16 @@ class DynamicPositionBias(Module):
|
|
449
449
|
return next(self.parameters()).device
|
450
450
|
|
451
451
|
def forward(self, i, j):
|
452
|
-
assert i == j
|
453
452
|
n, device = j, self.device
|
454
453
|
|
455
454
|
# get the (n x n) matrix of distances
|
456
|
-
seq_arange = arange(
|
457
|
-
context_arange = arange(
|
455
|
+
seq_arange = arange(j - i, j, device = device)
|
456
|
+
context_arange = arange(j, device = device)
|
458
457
|
indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
|
459
|
-
indices += (
|
458
|
+
indices += (j - 1)
|
460
459
|
|
461
460
|
# input to continuous positions MLP
|
462
|
-
pos = arange(-
|
461
|
+
pos = arange(-j + 1, j, device = device).float()
|
463
462
|
pos = rearrange(pos, '... -> ... 1')
|
464
463
|
|
465
464
|
if self.log_distance:
|
@@ -1282,7 +1281,7 @@ class Attention(Module):
|
|
1282
1281
|
dim_kv_input = dim_latent_kv
|
1283
1282
|
|
1284
1283
|
if exists(latent_rope_subheads):
|
1285
|
-
assert not exists(rotate_num_heads)
|
1284
|
+
assert not exists(rotate_num_heads), '`rotate_num_heads` cannot be set when multi-latent attention is being used'
|
1286
1285
|
rotate_num_heads = latent_rope_subheads
|
1287
1286
|
|
1288
1287
|
k_dim = dim_head * (kv_heads - latent_rope_subheads)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|