titans-pytorch 0.0.32__tar.gz → 0.0.35__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/PKG-INFO +1 -1
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/pyproject.toml +1 -1
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/__init__.py +2 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/mac_transformer.py +25 -4
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/train.py +0 -3
- titans_pytorch-0.0.35/train_mac.py +129 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/.gitignore +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/LICENSE +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/README.md +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/data/README.md +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/fig1.png +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/fig2.png +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/requirements.txt +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.32 → titans_pytorch-0.0.35}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -50,13 +50,20 @@ def pad_and_segment_with_inverse(seq, segment_len):
|
|
50
50
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
51
51
|
|
52
52
|
padding = next_seq_len_mult - seq_len
|
53
|
-
|
53
|
+
needs_pad = padding > 0
|
54
|
+
|
55
|
+
if needs_pad:
|
56
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
54
57
|
|
55
58
|
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
56
59
|
|
57
60
|
def inverse(out):
|
58
61
|
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
59
|
-
|
62
|
+
|
63
|
+
if needs_pad:
|
64
|
+
out = out[:, :-padding]
|
65
|
+
|
66
|
+
return out
|
60
67
|
|
61
68
|
return seq, inverse
|
62
69
|
|
@@ -226,7 +233,14 @@ class MemoryAsContextTransformer(Module):
|
|
226
233
|
|
227
234
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
228
235
|
|
229
|
-
def forward(
|
236
|
+
def forward(
|
237
|
+
self,
|
238
|
+
x,
|
239
|
+
return_loss = False
|
240
|
+
):
|
241
|
+
|
242
|
+
if return_loss:
|
243
|
+
x, labels = x[:, :-1], x[:, 1:]
|
230
244
|
|
231
245
|
# math
|
232
246
|
|
@@ -262,6 +276,7 @@ class MemoryAsContextTransformer(Module):
|
|
262
276
|
|
263
277
|
if exists(maybe_neural_mem):
|
264
278
|
batch_streams = x.shape[0]
|
279
|
+
|
265
280
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
266
281
|
|
267
282
|
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
@@ -277,6 +292,7 @@ class MemoryAsContextTransformer(Module):
|
|
277
292
|
x = inverse_segment(x)
|
278
293
|
|
279
294
|
x = attn(x)
|
295
|
+
|
280
296
|
x = ff(x)
|
281
297
|
|
282
298
|
x = self.reduce_streams(x)
|
@@ -293,4 +309,9 @@ class MemoryAsContextTransformer(Module):
|
|
293
309
|
|
294
310
|
x = self.norm(x)
|
295
311
|
|
296
|
-
|
312
|
+
logits = self.to_logits(x)
|
313
|
+
|
314
|
+
if not return_loss:
|
315
|
+
return logits
|
316
|
+
|
317
|
+
return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
@@ -63,11 +63,8 @@ def decode_tokens(tokens):
|
|
63
63
|
titans_neural_memory = NeuralMemory(
|
64
64
|
dim = 384,
|
65
65
|
chunk_size = 4,
|
66
|
-
pre_rmsnorm = True,
|
67
|
-
post_rmsnorm = True,
|
68
66
|
dim_head = 64,
|
69
67
|
heads = 4,
|
70
|
-
max_grad_norm = 1.,
|
71
68
|
use_accelerated_scan = True,
|
72
69
|
default_mlp_kwargs = dict(
|
73
70
|
depth = NEURAL_MEMORY_DEPTH
|
@@ -0,0 +1,129 @@
|
|
1
|
+
import random
|
2
|
+
import tqdm
|
3
|
+
import gzip
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from torch.optim import Adam
|
9
|
+
from torch.nn import functional as F
|
10
|
+
from torch.utils.data import DataLoader, Dataset
|
11
|
+
|
12
|
+
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
13
|
+
|
14
|
+
# constants
|
15
|
+
|
16
|
+
NUM_BATCHES = int(1e5)
|
17
|
+
BATCH_SIZE = 4
|
18
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
19
|
+
LEARNING_RATE = 2e-4
|
20
|
+
VALIDATE_EVERY = 100
|
21
|
+
GENERATE_EVERY = 500
|
22
|
+
GENERATE_LENGTH = 512
|
23
|
+
SHOULD_GENERATE = False
|
24
|
+
SEQ_LEN = 512
|
25
|
+
|
26
|
+
PROJECT_NAME = 'titans-mac-transformer'
|
27
|
+
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
28
|
+
GLOBAL_LAYERS = (2, 4)
|
29
|
+
NEURAL_MEMORY_DEPTH = 2
|
30
|
+
WINDOW_SIZE = 64
|
31
|
+
RUN_NAME = 'mac'
|
32
|
+
|
33
|
+
# wandb experiment tracker
|
34
|
+
|
35
|
+
import wandb
|
36
|
+
wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
|
37
|
+
wandb.run.name = RUN_NAME
|
38
|
+
wandb.run.save()
|
39
|
+
|
40
|
+
# helpers
|
41
|
+
|
42
|
+
def cycle(loader):
|
43
|
+
while True:
|
44
|
+
for data in loader:
|
45
|
+
yield data
|
46
|
+
|
47
|
+
def decode_token(token):
|
48
|
+
return str(chr(max(32, token)))
|
49
|
+
|
50
|
+
def decode_tokens(tokens):
|
51
|
+
return ''.join(list(map(decode_token, tokens)))
|
52
|
+
|
53
|
+
# instantiate memory-as-context transformer
|
54
|
+
|
55
|
+
model = MemoryAsContextTransformer(
|
56
|
+
num_tokens = 256,
|
57
|
+
dim = 384,
|
58
|
+
depth = 8,
|
59
|
+
segment_len = WINDOW_SIZE,
|
60
|
+
num_persist_mem_tokens = 16,
|
61
|
+
num_longterm_mem_tokens = 16,
|
62
|
+
neural_memory_layers = (3, 4),
|
63
|
+
neural_memory_kwargs = dict(
|
64
|
+
default_mlp_kwargs = dict(
|
65
|
+
depth = NEURAL_MEMORY_DEPTH
|
66
|
+
)
|
67
|
+
)
|
68
|
+
).cuda()
|
69
|
+
|
70
|
+
# prepare enwik8 data
|
71
|
+
|
72
|
+
with gzip.open('./data/enwik8.gz') as file:
|
73
|
+
data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
|
74
|
+
data_train, data_val = np.split(data, [int(90e6)])
|
75
|
+
data_train, data_val = map(torch.from_numpy, (data_train, data_val))
|
76
|
+
|
77
|
+
class TextSamplerDataset(Dataset):
|
78
|
+
def __init__(self, data, seq_len):
|
79
|
+
super().__init__()
|
80
|
+
self.data = data
|
81
|
+
self.seq_len = seq_len
|
82
|
+
|
83
|
+
def __getitem__(self, index):
|
84
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
85
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
86
|
+
return full_seq.cuda()
|
87
|
+
|
88
|
+
def __len__(self):
|
89
|
+
return self.data.size(0) // self.seq_len
|
90
|
+
|
91
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
92
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
93
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
|
94
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
95
|
+
|
96
|
+
# optimizer
|
97
|
+
|
98
|
+
optim = Adam(model.parameters(), lr=LEARNING_RATE)
|
99
|
+
|
100
|
+
# training
|
101
|
+
|
102
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
103
|
+
model.train()
|
104
|
+
|
105
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
106
|
+
loss = model(next(train_loader), return_loss = True)
|
107
|
+
loss.backward()
|
108
|
+
|
109
|
+
print(f'training loss: {loss.item()}')
|
110
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
111
|
+
optim.step()
|
112
|
+
optim.zero_grad()
|
113
|
+
wandb.log(dict(loss = loss.item()))
|
114
|
+
|
115
|
+
if i % VALIDATE_EVERY == 0:
|
116
|
+
model.eval()
|
117
|
+
with torch.no_grad():
|
118
|
+
loss = model(next(val_loader), return_loss = True)
|
119
|
+
print(f'validation loss: {loss.item()}')
|
120
|
+
|
121
|
+
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
122
|
+
model.eval()
|
123
|
+
inp = random.choice(val_dataset)[:-1]
|
124
|
+
prime = decode_tokens(inp)
|
125
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
126
|
+
|
127
|
+
sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
|
128
|
+
output_str = decode_tokens(sample[0])
|
129
|
+
print(output_str)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|