titans-pytorch 0.0.31__tar.gz → 0.0.34__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/PKG-INFO +1 -1
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/pyproject.toml +1 -1
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/titans_pytorch/mac_transformer.py +70 -9
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/train.py +0 -3
- titans_pytorch-0.0.34/train_mac.py +129 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/.gitignore +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/LICENSE +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/README.md +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/data/README.md +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/fig1.png +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/fig2.png +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/requirements.txt +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.31 → titans_pytorch-0.0.34}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -17,6 +17,10 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
17
17
|
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
18
18
|
from rotary_embedding_torch import RotaryEmbedding
|
19
19
|
|
20
|
+
# proposed neural memory
|
21
|
+
|
22
|
+
from titans_pytorch.titans import NeuralMemory
|
23
|
+
|
20
24
|
# constants
|
21
25
|
|
22
26
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -46,13 +50,20 @@ def pad_and_segment_with_inverse(seq, segment_len):
|
|
46
50
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
47
51
|
|
48
52
|
padding = next_seq_len_mult - seq_len
|
49
|
-
|
53
|
+
needs_pad = padding > 0
|
54
|
+
|
55
|
+
if needs_pad:
|
56
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
50
57
|
|
51
58
|
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
52
59
|
|
53
60
|
def inverse(out):
|
54
61
|
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
55
|
-
|
62
|
+
|
63
|
+
if needs_pad:
|
64
|
+
out = out[:, :-padding]
|
65
|
+
|
66
|
+
return out
|
56
67
|
|
57
68
|
return seq, inverse
|
58
69
|
|
@@ -161,7 +172,9 @@ class MemoryAsContextTransformer(Module):
|
|
161
172
|
dim_head = 64,
|
162
173
|
heads = 8,
|
163
174
|
ff_mult = 4,
|
164
|
-
num_residual_streams = 4
|
175
|
+
num_residual_streams = 4,
|
176
|
+
neural_memory_kwargs: dict = dict(),
|
177
|
+
neural_memory_layers: tuple[int, ...] | None = None,
|
165
178
|
):
|
166
179
|
super().__init__()
|
167
180
|
|
@@ -181,8 +194,25 @@ class MemoryAsContextTransformer(Module):
|
|
181
194
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
182
195
|
|
183
196
|
self.layers = ModuleList([])
|
197
|
+
self.neural_mem_layers = ModuleList([])
|
198
|
+
|
199
|
+
layers = tuple(range(1, depth + 1))
|
200
|
+
neural_memory_layers = set(default(neural_memory_layers, layers))
|
201
|
+
|
202
|
+
for layer in layers:
|
203
|
+
|
204
|
+
# neural memory
|
205
|
+
|
206
|
+
mem = None
|
207
|
+
|
208
|
+
if num_longterm_mem_tokens > 0 and layer in neural_memory_layers:
|
209
|
+
mem = NeuralMemory(dim = dim, chunk_size = num_longterm_mem_tokens)
|
210
|
+
mem = init_hyper_conn(dim = dim, branch = mem)
|
211
|
+
|
212
|
+
self.neural_mem_layers.append(mem)
|
213
|
+
|
214
|
+
# attention and feedforward
|
184
215
|
|
185
|
-
for _ in range(depth):
|
186
216
|
attn = SegmentedAttention(
|
187
217
|
dim = dim,
|
188
218
|
dim_head = dim_head,
|
@@ -203,7 +233,14 @@ class MemoryAsContextTransformer(Module):
|
|
203
233
|
|
204
234
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
205
235
|
|
206
|
-
def forward(
|
236
|
+
def forward(
|
237
|
+
self,
|
238
|
+
x,
|
239
|
+
return_loss = False
|
240
|
+
):
|
241
|
+
|
242
|
+
if return_loss:
|
243
|
+
x, labels = x[:, :-1], x[:, 1:]
|
207
244
|
|
208
245
|
# math
|
209
246
|
|
@@ -221,7 +258,7 @@ class MemoryAsContextTransformer(Module):
|
|
221
258
|
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
222
259
|
|
223
260
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
224
|
-
x =
|
261
|
+
x = cat((mems, x), dim = -2)
|
225
262
|
|
226
263
|
x = inverse_segment(x)
|
227
264
|
|
@@ -235,8 +272,27 @@ class MemoryAsContextTransformer(Module):
|
|
235
272
|
|
236
273
|
x = self.expand_streams(x)
|
237
274
|
|
238
|
-
for attn, ff in self.layers:
|
275
|
+
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
276
|
+
|
277
|
+
if exists(maybe_neural_mem):
|
278
|
+
batch_streams = x.shape[0]
|
279
|
+
|
280
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
281
|
+
|
282
|
+
longterm_mems, x = x[:, :num_longterm_mem_tokens], x[:, num_longterm_mem_tokens:]
|
283
|
+
|
284
|
+
longterm_mems = rearrange(longterm_mems, '(b w) n d -> b (w n) d', b = batch_streams)
|
285
|
+
|
286
|
+
longterm_mems = maybe_neural_mem(longterm_mems)
|
287
|
+
|
288
|
+
longterm_mems = rearrange(longterm_mems, 'b (w n) d -> (b w) n d', n = num_longterm_mem_tokens)
|
289
|
+
|
290
|
+
x = cat((longterm_mems, x), dim = -2)
|
291
|
+
|
292
|
+
x = inverse_segment(x)
|
293
|
+
|
239
294
|
x = attn(x)
|
295
|
+
|
240
296
|
x = ff(x)
|
241
297
|
|
242
298
|
x = self.reduce_streams(x)
|
@@ -245,7 +301,7 @@ class MemoryAsContextTransformer(Module):
|
|
245
301
|
|
246
302
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
247
303
|
|
248
|
-
x = x[:,
|
304
|
+
x = x[:, num_longterm_mem_tokens:]
|
249
305
|
|
250
306
|
x = inverse_segment(x)
|
251
307
|
|
@@ -253,4 +309,9 @@ class MemoryAsContextTransformer(Module):
|
|
253
309
|
|
254
310
|
x = self.norm(x)
|
255
311
|
|
256
|
-
|
312
|
+
logits = self.to_logits(x)
|
313
|
+
|
314
|
+
if not return_loss:
|
315
|
+
return logits
|
316
|
+
|
317
|
+
return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
@@ -63,11 +63,8 @@ def decode_tokens(tokens):
|
|
63
63
|
titans_neural_memory = NeuralMemory(
|
64
64
|
dim = 384,
|
65
65
|
chunk_size = 4,
|
66
|
-
pre_rmsnorm = True,
|
67
|
-
post_rmsnorm = True,
|
68
66
|
dim_head = 64,
|
69
67
|
heads = 4,
|
70
|
-
max_grad_norm = 1.,
|
71
68
|
use_accelerated_scan = True,
|
72
69
|
default_mlp_kwargs = dict(
|
73
70
|
depth = NEURAL_MEMORY_DEPTH
|
@@ -0,0 +1,129 @@
|
|
1
|
+
import random
|
2
|
+
import tqdm
|
3
|
+
import gzip
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from torch.optim import Adam
|
9
|
+
from torch.nn import functional as F
|
10
|
+
from torch.utils.data import DataLoader, Dataset
|
11
|
+
|
12
|
+
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
13
|
+
|
14
|
+
# constants
|
15
|
+
|
16
|
+
NUM_BATCHES = int(1e5)
|
17
|
+
BATCH_SIZE = 4
|
18
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
19
|
+
LEARNING_RATE = 2e-4
|
20
|
+
VALIDATE_EVERY = 100
|
21
|
+
GENERATE_EVERY = 500
|
22
|
+
GENERATE_LENGTH = 512
|
23
|
+
SHOULD_GENERATE = False
|
24
|
+
SEQ_LEN = 512
|
25
|
+
|
26
|
+
PROJECT_NAME = 'titans-mac-transformer'
|
27
|
+
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
28
|
+
GLOBAL_LAYERS = (2, 4)
|
29
|
+
NEURAL_MEMORY_DEPTH = 2
|
30
|
+
WINDOW_SIZE = 64
|
31
|
+
RUN_NAME = 'mac'
|
32
|
+
|
33
|
+
# wandb experiment tracker
|
34
|
+
|
35
|
+
import wandb
|
36
|
+
wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
|
37
|
+
wandb.run.name = RUN_NAME
|
38
|
+
wandb.run.save()
|
39
|
+
|
40
|
+
# helpers
|
41
|
+
|
42
|
+
def cycle(loader):
|
43
|
+
while True:
|
44
|
+
for data in loader:
|
45
|
+
yield data
|
46
|
+
|
47
|
+
def decode_token(token):
|
48
|
+
return str(chr(max(32, token)))
|
49
|
+
|
50
|
+
def decode_tokens(tokens):
|
51
|
+
return ''.join(list(map(decode_token, tokens)))
|
52
|
+
|
53
|
+
# instantiate memory-as-context transformer
|
54
|
+
|
55
|
+
model = MemoryAsContextTransformer(
|
56
|
+
num_tokens = 256,
|
57
|
+
dim = 384,
|
58
|
+
depth = 8,
|
59
|
+
segment_len = WINDOW_SIZE,
|
60
|
+
num_persist_mem_tokens = 16,
|
61
|
+
num_longterm_mem_tokens = 16,
|
62
|
+
neural_memory_layers = (3, 4),
|
63
|
+
neural_memory_kwargs = dict(
|
64
|
+
default_mlp_kwargs = dict(
|
65
|
+
depth = NEURAL_MEMORY_DEPTH
|
66
|
+
)
|
67
|
+
)
|
68
|
+
).cuda()
|
69
|
+
|
70
|
+
# prepare enwik8 data
|
71
|
+
|
72
|
+
with gzip.open('./data/enwik8.gz') as file:
|
73
|
+
data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
|
74
|
+
data_train, data_val = np.split(data, [int(90e6)])
|
75
|
+
data_train, data_val = map(torch.from_numpy, (data_train, data_val))
|
76
|
+
|
77
|
+
class TextSamplerDataset(Dataset):
|
78
|
+
def __init__(self, data, seq_len):
|
79
|
+
super().__init__()
|
80
|
+
self.data = data
|
81
|
+
self.seq_len = seq_len
|
82
|
+
|
83
|
+
def __getitem__(self, index):
|
84
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
85
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
86
|
+
return full_seq.cuda()
|
87
|
+
|
88
|
+
def __len__(self):
|
89
|
+
return self.data.size(0) // self.seq_len
|
90
|
+
|
91
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
92
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
93
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
|
94
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
95
|
+
|
96
|
+
# optimizer
|
97
|
+
|
98
|
+
optim = Adam(model.parameters(), lr=LEARNING_RATE)
|
99
|
+
|
100
|
+
# training
|
101
|
+
|
102
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
103
|
+
model.train()
|
104
|
+
|
105
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
106
|
+
loss = model(next(train_loader), return_loss = True)
|
107
|
+
loss.backward()
|
108
|
+
|
109
|
+
print(f'training loss: {loss.item()}')
|
110
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
111
|
+
optim.step()
|
112
|
+
optim.zero_grad()
|
113
|
+
wandb.log(dict(loss = loss.item()))
|
114
|
+
|
115
|
+
if i % VALIDATE_EVERY == 0:
|
116
|
+
model.eval()
|
117
|
+
with torch.no_grad():
|
118
|
+
loss = model(next(val_loader), return_loss = True)
|
119
|
+
print(f'validation loss: {loss.item()}')
|
120
|
+
|
121
|
+
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
122
|
+
model.eval()
|
123
|
+
inp = random.choice(val_dataset)[:-1]
|
124
|
+
prime = decode_tokens(inp)
|
125
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
126
|
+
|
127
|
+
sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
|
128
|
+
output_str = decode_tokens(sample[0])
|
129
|
+
print(output_str)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|