titans-pytorch 0.0.6__tar.gz → 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.6 → titans_pytorch-0.0.7}/PKG-INFO +1 -1
- titans_pytorch-0.0.7/data/README.md +3 -0
- titans_pytorch-0.0.7/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.6 → titans_pytorch-0.0.7}/pyproject.toml +1 -1
- {titans_pytorch-0.0.6 → titans_pytorch-0.0.7}/titans_pytorch/titans.py +9 -1
- titans_pytorch-0.0.7/train.py +108 -0
- {titans_pytorch-0.0.6 → titans_pytorch-0.0.7}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.6 → titans_pytorch-0.0.7}/.gitignore +0 -0
- {titans_pytorch-0.0.6 → titans_pytorch-0.0.7}/LICENSE +0 -0
- {titans_pytorch-0.0.6 → titans_pytorch-0.0.7}/README.md +0 -0
- {titans_pytorch-0.0.6 → titans_pytorch-0.0.7}/fig1.png +0 -0
- {titans_pytorch-0.0.6 → titans_pytorch-0.0.7}/fig2.png +0 -0
- {titans_pytorch-0.0.6 → titans_pytorch-0.0.7}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.6 → titans_pytorch-0.0.7}/titans_pytorch/associative_scan.py +0 -0
|
Binary file
|
|
@@ -91,10 +91,14 @@ class NeuralMemory(Module):
|
|
|
91
91
|
dim,
|
|
92
92
|
chunk_size = 1,
|
|
93
93
|
model: Module | None = None,
|
|
94
|
-
store_memory_loss_fn: Callable = default_loss_fn
|
|
94
|
+
store_memory_loss_fn: Callable = default_loss_fn,
|
|
95
|
+
pre_rmsnorm = False
|
|
95
96
|
):
|
|
96
97
|
super().__init__()
|
|
97
98
|
|
|
99
|
+
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
100
|
+
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
101
|
+
|
|
98
102
|
if not exists(model):
|
|
99
103
|
model = MLP(dim, depth = 4)
|
|
100
104
|
|
|
@@ -161,6 +165,8 @@ class NeuralMemory(Module):
|
|
|
161
165
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
|
162
166
|
):
|
|
163
167
|
|
|
168
|
+
seq = self.store_norm(seq)
|
|
169
|
+
|
|
164
170
|
# curtail sequence by multiple of the chunk size
|
|
165
171
|
# only a complete chunk of the sequence provides the memory for the next chunk
|
|
166
172
|
|
|
@@ -244,6 +250,8 @@ class NeuralMemory(Module):
|
|
|
244
250
|
chunk_size = self.chunk_size
|
|
245
251
|
batch, seq_len = seq.shape[:2]
|
|
246
252
|
|
|
253
|
+
seq = self.retrieve_norm(seq)
|
|
254
|
+
|
|
247
255
|
assert seq_len >= chunk_size
|
|
248
256
|
|
|
249
257
|
seq = seq[:, (chunk_size - 1):]
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import tqdm
|
|
3
|
+
import gzip
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.optim import Adam
|
|
8
|
+
from torch.nn import functional as F
|
|
9
|
+
from torch.utils.data import DataLoader, Dataset
|
|
10
|
+
|
|
11
|
+
from local_attention import LocalTransformer
|
|
12
|
+
|
|
13
|
+
from titans_pytorch.titans import NeuralMemory
|
|
14
|
+
|
|
15
|
+
# constants
|
|
16
|
+
|
|
17
|
+
NUM_BATCHES = int(1e5)
|
|
18
|
+
BATCH_SIZE = 4
|
|
19
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
|
20
|
+
LEARNING_RATE = 2e-4
|
|
21
|
+
VALIDATE_EVERY = 100
|
|
22
|
+
GENERATE_EVERY = 500
|
|
23
|
+
GENERATE_LENGTH = 512
|
|
24
|
+
SEQ_LEN = 512
|
|
25
|
+
|
|
26
|
+
# helpers
|
|
27
|
+
|
|
28
|
+
def cycle(loader):
|
|
29
|
+
while True:
|
|
30
|
+
for data in loader:
|
|
31
|
+
yield data
|
|
32
|
+
|
|
33
|
+
def decode_token(token):
|
|
34
|
+
return str(chr(max(32, token)))
|
|
35
|
+
|
|
36
|
+
def decode_tokens(tokens):
|
|
37
|
+
return ''.join(list(map(decode_token, tokens)))
|
|
38
|
+
|
|
39
|
+
# instantiate GPT-like decoder model
|
|
40
|
+
|
|
41
|
+
model = LocalTransformer(
|
|
42
|
+
num_tokens = 256,
|
|
43
|
+
dim = 512,
|
|
44
|
+
depth = 8,
|
|
45
|
+
causal = True,
|
|
46
|
+
local_attn_window_size = 64,
|
|
47
|
+
max_seq_len = SEQ_LEN
|
|
48
|
+
).cuda()
|
|
49
|
+
|
|
50
|
+
# prepare enwik8 data
|
|
51
|
+
|
|
52
|
+
with gzip.open('./data/enwik8.gz') as file:
|
|
53
|
+
data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
|
|
54
|
+
data_train, data_val = np.split(data, [int(90e6)])
|
|
55
|
+
data_train, data_val = map(torch.from_numpy, (data_train, data_val))
|
|
56
|
+
|
|
57
|
+
class TextSamplerDataset(Dataset):
|
|
58
|
+
def __init__(self, data, seq_len):
|
|
59
|
+
super().__init__()
|
|
60
|
+
self.data = data
|
|
61
|
+
self.seq_len = seq_len
|
|
62
|
+
|
|
63
|
+
def __getitem__(self, index):
|
|
64
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
|
65
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
|
66
|
+
return full_seq.cuda()
|
|
67
|
+
|
|
68
|
+
def __len__(self):
|
|
69
|
+
return self.data.size(0) // self.seq_len
|
|
70
|
+
|
|
71
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
|
72
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
|
73
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
|
|
74
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
|
75
|
+
|
|
76
|
+
# optimizer
|
|
77
|
+
|
|
78
|
+
optim = Adam(model.parameters(), lr=LEARNING_RATE)
|
|
79
|
+
|
|
80
|
+
# training
|
|
81
|
+
|
|
82
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
83
|
+
model.train()
|
|
84
|
+
|
|
85
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
|
86
|
+
loss = model(next(train_loader), return_loss = True)
|
|
87
|
+
loss.backward()
|
|
88
|
+
|
|
89
|
+
print(f'training loss: {loss.item()}')
|
|
90
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
|
91
|
+
optim.step()
|
|
92
|
+
optim.zero_grad()
|
|
93
|
+
|
|
94
|
+
if i % VALIDATE_EVERY == 0:
|
|
95
|
+
model.eval()
|
|
96
|
+
with torch.no_grad():
|
|
97
|
+
loss = model(next(val_loader), return_loss = True)
|
|
98
|
+
print(f'validation loss: {loss.item()}')
|
|
99
|
+
|
|
100
|
+
if i % GENERATE_EVERY == 0:
|
|
101
|
+
model.eval()
|
|
102
|
+
inp = random.choice(val_dataset)[:-1]
|
|
103
|
+
prime = decode_tokens(inp)
|
|
104
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
|
105
|
+
|
|
106
|
+
sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
|
|
107
|
+
output_str = decode_tokens(sample[0])
|
|
108
|
+
print(output_str)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|