x-transformers 2.1.15__tar.gz → 2.1.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-2.1.15 → x_transformers-2.1.17}/PKG-INFO +1 -1
- {x_transformers-2.1.15 → x_transformers-2.1.17}/pyproject.toml +1 -1
- {x_transformers-2.1.15 → x_transformers-2.1.17}/tests/test_x_transformers.py +1 -2
- x_transformers-2.1.17/train_belief_state.py +133 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/belief_state_wrapper.py +65 -26
- {x_transformers-2.1.15 → x_transformers-2.1.17}/.github/FUNDING.yml +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/.github/workflows/python-publish.yml +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/.github/workflows/python-test.yaml +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/.gitignore +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/LICENSE +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/README.md +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/data/README.md +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/data/enwik8.gz +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/all-attention.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/attention-on-attention.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/cosine-sim-attention.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/deepnorm.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/dynamic-pos-bias-linear.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/dynamic-pos-bias-log.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/dynamic-pos-bias-sinusoidal.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/dynamic-pos-bias.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/enhanced-recurrence.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/fcm.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/ffglu.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/flash-attention.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/gate_values.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/gating.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/length-extrapolation-scale.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/macaron-1.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/macaron-2.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/memory-transformer.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/normformer.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/pia.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/qknorm-analysis.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/resi_dual.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/residual_attn.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/rezero.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/rotary.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/sandwich-2.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/sandwich.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/sandwich_norm.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/scalenorm.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/talking-heads.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/topk-attention.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/images/xval.png +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/train_copy.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/train_enwik8.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/train_length_extrapolate.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/train_parity.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/__init__.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/attend.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/continuous.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/dpo.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/multi_input.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/neo_mlp.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/x_transformers.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-2.1.15 → x_transformers-2.1.17}/x_transformers/xval.py +0 -0
@@ -0,0 +1,133 @@
|
|
1
|
+
from x_transformers import TransformerWrapper, Decoder, BeliefStateWrapper
|
2
|
+
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
3
|
+
|
4
|
+
import random
|
5
|
+
import tqdm
|
6
|
+
import gzip
|
7
|
+
import numpy as np
|
8
|
+
import torch
|
9
|
+
import torch.optim as optim
|
10
|
+
from torch.nn import functional as F
|
11
|
+
from torch.utils.data import DataLoader, Dataset
|
12
|
+
|
13
|
+
# constants
|
14
|
+
|
15
|
+
NUM_BATCHES = int(1e5)
|
16
|
+
BATCH_SIZE = 2
|
17
|
+
GRADIENT_ACCUMULATE_EVERY = 8
|
18
|
+
LEARNING_RATE = 1e-4
|
19
|
+
VALIDATE_EVERY = 100
|
20
|
+
GENERATE_EVERY = 500
|
21
|
+
GENERATE_LENGTH = 256
|
22
|
+
SEQ_LEN = 256
|
23
|
+
|
24
|
+
# helpers
|
25
|
+
|
26
|
+
def cycle(loader):
|
27
|
+
while True:
|
28
|
+
for data in loader:
|
29
|
+
yield data
|
30
|
+
|
31
|
+
def decode_token(token):
|
32
|
+
return str(chr(max(32, token)))
|
33
|
+
|
34
|
+
def decode_tokens(tokens):
|
35
|
+
return ''.join(list(map(decode_token, tokens)))
|
36
|
+
|
37
|
+
# instantiate GPT-like decoder model for forward and backwards
|
38
|
+
|
39
|
+
forward_model = TransformerWrapper(
|
40
|
+
num_tokens = 256,
|
41
|
+
max_seq_len = SEQ_LEN,
|
42
|
+
attn_layers = Decoder(
|
43
|
+
dim = 512,
|
44
|
+
depth = 6,
|
45
|
+
heads = 8,
|
46
|
+
rotary_pos_emb = True
|
47
|
+
)
|
48
|
+
)
|
49
|
+
|
50
|
+
backward_model = TransformerWrapper(
|
51
|
+
num_tokens = 256,
|
52
|
+
max_seq_len = SEQ_LEN,
|
53
|
+
attn_layers = Decoder(
|
54
|
+
dim = 512,
|
55
|
+
depth = 4, # do a smaller backwards
|
56
|
+
heads = 8,
|
57
|
+
rotary_pos_emb = True
|
58
|
+
)
|
59
|
+
)
|
60
|
+
|
61
|
+
model = BeliefStateWrapper(
|
62
|
+
forward_decoder = forward_model,
|
63
|
+
backward_decoder = backward_model
|
64
|
+
)
|
65
|
+
|
66
|
+
model.cuda()
|
67
|
+
|
68
|
+
# prepare enwik8 data
|
69
|
+
|
70
|
+
with gzip.open('./data/enwik8.gz') as file:
|
71
|
+
data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
|
72
|
+
train_x, valid_x = np.split(data, [int(90e6)])
|
73
|
+
data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)
|
74
|
+
|
75
|
+
class TextSamplerDataset(Dataset):
|
76
|
+
def __init__(self, data, seq_len):
|
77
|
+
super().__init__()
|
78
|
+
self.data = data
|
79
|
+
self.seq_len = seq_len
|
80
|
+
|
81
|
+
def __getitem__(self, index):
|
82
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
|
83
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
84
|
+
return full_seq.cuda()
|
85
|
+
|
86
|
+
def __len__(self):
|
87
|
+
return self.data.size(0) // self.seq_len
|
88
|
+
|
89
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
90
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
91
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
92
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))
|
93
|
+
|
94
|
+
# optimizer
|
95
|
+
|
96
|
+
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
97
|
+
|
98
|
+
# training
|
99
|
+
|
100
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
|
101
|
+
model.train()
|
102
|
+
|
103
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
104
|
+
loss = model(
|
105
|
+
next(train_loader),
|
106
|
+
loss_scale = 1./ GRADIENT_ACCUMULATE_EVERY
|
107
|
+
)
|
108
|
+
|
109
|
+
print(f'training loss: {loss.item()}')
|
110
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
111
|
+
optim.step()
|
112
|
+
optim.zero_grad()
|
113
|
+
|
114
|
+
if i % VALIDATE_EVERY == 0:
|
115
|
+
model.eval()
|
116
|
+
with torch.no_grad():
|
117
|
+
loss = model(next(val_loader), return_loss_only = True)
|
118
|
+
print(f'validation loss: {loss.item()}')
|
119
|
+
|
120
|
+
if i % GENERATE_EVERY == 0:
|
121
|
+
model.eval()
|
122
|
+
inp = random.choice(val_dataset)[:-1]
|
123
|
+
prime = decode_tokens(inp)
|
124
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
125
|
+
|
126
|
+
sample = model.generate_with_suffix_cond(
|
127
|
+
prompts = inp,
|
128
|
+
seq_len = GENERATE_LENGTH,
|
129
|
+
cache_kv = True
|
130
|
+
)
|
131
|
+
|
132
|
+
output_str = decode_tokens(sample)
|
133
|
+
print(output_str)
|
@@ -23,7 +23,8 @@ from x_transformers.x_transformers import (
|
|
23
23
|
)
|
24
24
|
|
25
25
|
import einx
|
26
|
-
from einops import rearrange, repeat
|
26
|
+
from einops import rearrange, repeat, pack, unpack
|
27
|
+
from einops.layers.torch import Rearrange
|
27
28
|
|
28
29
|
# helper functions
|
29
30
|
|
@@ -55,7 +56,9 @@ class BeliefStateWrapper(Module):
|
|
55
56
|
backward_decoder: TransformerWrapper | None = None,
|
56
57
|
train_frac_forward_backward_pairs: float = 1.,
|
57
58
|
text_head: Module | None = None,
|
58
|
-
backward_ar_loss_weight: float = 1
|
59
|
+
backward_ar_loss_weight: float = 1., # can weigh the training of the backwards decoder differently, perhaps fwd/bwd have a shared backbone etc etc
|
60
|
+
pred_terminal = False,
|
61
|
+
pred_terminal_loss_weight: float = 1.
|
59
62
|
):
|
60
63
|
super().__init__()
|
61
64
|
backward_decoder = default(backward_decoder, forward_decoder) # if backward decoder not set, use the same transformer, assume it knows how to switch gears based on suffix token
|
@@ -66,6 +69,8 @@ class BeliefStateWrapper(Module):
|
|
66
69
|
dim = forward_decoder.emb_dim
|
67
70
|
num_tokens = forward_decoder.num_tokens
|
68
71
|
|
72
|
+
self.to_forward_logits = nn.Linear(dim, num_tokens, bias = False)
|
73
|
+
|
69
74
|
# the suffix token
|
70
75
|
|
71
76
|
self.suffix_token = nn.Parameter(torch.zeros(dim))
|
@@ -82,6 +87,17 @@ class BeliefStateWrapper(Module):
|
|
82
87
|
|
83
88
|
self.text_head = text_head
|
84
89
|
|
90
|
+
# predicting terminal state (when suffix and prefix predict the same token)
|
91
|
+
|
92
|
+
self.to_terminal_logit = nn.Sequential(
|
93
|
+
nn.Linear(dim * 2, dim),
|
94
|
+
nn.LeakyReLU(),
|
95
|
+
nn.Linear(dim, 1),
|
96
|
+
Rearrange('... 1 -> ...')
|
97
|
+
) if pred_terminal else None
|
98
|
+
|
99
|
+
self.pred_terminal_loss_weight = pred_terminal_loss_weight
|
100
|
+
|
85
101
|
# the two decoders, one which is causal forward, the other causal backwards
|
86
102
|
|
87
103
|
self.forward_decoder = forward_decoder
|
@@ -112,7 +128,7 @@ class BeliefStateWrapper(Module):
|
|
112
128
|
prompts,
|
113
129
|
seq_len,
|
114
130
|
temperature = 1.25,
|
115
|
-
cache_kv =
|
131
|
+
cache_kv = False,
|
116
132
|
suffix: Tensor | None = None, # the goal conditioning
|
117
133
|
filter_logits_fn = min_p,
|
118
134
|
filter_kwargs = dict(
|
@@ -122,6 +138,8 @@ class BeliefStateWrapper(Module):
|
|
122
138
|
):
|
123
139
|
max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
|
124
140
|
|
141
|
+
prompts, batch_ps = pack([prompts], '* d')
|
142
|
+
|
125
143
|
batch, orig_seq_len = prompts.shape
|
126
144
|
|
127
145
|
out = prompts
|
@@ -183,14 +201,19 @@ class BeliefStateWrapper(Module):
|
|
183
201
|
|
184
202
|
# concat sample
|
185
203
|
|
186
|
-
out = torch.cat((out, sample), dim
|
204
|
+
out = torch.cat((out, sample), dim = -1)
|
205
|
+
|
206
|
+
out = out[:, orig_seq_len:]
|
187
207
|
|
188
|
-
|
208
|
+
out, = unpack(out, batch_ps, '* n')
|
209
|
+
|
210
|
+
return out
|
189
211
|
|
190
212
|
def forward(
|
191
213
|
self,
|
192
214
|
seq,
|
193
|
-
|
215
|
+
return_loss_only = False,
|
216
|
+
loss_scale = 1.
|
194
217
|
):
|
195
218
|
batch, seq_len, device = *seq.shape, seq.device
|
196
219
|
|
@@ -230,6 +253,7 @@ class BeliefStateWrapper(Module):
|
|
230
253
|
# f - forward, b - backward, i - indices
|
231
254
|
|
232
255
|
fi, bi = fb_pairs.unbind(dim = -1)
|
256
|
+
|
233
257
|
valid_mask = (bi - fi) >= 2
|
234
258
|
|
235
259
|
fb_pairs = fb_pairs[valid_mask]
|
@@ -251,8 +275,9 @@ class BeliefStateWrapper(Module):
|
|
251
275
|
|
252
276
|
labels_fi, labels_bi = (fi + 1), bi
|
253
277
|
|
254
|
-
forward_labels, backward_labels = seq[:,
|
255
|
-
|
278
|
+
forward_labels, backward_labels = seq[:, labels_fi], seq[:, labels_bi]
|
279
|
+
|
280
|
+
labels = cat((forward_labels, backward_labels), dim = -1)
|
256
281
|
|
257
282
|
# get the forward and backward embedding pairs and feed them through the text head for both forward and backward predictions
|
258
283
|
|
@@ -265,33 +290,47 @@ class BeliefStateWrapper(Module):
|
|
265
290
|
|
266
291
|
# cross entropy loss
|
267
292
|
|
268
|
-
|
293
|
+
loss = F.cross_entropy(
|
269
294
|
rearrange(logits, 'b n (fb l) -> b l (fb n)', fb = 2),
|
270
|
-
|
295
|
+
labels,
|
271
296
|
reduction = 'none' if self.needs_loss_weight else 'mean'
|
272
297
|
)
|
273
298
|
|
299
|
+
# maybe predict terminal
|
300
|
+
|
301
|
+
if exists(self.to_terminal_logit):
|
302
|
+
terminal_logits = self.to_terminal_logit(fb_embeds)
|
303
|
+
|
304
|
+
terminal_labels = ((bi - fi) == 2).float() # distance is exactly 2
|
305
|
+
terminal_labels = repeat(terminal_labels, 'n -> b n', b = batch)
|
306
|
+
|
307
|
+
is_end_loss = F.binary_cross_entropy_with_logits(
|
308
|
+
terminal_logits,
|
309
|
+
terminal_labels
|
310
|
+
)
|
311
|
+
|
312
|
+
loss = (
|
313
|
+
loss +
|
314
|
+
is_end_loss * self.pred_terminal_loss_weight
|
315
|
+
)
|
316
|
+
|
317
|
+
# maybe early return loss
|
318
|
+
|
319
|
+
if return_loss_only:
|
320
|
+
return loss
|
321
|
+
|
274
322
|
# maybe loss weighting
|
275
323
|
|
276
324
|
if self.needs_loss_weight:
|
277
|
-
|
278
|
-
|
279
|
-
|
325
|
+
loss = rearrange(loss, 'b (fb n) -> b fb n', fb = 2)
|
326
|
+
loss = einx.multiply('b fb n, fb', loss, self.loss_weights)
|
327
|
+
loss = loss.mean()
|
280
328
|
|
281
329
|
# backwards
|
282
330
|
|
283
|
-
|
284
|
-
|
285
|
-
def patched_backward_fn(*args, **kwargs):
|
286
|
-
orig_backward(*args, **kwargs)
|
287
|
-
orig_forward_embeds.backward(forward_embeds.grad)
|
288
|
-
orig_backward_embeds.backward(backward_embeds.grad)
|
289
|
-
|
290
|
-
# can allow the researcher to call .backward from the outside
|
331
|
+
(loss * loss_scale).backward()
|
291
332
|
|
292
|
-
|
293
|
-
|
294
|
-
else:
|
295
|
-
setattr(fb_loss, 'backward', patched_backward_fn)
|
333
|
+
orig_forward_embeds.backward(forward_embeds.grad)
|
334
|
+
orig_backward_embeds.backward(backward_embeds.grad)
|
296
335
|
|
297
|
-
return
|
336
|
+
return loss
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|