titans-pytorch 0.0.32__tar.gz → 0.0.35__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/PKG-INFO +1 -1
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/pyproject.toml +1 -1
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/__init__.py +2 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/mac_transformer.py +25 -4
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/train.py +0 -3
- titans_pytorch-0.0.35/train_mac.py +129 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/.gitignore +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/LICENSE +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/README.md +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/data/README.md +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/fig1.png +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/fig2.png +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/requirements.txt +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -50,13 +50,20 @@ def pad_and_segment_with_inverse(seq, segment_len):
|
|
50
50
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
51
51
|
|
52
52
|
padding = next_seq_len_mult - seq_len
|
53
|
-
|
53
|
+
needs_pad = padding > 0
|
54
|
+
|
55
|
+
if needs_pad:
|
56
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
54
57
|
|
55
58
|
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
56
59
|
|
57
60
|
def inverse(out):
|
58
61
|
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
59
|
-
|
62
|
+
|
63
|
+
if needs_pad:
|
64
|
+
out = out[:, :-padding]
|
65
|
+
|
66
|
+
return out
|
60
67
|
|
61
68
|
return seq, inverse
|
62
69
|
|
@@ -226,7 +233,14 @@ class MemoryAsContextTransformer(Module):
|
|
226
233
|
|
227
234
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
228
235
|
|
229
|
-
def forward(
|
236
|
+
def forward(
|
237
|
+
self,
|
238
|
+
x,
|
239
|
+
return_loss = False
|
240
|
+
):
|
241
|
+
|
242
|
+
if return_loss:
|
243
|
+
x, labels = x[:, :-1], x[:, 1:]
|
230
244
|
|
231
245
|
# math
|
232
246
|
|
@@ -262,6 +276,7 @@ class MemoryAsContextTransformer(Module):
|
|
262
276
|
|
263
277
|
if exists(maybe_neural_mem):
|
264
278
|
batch_streams = x.shape[0]
|
279
|
+
|
265
280
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
266
281
|
|
267
282
|
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
@@ -277,6 +292,7 @@ class MemoryAsContextTransformer(Module):
|
|
277
292
|
x = inverse_segment(x)
|
278
293
|
|
279
294
|
x = attn(x)
|
295
|
+
|
280
296
|
x = ff(x)
|
281
297
|
|
282
298
|
x = self.reduce_streams(x)
|
@@ -293,4 +309,9 @@ class MemoryAsContextTransformer(Module):
|
|
293
309
|
|
294
310
|
x = self.norm(x)
|
295
311
|
|
296
|
-
|
312
|
+
logits = self.to_logits(x)
|
313
|
+
|
314
|
+
if not return_loss:
|
315
|
+
return logits
|
316
|
+
|
317
|
+
return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
@@ -63,11 +63,8 @@ def decode_tokens(tokens):
|
|
63
63
|
titans_neural_memory = NeuralMemory(
|
64
64
|
dim = 384,
|
65
65
|
chunk_size = 4,
|
66
|
-
pre_rmsnorm = True,
|
67
|
-
post_rmsnorm = True,
|
68
66
|
dim_head = 64,
|
69
67
|
heads = 4,
|
70
|
-
max_grad_norm = 1.,
|
71
68
|
use_accelerated_scan = True,
|
72
69
|
default_mlp_kwargs = dict(
|
73
70
|
depth = NEURAL_MEMORY_DEPTH
|
@@ -0,0 +1,129 @@
|
|
1
|
+
import random
|
2
|
+
import tqdm
|
3
|
+
import gzip
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from torch.optim import Adam
|
9
|
+
from torch.nn import functional as F
|
10
|
+
from torch.utils.data import DataLoader, Dataset
|
11
|
+
|
12
|
+
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
13
|
+
|
14
|
+
# constants
|
15
|
+
|
16
|
+
NUM_BATCHES = int(1e5)
|
17
|
+
BATCH_SIZE = 4
|
18
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
19
|
+
LEARNING_RATE = 2e-4
|
20
|
+
VALIDATE_EVERY = 100
|
21
|
+
GENERATE_EVERY = 500
|
22
|
+
GENERATE_LENGTH = 512
|
23
|
+
SHOULD_GENERATE = False
|
24
|
+
SEQ_LEN = 512
|
25
|
+
|
26
|
+
PROJECT_NAME = 'titans-mac-transformer'
|
27
|
+
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
28
|
+
GLOBAL_LAYERS = (2, 4)
|
29
|
+
NEURAL_MEMORY_DEPTH = 2
|
30
|
+
WINDOW_SIZE = 64
|
31
|
+
RUN_NAME = 'mac'
|
32
|
+
|
33
|
+
# wandb experiment tracker
|
34
|
+
|
35
|
+
import wandb
|
36
|
+
wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
|
37
|
+
wandb.run.name = RUN_NAME
|
38
|
+
wandb.run.save()
|
39
|
+
|
40
|
+
# helpers
|
41
|
+
|
42
|
+
def cycle(loader):
|
43
|
+
while True:
|
44
|
+
for data in loader:
|
45
|
+
yield data
|
46
|
+
|
47
|
+
def decode_token(token):
|
48
|
+
return str(chr(max(32, token)))
|
49
|
+
|
50
|
+
def decode_tokens(tokens):
|
51
|
+
return ''.join(list(map(decode_token, tokens)))
|
52
|
+
|
53
|
+
# instantiate memory-as-context transformer
|
54
|
+
|
55
|
+
model = MemoryAsContextTransformer(
|
56
|
+
num_tokens = 256,
|
57
|
+
dim = 384,
|
58
|
+
depth = 8,
|
59
|
+
segment_len = WINDOW_SIZE,
|
60
|
+
num_persist_mem_tokens = 16,
|
61
|
+
num_longterm_mem_tokens = 16,
|
62
|
+
neural_memory_layers = (3, 4),
|
63
|
+
neural_memory_kwargs = dict(
|
64
|
+
default_mlp_kwargs = dict(
|
65
|
+
depth = NEURAL_MEMORY_DEPTH
|
66
|
+
)
|
67
|
+
)
|
68
|
+
).cuda()
|
69
|
+
|
70
|
+
# prepare enwik8 data
|
71
|
+
|
72
|
+
with gzip.open('./data/enwik8.gz') as file:
|
73
|
+
data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
|
74
|
+
data_train, data_val = np.split(data, [int(90e6)])
|
75
|
+
data_train, data_val = map(torch.from_numpy, (data_train, data_val))
|
76
|
+
|
77
|
+
class TextSamplerDataset(Dataset):
|
78
|
+
def __init__(self, data, seq_len):
|
79
|
+
super().__init__()
|
80
|
+
self.data = data
|
81
|
+
self.seq_len = seq_len
|
82
|
+
|
83
|
+
def __getitem__(self, index):
|
84
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
85
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
86
|
+
return full_seq.cuda()
|
87
|
+
|
88
|
+
def __len__(self):
|
89
|
+
return self.data.size(0) // self.seq_len
|
90
|
+
|
91
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
92
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
93
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
|
94
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
95
|
+
|
96
|
+
# optimizer
|
97
|
+
|
98
|
+
optim = Adam(model.parameters(), lr=LEARNING_RATE)
|
99
|
+
|
100
|
+
# training
|
101
|
+
|
102
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
103
|
+
model.train()
|
104
|
+
|
105
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
106
|
+
loss = model(next(train_loader), return_loss = True)
|
107
|
+
loss.backward()
|
108
|
+
|
109
|
+
print(f'training loss: {loss.item()}')
|
110
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
111
|
+
optim.step()
|
112
|
+
optim.zero_grad()
|
113
|
+
wandb.log(dict(loss = loss.item()))
|
114
|
+
|
115
|
+
if i % VALIDATE_EVERY == 0:
|
116
|
+
model.eval()
|
117
|
+
with torch.no_grad():
|
118
|
+
loss = model(next(val_loader), return_loss = True)
|
119
|
+
print(f'validation loss: {loss.item()}')
|
120
|
+
|
121
|
+
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
122
|
+
model.eval()
|
123
|
+
inp = random.choice(val_dataset)[:-1]
|
124
|
+
prime = decode_tokens(inp)
|
125
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
126
|
+
|
127
|
+
sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
|
128
|
+
output_str = decode_tokens(sample[0])
|
129
|
+
print(output_str)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|