x-transformers 2.2.7__tar.gz → 2.2.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.2.7 → x_transformers-2.2.9}/PKG-INFO +1 -1
- {x_transformers-2.2.7 → x_transformers-2.2.9}/pyproject.toml +1 -1
- {x_transformers-2.2.7 → x_transformers-2.2.9}/tests/test_x_transformers.py +21 -1
- x_transformers-2.2.9/train_entropy_tokenizer.py +118 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/entropy_based_tokenizer.py +28 -3
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/x_transformers.py +4 -1
- {x_transformers-2.2.7 → x_transformers-2.2.9}/.github/FUNDING.yml +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/.gitignore +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/LICENSE +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/README.md +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/data/README.md +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/data/enwik8.gz +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/all-attention.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/attention-on-attention.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/deepnorm.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/fcm.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/ffglu.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/flash-attention.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/gate_values.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/gating.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/macaron-1.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/macaron-2.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/memory-transformer.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/normformer.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/pia.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/resi_dual.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/residual_attn.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/rezero.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/rotary.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/sandwich-2.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/sandwich.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/sandwich_norm.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/scalenorm.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/talking-heads.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/topk-attention.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/images/xval.png +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/train_belief_state.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/train_copy.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/train_enwik8.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/train_length_extrapolate.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/train_parity.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/__init__.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/attend.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/belief_state_wrapper.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/continuous.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/dpo.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.2.7 → x_transformers-2.2.9}/x_transformers/xval.py +0 -0
@@ -798,4 +798,24 @@ def test_entropy_based_tokenizer(
|
|
798
798
|
|
799
799
|
assert len(segmented_seq) == seq.shape[0]
|
800
800
|
|
801
|
-
tokenizer(seq[0]) # able to handle without batch dim
|
801
|
+
tokenizer(seq[0]) # able to handle without batch dim
|
802
|
+
|
803
|
+
def test_custom_ff_activation():
|
804
|
+
|
805
|
+
model = TransformerWrapper(
|
806
|
+
num_tokens = 20000,
|
807
|
+
max_seq_len = 1024,
|
808
|
+
attn_layers = Decoder(
|
809
|
+
dim = 128,
|
810
|
+
depth = 6,
|
811
|
+
heads = 8,
|
812
|
+
attn_dim_head = 64,
|
813
|
+
ff_custom_activation = nn.Sigmoid()
|
814
|
+
)
|
815
|
+
)
|
816
|
+
|
817
|
+
seq = torch.randint(0, 20000, (2, 1024))
|
818
|
+
|
819
|
+
logits = model(seq)
|
820
|
+
|
821
|
+
assert logits.shape == (2, 1024, 20000)
|
@@ -0,0 +1,118 @@
|
|
1
|
+
from x_transformers import TransformerWrapper, Decoder
|
2
|
+
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
3
|
+
from x_transformers.entropy_based_tokenizer import EntropyBasedTokenizer
|
4
|
+
|
5
|
+
import random
|
6
|
+
import tqdm
|
7
|
+
import gzip
|
8
|
+
import numpy as np
|
9
|
+
import torch
|
10
|
+
import torch.optim as optim
|
11
|
+
from torch.nn import functional as F
|
12
|
+
from torch.utils.data import DataLoader, Dataset
|
13
|
+
|
14
|
+
# constants
|
15
|
+
|
16
|
+
NUM_BATCHES = int(1e5)
|
17
|
+
BATCH_SIZE = 4
|
18
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
19
|
+
LEARNING_RATE = 1e-4
|
20
|
+
VALIDATE_EVERY = 100
|
21
|
+
GENERATE_EVERY = 100
|
22
|
+
GENERATE_LENGTH = 1024
|
23
|
+
SEQ_LEN = 1024
|
24
|
+
|
25
|
+
# helpers
|
26
|
+
|
27
|
+
def cycle(loader):
|
28
|
+
while True:
|
29
|
+
for data in loader:
|
30
|
+
yield data
|
31
|
+
|
32
|
+
def decode_token(token):
|
33
|
+
return str(chr(max(32, token)))
|
34
|
+
|
35
|
+
def decode_tokens(tokens):
|
36
|
+
return ''.join(list(map(decode_token, tokens)))
|
37
|
+
|
38
|
+
# instantiate GPT-like decoder model
|
39
|
+
|
40
|
+
model = TransformerWrapper(
|
41
|
+
num_tokens = 256,
|
42
|
+
max_seq_len = SEQ_LEN,
|
43
|
+
attn_layers = Decoder(
|
44
|
+
dim = 512,
|
45
|
+
depth = 6,
|
46
|
+
heads = 8,
|
47
|
+
rotary_pos_emb = True
|
48
|
+
)
|
49
|
+
)
|
50
|
+
|
51
|
+
tokenizer = EntropyBasedTokenizer(
|
52
|
+
model,
|
53
|
+
entropy_threshold = 2.5
|
54
|
+
)
|
55
|
+
|
56
|
+
model = AutoregressiveWrapper(model)
|
57
|
+
model.cuda()
|
58
|
+
|
59
|
+
# prepare enwik8 data
|
60
|
+
|
61
|
+
with gzip.open('./data/enwik8.gz') as file:
|
62
|
+
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
|
63
|
+
train_x, valid_x = np.split(data, [int(90e6)])
|
64
|
+
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
|
65
|
+
|
66
|
+
class TextSamplerDataset(Dataset):
|
67
|
+
def __init__(self, data, seq_len):
|
68
|
+
super().__init__()
|
69
|
+
self.data = data
|
70
|
+
self.seq_len = seq_len
|
71
|
+
|
72
|
+
def __getitem__(self, index):
|
73
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
|
74
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
75
|
+
return full_seq.cuda()
|
76
|
+
|
77
|
+
def __len__(self):
|
78
|
+
return self.data.size(0) // self.seq_len
|
79
|
+
|
80
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
81
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
82
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
83
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
84
|
+
|
85
|
+
# optimizer
|
86
|
+
|
87
|
+
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
88
|
+
|
89
|
+
# training
|
90
|
+
|
91
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
92
|
+
model.train()
|
93
|
+
|
94
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
95
|
+
loss = model(next(train_loader))
|
96
|
+
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
|
97
|
+
|
98
|
+
print(f'training loss: {loss.item()}')
|
99
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
100
|
+
optim.step()
|
101
|
+
optim.zero_grad()
|
102
|
+
|
103
|
+
if i % VALIDATE_EVERY == 0:
|
104
|
+
model.eval()
|
105
|
+
with torch.no_grad():
|
106
|
+
loss = model(next(val_loader))
|
107
|
+
print(f'validation loss: {loss.item()}')
|
108
|
+
|
109
|
+
if i % GENERATE_EVERY == 0:
|
110
|
+
model.eval()
|
111
|
+
inp = random.choice(val_dataset)[:-1]
|
112
|
+
|
113
|
+
tokens = tokenizer(inp, return_segmented_seq = True)
|
114
|
+
|
115
|
+
delimiter = " \u275A "
|
116
|
+
output_str = delimiter.join([decode_tokens(token) for token in tokens])
|
117
|
+
|
118
|
+
print(f"{output_str}\n\n")
|
@@ -33,12 +33,15 @@ class EntropyBasedTokenizer(Module):
|
|
33
33
|
def __init__(
|
34
34
|
self,
|
35
35
|
decoder: Module,
|
36
|
-
entropy_threshold: float
|
36
|
+
entropy_threshold: float,
|
37
|
+
max_token_size: int | None = None
|
37
38
|
):
|
38
39
|
super().__init__()
|
39
40
|
self.decoder = decoder
|
40
41
|
self.entropy_threshold = entropy_threshold
|
41
42
|
|
43
|
+
self.max_token_size = max_token_size
|
44
|
+
|
42
45
|
@torch.no_grad()
|
43
46
|
def forward(
|
44
47
|
self,
|
@@ -53,7 +56,7 @@ class EntropyBasedTokenizer(Module):
|
|
53
56
|
self.decoder.eval()
|
54
57
|
|
55
58
|
is_var_length = exists(lens)
|
56
|
-
batch, seq_len, device = *seq.shape, seq.device
|
59
|
+
batch, seq_len, device, max_token_size = *seq.shape, seq.device, self.max_token_size
|
57
60
|
|
58
61
|
arange = torch.arange(seq_len, device = device)
|
59
62
|
|
@@ -94,7 +97,29 @@ class EntropyBasedTokenizer(Module):
|
|
94
97
|
scatter_indices = rearrange(lens - 1, 'b -> b 1')
|
95
98
|
boundaries.scatter_(-1, scatter_indices, True)
|
96
99
|
|
97
|
-
|
100
|
+
# handle max token size - technique has the flaw that repeating subsequences are grouped into one large token
|
101
|
+
|
102
|
+
if exists(max_token_size):
|
103
|
+
token_ids = boundaries.cumsum(dim = -1)
|
104
|
+
token_ids = F.pad(token_ids, (1, -1), value = 0)
|
105
|
+
|
106
|
+
max_num_tokens = boundaries.sum(dim = -1).amax().item()
|
107
|
+
token_ids_seq = torch.arange(max_num_tokens, device = device)
|
108
|
+
|
109
|
+
token_mask = einx.equal('j, b i -> b j i', token_ids_seq, token_ids)
|
110
|
+
|
111
|
+
token_sub_seq_arange = token_mask.cumsum(dim = -1)
|
112
|
+
|
113
|
+
sub_seq_boundaries = (token_sub_seq_arange % max_token_size == 0)
|
114
|
+
sub_seq_boundaries = (sub_seq_boundaries & token_mask).any(dim = 1)
|
115
|
+
|
116
|
+
boundaries = boundaries | sub_seq_boundaries
|
117
|
+
|
118
|
+
# number of tokens
|
119
|
+
|
120
|
+
num_tokens = boundaries.sum(dim = -1)
|
121
|
+
|
122
|
+
# get number of tokens as well as derived indices
|
98
123
|
|
99
124
|
indices = arange_plus_one[boundaries].split(num_tokens.tolist())
|
100
125
|
|
@@ -1196,6 +1196,7 @@ class FeedForward(Module):
|
|
1196
1196
|
glu_mult_bias = False,
|
1197
1197
|
swish = False,
|
1198
1198
|
relu_squared = False,
|
1199
|
+
custom_activation = None,
|
1199
1200
|
post_act_ln = False,
|
1200
1201
|
dropout = 0.,
|
1201
1202
|
no_bias = False,
|
@@ -1205,7 +1206,9 @@ class FeedForward(Module):
|
|
1205
1206
|
inner_dim = int(dim * mult)
|
1206
1207
|
dim_out = default(dim_out, dim)
|
1207
1208
|
|
1208
|
-
if
|
1209
|
+
if exists(custom_activation):
|
1210
|
+
activation = deepcopy(custom_activation)
|
1211
|
+
elif relu_squared:
|
1209
1212
|
activation = ReluSquared()
|
1210
1213
|
elif swish:
|
1211
1214
|
activation = nn.SiLU()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|